MNIST classifier in PyTorch

24.5. MNIST classifier in PyTorch#

This notebook is a PyTorch rewrite of the original TensorFlow/Keras demo.
It trains a simple fully-connected network on MNIST and reproduces:

  • data loading

  • model definition (Flatten β†’ Dense(128, ReLU) β†’ Dropout(0.2) β†’ Dense(10))

  • training for 10 epochs

  • accuracy on the test set

  • probability predictions for a sample

  • the same plotting helpers for predictions

# Imports 
import os, math, numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

print('You have PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(1234)
You have PyTorch version: 2.8.0
Using device: cpu
<torch._C.Generator at 0x7f7524f3bf90>
# Data: MNIST via torchvision (returns tensors in [0,1])
batch_size = 128

transform = transforms.ToTensor()  # simple scaling to [0,1]
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

x_train = train_dataset.data.numpy()
y_train = train_dataset.targets.numpy()
x_test  = test_dataset.data.numpy()
y_test  = test_dataset.targets.numpy()

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print('Train:', len(train_dataset), 'Test:', len(test_dataset))
Train: 60000 Test: 10000
# Model: Flatten β†’ Linear(784,128) β†’ ReLU β†’ Dropout(0.2) β†’ Linear(128,10)
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.net(x)

model = MLP().to(device)
sum(p.numel() for p in model.parameters()), model
(101770,
 MLP(
   (net): Sequential(
     (0): Flatten(start_dim=1, end_dim=-1)
     (1): Linear(in_features=784, out_features=128, bias=True)
     (2): ReLU()
     (3): Dropout(p=0.2, inplace=False)
     (4): Linear(in_features=128, out_features=10, bias=True)
   )
 ))
# Training
criterion = nn.CrossEntropyLoss()  # expects raw logits
optimizer = torch.optim.Adam(model.parameters())
epochs = 10

history = {'accuracy': [], 'loss': []}

for epoch in range(1, epochs+1):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)

        # compute training accuracy
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    history['loss'].append(epoch_loss)
    history['accuracy'].append(epoch_acc)
    print(f'Epoch {epoch:2d}/{epochs} - loss: {epoch_loss:.4f} - acc: {epoch_acc:.4f}')

history
Epoch  1/10 - loss: 0.4464 - acc: 0.8796
Epoch  2/10 - loss: 0.2149 - acc: 0.9373
Epoch  3/10 - loss: 0.1596 - acc: 0.9535
Epoch  4/10 - loss: 0.1290 - acc: 0.9617
Epoch  5/10 - loss: 0.1077 - acc: 0.9687
Epoch  6/10 - loss: 0.0928 - acc: 0.9725
Epoch  7/10 - loss: 0.0822 - acc: 0.9751
Epoch  8/10 - loss: 0.0748 - acc: 0.9772
Epoch  9/10 - loss: 0.0672 - acc: 0.9794
Epoch 10/10 - loss: 0.0591 - acc: 0.9823
{'accuracy': [0.8795666666666667,
  0.9372833333333334,
  0.95345,
  0.9617,
  0.9687166666666667,
  0.9725,
  0.97515,
  0.9772333333333333,
  0.9794,
  0.9823],
 'loss': [0.4463983178456624,
  0.21493444612026213,
  0.1596157058954239,
  0.1290262913386027,
  0.10767514850695928,
  0.09282231144507726,
  0.08220300891399383,
  0.07477020881772041,
  0.06717816098332405,
  0.05907997849583626]}
# Evaluation on test set
model.eval()
correct, total, test_loss = 0, 0, 0.0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        test_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

test_acc = correct / total
test_loss = test_loss / total
print(f'\nTest accuracy: {test_acc:5.3f}')
test_acc
Test accuracy: 0.978
0.9775
# Predict class probabilities for the first test image (to mirror Keras softmax output)
softmax = nn.Softmax(dim=1)
model.eval()
with torch.no_grad():
    sample = test_dataset[0][0].unsqueeze(0).to(device)  # 1x1x28x28
    probs = softmax(model(sample)).cpu().numpy()[0]

probs, probs.sum(), probs.argmax()
(array([3.6496281e-06, 6.8333939e-08, 1.1807497e-05, 5.3805462e-04,
        5.3760618e-10, 1.7252582e-06, 1.9389098e-13, 9.9938619e-01,
        2.5817704e-07, 5.8283444e-05], dtype=float32),
 1.0,
 7)
# Helper plotting functions (ported from the Keras version)
def plot_image(i, predictions_array, true_label, img):
    predictions_array, true_label, img = predictions_array, int(true_label), img
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img, cmap=plt.cm.binary)
    predicted_label = int(np.argmax(predictions_array))
    color = 'blue' if predicted_label == true_label else 'red'
    plt.xlabel(f"{predicted_label} {100*np.max(predictions_array):2.0f}% (true: {true_label})", color=color)

def plot_value_array(i, predictions_array, true_label):
    predictions_array, true_label = predictions_array, int(true_label)
    plt.grid(False)
    plt.xticks(range(10))
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array)
    predicted_label = int(np.argmax(predictions_array))
    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('green')
# Plot a grid of test images with predicted probabilities
num_rows = 5
num_cols = 3
num_images = num_rows * num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))

# Precompute probabilities for first num_images
with torch.no_grad():
    xs = torch.stack([test_dataset[i][0] for i in range(num_images)]).to(device)
    probs_batch = nn.Softmax(dim=1)(model(xs)).cpu().numpy()

for i in range(num_images):
    img, label = test_dataset[i]
    plt.subplot(num_rows, 2*num_cols, 2*i+1)
    plot_image(i, probs_batch[i], label, img.squeeze().numpy())
    plt.subplot(num_rows, 2*num_cols, 2*i+2)
    plot_value_array(i, probs_batch[i], label)

plt.tight_layout()
plt.show()
../../../../_images/3bb799fb9c1bf3a45d5679480471403f9857d2af7dc8bb6a31170235673c6f06.png