MNIST classifier in PyTorch (Binder-ready)#
This notebook is a teachable, Binder-friendly rewrite of a classic MNIST multilayer perceptron (MLP). It emphasizes clarity, reproducibility, and portability for classroom/demo use.
# --- Setup & Reproducibility --------------------------------------------------
import os, math, time, numpy as np
import torch
SEED = int(os.environ.get("SEED", 1234))
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("PyTorch:", torch.__version__)
print("Device :", device)
PyTorch: 2.6.0
Device : cpu
# --- Imports ------------------------------------------------------------------
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
# --- Data: MNIST --------------------------------------------------------------
transform = transforms.ToTensor()
train_full = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
n_val = 10000
n_train = len(train_full) - n_val
train_ds, val_ds = random_split(train_full, [n_train, n_val], generator=torch.Generator().manual_seed(SEED))
batch_size = 128
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False)
print(f"Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}")
Train/Val/Test sizes: 50000/10000/10000
# --- Model --------------------------------------------------------------------
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 10)
)
def forward(self, x):
return self.net(x)
model = MLP().to(device)
print(model)
print("Total parameters:", sum(p.numel() for p in model.parameters()))
MLP(
(net): Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): Linear(in_features=784, out_features=128, bias=True)
(2): ReLU()
(3): Dropout(p=0.2, inplace=False)
(4): Linear(in_features=128, out_features=10, bias=True)
)
)
Total parameters: 101770
# --- Training -----------------------------------------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
def run_epoch(loader, train=True):
if train:
model.train()
else:
model.eval()
total, correct, running_loss = 0, 0, 0.0
t0 = time.time()
for xb, yb in loader:
xb, yb = xb.to(device), yb.to(device)
if train:
optimizer.zero_grad()
logits = model(xb)
loss = criterion(logits, yb)
if train:
loss.backward()
optimizer.step()
running_loss += loss.item() * xb.size(0)
pred = logits.argmax(1)
correct += (pred == yb).sum().item()
total += yb.size(0)
elapsed = time.time() - t0
return running_loss/total, correct/total, elapsed
EPOCHS = 10
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
for epoch in range(1, EPOCHS+1):
tr_loss, tr_acc, tr_t = run_epoch(train_loader, train=True)
va_loss, va_acc, va_t = run_epoch(val_loader, train=False)
history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc)
history["val_loss"].append(va_loss); history["val_acc"].append(va_acc)
print(f"Epoch {epoch:2d}/{EPOCHS} | "
f"train loss {tr_loss:.4f} acc {tr_acc:.4f} ({tr_t:.1f}s) | "
f"val loss {va_loss:.4f} acc {va_acc:.4f} ({va_t:.1f}s)")
Epoch 1/10 | train loss 0.4838 acc 0.8712 (4.3s) | val loss 0.2659 acc 0.9238 (0.9s)
Epoch 2/10 | train loss 0.2324 acc 0.9344 (6.5s) | val loss 0.1941 acc 0.9451 (0.9s)
Epoch 3/10 | train loss 0.1747 acc 0.9500 (7.2s) | val loss 0.1584 acc 0.9551 (0.9s)
Epoch 4/10 | train loss 0.1430 acc 0.9578 (7.7s) | val loss 0.1349 acc 0.9613 (0.8s)
Epoch 5/10 | train loss 0.1204 acc 0.9647 (7.9s) | val loss 0.1168 acc 0.9672 (0.8s)
Epoch 6/10 | train loss 0.1039 acc 0.9689 (9.6s) | val loss 0.1065 acc 0.9696 (0.9s)
Epoch 7/10 | train loss 0.0920 acc 0.9724 (8.1s) | val loss 0.1004 acc 0.9711 (1.0s)
Epoch 8/10 | train loss 0.0817 acc 0.9756 (9.0s) | val loss 0.0971 acc 0.9731 (0.9s)
Epoch 9/10 | train loss 0.0735 acc 0.9779 (8.3s) | val loss 0.0933 acc 0.9737 (0.9s)
Epoch 10/10 | train loss 0.0679 acc 0.9789 (8.2s) | val loss 0.0936 acc 0.9724 (0.9s)
# --- Test ---------------------------------------------------------------------
model.eval()
total, correct, running_loss = 0, 0, 0.0
with torch.no_grad():
for xb, yb in test_loader:
xb, yb = xb.to(device), yb.to(device)
logits = model(xb)
loss = criterion(logits, yb)
running_loss += loss.item() * xb.size(0)
pred = logits.argmax(1)
correct += (pred == yb).sum().item()
total += yb.size(0)
test_loss = running_loss/total
test_acc = correct/total
print(f"\nTest loss {test_loss:.4f} | Test accuracy {test_acc:.4f}")
Test loss 0.0770 | Test accuracy 0.9747
# --- Prediction & Visualization ----------------------------------------------
import numpy as np
softmax = nn.Softmax(dim=1)
def predict_probs(x_batch):
model.eval()
with torch.no_grad():
return softmax(model(x_batch)).cpu().numpy()
def plot_image(predictions_array, true_label, img):
import matplotlib.pyplot as plt
predictions_array, true_label, img = predictions_array, int(true_label), img
plt.grid(False); plt.xticks([]); plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary)
predicted_label = int(np.argmax(predictions_array))
color = 'blue' if predicted_label == true_label else 'red'
plt.xlabel(f"{predicted_label} {100*np.max(predictions_array):2.0f}% (true: {true_label})", color=color)
def plot_value_array(predictions_array, true_label):
import matplotlib.pyplot as plt
predictions_array, true_label = predictions_array, int(true_label)
plt.grid(False); plt.xticks(range(10)); plt.yticks([])
bars = plt.bar(range(10), predictions_array)
bars[np.argmax(predictions_array)].set_color('red')
bars[true_label].set_color('green')
# Show a small gallery
num_rows, num_cols = 4, 3
num_images = num_rows * num_cols
fig = plt.figure(figsize=(2*2*num_cols, 2*num_rows))
imgs, labels = [], []
for i in range(num_images):
img, label = test_ds[i]
imgs.append(img)
labels.append(label)
xb = torch.stack(imgs).to(device)
probs = predict_probs(xb)
for i in range(num_images):
img = imgs[i].squeeze().numpy()
label = labels[i]
ax1 = plt.subplot(num_rows, 2*num_cols, 2*i+1)
plot_image(probs[i], label, img)
ax2 = plt.subplot(num_rows, 2*num_cols, 2*i+2)
plot_value_array(probs[i], label)
plt.tight_layout(); plt.show()
Notes & Troubleshooting#
Binder resources: DataLoader uses
num_workers=0to avoid fork issues on Binder.Reproducibility: We fix seeds but do not force fully deterministic algorithms (trade-off for speed/compatibility).
MKL/OpenMP warnings: If you see native warnings, you can try setting
KMP_WARNINGS=0andMKL_SERVICE_FORCE_INTEL=1before importing NumPy/PyTorch in the first cell.