Demo: MNIST classifier in PyTorch

26.4. Demo: MNIST classifier in PyTorch#

This notebook is a PyTorch rewrite of the original TensorFlow/Keras demo.
It trains a simple fully-connected network on MNIST and reproduces:

  • data loading

  • model definition (Flatten → Dense(128, ReLU) → Dropout(0.2) → Dense(10))

  • training for 10 epochs

  • accuracy on the test set

  • probability predictions for a sample

  • the same plotting helpers for predictions

# Imports 
import os, math, numpy as np


import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

print('You have PyTorch version:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(1234)
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
You have PyTorch version: 2.6.0
Using device: cpu
<torch._C.Generator at 0x111a171f0>
# Data: MNIST via torchvision (returns tensors in [0,1])
batch_size = 128

transform = transforms.ToTensor()  # simple scaling to [0,1]
train_dataset = datasets.MNIST(root='../assets', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='../assets', train=False, download=True, transform=transform)

x_train = train_dataset.data.numpy()
y_train = train_dataset.targets.numpy()
x_test  = test_dataset.data.numpy()
y_test  = test_dataset.targets.numpy()

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print('Train:', len(train_dataset), 'Test:', len(test_dataset))
Train: 60000 Test: 10000
# Model: Flatten → Linear(784,128) → ReLU → Dropout(0.2) → Linear(128,10)
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.net(x)

model = MLP().to(device)
sum(p.numel() for p in model.parameters()), model
(101770,
 MLP(
   (net): Sequential(
     (0): Flatten(start_dim=1, end_dim=-1)
     (1): Linear(in_features=784, out_features=128, bias=True)
     (2): ReLU()
     (3): Dropout(p=0.2, inplace=False)
     (4): Linear(in_features=128, out_features=10, bias=True)
   )
 ))
# Training
criterion = nn.CrossEntropyLoss()  # expects raw logits
optimizer = torch.optim.Adam(model.parameters())
epochs = 10

history = {'accuracy': [], 'loss': []}

for epoch in range(1, epochs+1):
    model.train()
    running_loss = 0.0
    correct, total = 0, 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)

        # compute training accuracy
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    history['loss'].append(epoch_loss)
    history['accuracy'].append(epoch_acc)
    print(f'Epoch {epoch:2d}/{epochs} - loss: {epoch_loss:.4f} - acc: {epoch_acc:.4f}')

history
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Epoch  1/10 - loss: 0.4442 - acc: 0.8808
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
Epoch  2/10 - loss: 0.2119 - acc: 0.9388
# Evaluation on test set
model.eval()
correct, total, test_loss = 0, 0, 0.0
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        test_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

test_acc = correct / total
test_loss = test_loss / total
print(f'\nTest accuracy: {test_acc:5.3f}')
test_acc
# Predict class probabilities for the first test image (to mirror Keras softmax output)
softmax = nn.Softmax(dim=1)
model.eval()
with torch.no_grad():
    sample = test_dataset[0][0].unsqueeze(0).to(device)  # 1x1x28x28
    probs = softmax(model(sample)).cpu().numpy()[0]

probs, probs.sum(), probs.argmax()
# Helper plotting functions (ported from the Keras version)
def plot_image(i, predictions_array, true_label, img):
    predictions_array, true_label, img = predictions_array, int(true_label), img
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img, cmap=plt.cm.binary)
    predicted_label = int(np.argmax(predictions_array))
    color = 'blue' if predicted_label == true_label else 'red'
    plt.xlabel(f"{predicted_label} {100*np.max(predictions_array):2.0f}% (true: {true_label})", color=color)

def plot_value_array(i, predictions_array, true_label):
    predictions_array, true_label = predictions_array, int(true_label)
    plt.grid(False)
    plt.xticks(range(10))
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array)
    predicted_label = int(np.argmax(predictions_array))
    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('green')
# Plot a grid of test images with predicted probabilities
num_rows = 5
num_cols = 3
num_images = num_rows * num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))

# Precompute probabilities for first num_images
with torch.no_grad():
    xs = torch.stack([test_dataset[i][0] for i in range(num_images)]).to(device)
    probs_batch = nn.Softmax(dim=1)(model(xs)).cpu().numpy()

for i in range(num_images):
    img, label = test_dataset[i]
    plt.subplot(num_rows, 2*num_cols, 2*i+1)
    plot_image(i, probs_batch[i], label, img.squeeze().numpy())
    plt.subplot(num_rows, 2*num_cols, 2*i+2)
    plot_value_array(i, probs_batch[i], label)

plt.tight_layout()
plt.show()