Identifying hand-written digits (MNIST database) using PyTorch¶
Implementation Steps
Library & Dataset Setup
The necessary libraries were imported & the MNIST dataset is downloaded using torchvision. The dataset is split into training & testing sets, & dataloaders are created to manage the input pipeline efficiently.Hyperparameter Initialization
Key training parameters such as input size, number of output classes, number of epochs, batch size, & learning rate are defined. Each image in MNIST is 28×28 pixels, leading to an input size of 784.Model Definition
A logistic regression model and a multi-layer perceptron (MLP) are defined using PyTorch’s module system.Loss Function & Optimizer
The cross-entropy loss function is selected for its suitability in multi-class classification. The stochastic gradient descent algorithm is used for optimizing the model weights.Model Training
The models are trained for five epochs. Each epoch involved iterating through batches of the training data. For each batch, the model performed a forward pass, calculated the loss, backpropagated the gradients, & updated weights using the optimizer.Model Evaluation
After training, the models are evaluated on the test dataset. The evaluation loop measured the number of correct predictions out of total test samples to calculate the final accuracy.
1. Library & Dataset Setup¶
import torch
import torch.nn as nn
import torchvision ## Contains some utilities for working with the image data
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
#%matplotlib inline
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn.functional as F
from IPython.display import Image
Image("HandDigits.png", width=600)
Downloading the MNIST dataset¶
### load the MNIST dataset
dataset = MNIST(root = 'data/', download = True)
print(len(dataset))
60000
type(dataset)
torchvision.datasets.mnist.MNIST
image, label = dataset[10]
plt.imshow(image, cmap = 'gray')
print('Label:', label)
Label: 3
image, label = dataset[1]
plt.imshow(image, cmap = 'gray')
print('Label:', label)
Label: 0
Loading the MNIST data with transformation¶
PyTorch doesn't know how to work with images. We need to convert the images into tensors. We can do this by specifying a transform while creating our dataset.
### Convert to tensors
## MNIST dataset(images and labels)
mnist_dataset = MNIST(root = 'data/', train = True, transform = transforms.ToTensor())
print(mnist_dataset)
Dataset MNIST
Number of datapoints: 60000
Root location: data/
Split: Train
StandardTransform
Transform: ToTensor()
The image is now convert to a 28 X 28 tensor. The first dimension is used to keep track of the color channels. Since images in the MNIST dataset are grayscale, there's just one channel. Other datasets have images with color, in that case the color channels would be 3(Red, Green, Blue).¶
image_tensor, label = mnist_dataset[0]
print(image_tensor.shape, label)
torch.Size([1, 28, 28]) 5
print(image_tensor[:,10:15,10:15])
print(torch.max(image_tensor), torch.min(image_tensor))
tensor([[[0.0039, 0.6039, 0.9922, 0.3529, 0.0000],
[0.0000, 0.5451, 0.9922, 0.7451, 0.0078],
[0.0000, 0.0431, 0.7451, 0.9922, 0.2745],
[0.0000, 0.0000, 0.1373, 0.9451, 0.8824],
[0.0000, 0.0000, 0.0000, 0.3176, 0.9412]]])
tensor(1.) tensor(0.)
## Plot the image of the tensor
plt.imshow(image_tensor[0,0:27,0:27],cmap = 'gray')
<matplotlib.image.AxesImage at 0x143eb7250>
## Plot the image of the tensor
plt.imshow(image_tensor[0,10:15,10:15],cmap = 'gray')
<matplotlib.image.AxesImage at 0x143f3d310>
2. Hyperparameter Initialization¶
# --- Hyperparameters ---
n_trials = 4 ### repeat the experiment 4 times
epochs = 40 ### number of training steps
batch_size = 128 ####
lr_logreg = 0.1
lr_mlp = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden1=256
hidden2=128
p_dropout=0.2
####
input_size= 28 * 28
num_classes=10
Training and Validation Datasets¶
#### split training/validation
train_data, validation_data = random_split(mnist_dataset, [50000, 10000])
## Print the length of train and validation datasets
print("length of Train Datasets: ", len(train_data))
print("length of Validation Datasets: ", len(validation_data))
length of Train Datasets: 50000 length of Validation Datasets: 10000
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=False)
Define Logistic Regression Model in Pytorch¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class LogisticRegression(nn.Module):
def __init__(self, input_size=input_size, num_classes=num_classes):
super().__init__()
self.linear = nn.Linear(input_size, num_classes)
self._init_weights()
def forward(self, xb):
xb = xb.view(xb.size(0), -1) # flatten, device-safe
return self.linear(xb)
def _init_weights(self):
# small normal init for weights, zero biases
nn.init.normal_(self.linear.weight, mean=0.0, std=0.01)
if self.linear.bias is not None:
nn.init.zeros_(self.linear.bias)
# --- training / validation helpers (device-safe) ---
def training_step(self, batch):
images, labels = batch
device = next(self.parameters()).device
images = images.to(device)
labels = labels.to(device)
out = self(images) # logits
loss = F.cross_entropy(out, labels) # scalar tensor
return loss
def validation_step(self, batch):
images, labels = batch
device = next(self.parameters()).device
images = images.to(device)
labels = labels.to(device)
out = self(images)
loss = F.cross_entropy(out, labels)
preds = out.argmax(dim=1)
acc = (preds == labels).float().mean() # tensor
return {'val_loss': loss.detach(), 'val_acc': acc.detach()}
def validation_epoch_end(self, outputs):
# outputs: list of {'val_loss': tensor, 'val_acc': tensor}
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean()
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
def epoch_end(self, epoch, result):
print(f"Epoch [{epoch}], val_loss: {result['val_loss']:.4f}, val_acc: {result['val_acc']:.4f}")
modelL = LogisticRegression()
MLP with 2 hidden layer¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, input_size=input_size, hidden1=hidden1, hidden2=hidden2, num_classes=num_classes, p_dropout=p_dropout):
super().__init__()
# Layers
self.fc1 = nn.Linear(input_size, hidden1)
self.bn1 = nn.BatchNorm1d(hidden1)
self.fc2 = nn.Linear(hidden1, hidden2)
self.bn2 = nn.BatchNorm1d(hidden2)
self.fc3 = nn.Linear(hidden2, num_classes)
self.dropout = nn.Dropout(p_dropout)
self.relu = nn.ReLU(inplace=True)
# weight init
self._init_weights()
def forward(self, xb):
xb = xb.view(xb.size(0), -1) # flatten
x = self.fc1(xb)
x = self.bn1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc3(x)
return x
def _init_weights(self):
# He/Kaiming init for linear layers and sensible BN init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None:
nn.init.ones_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# --- Convenience training/validation helpers (device-safe) ---
def training_step(self, batch):
images, labels = batch
device = next(self.parameters()).device
images = images.to(device)
labels = labels.to(device)
out = self(images)
loss = F.cross_entropy(out, labels)
return loss
def validation_step(self, batch):
images, labels = batch
device = next(self.parameters()).device
images = images.to(device)
labels = labels.to(device)
out = self(images)
loss = F.cross_entropy(out, labels)
preds = out.argmax(dim=1)
acc = (preds == labels).float().mean()
# return tensors (not Python floats) so we can stack them later
return {'val_loss': loss.detach(), 'val_acc': acc.detach()}
def validation_epoch_end(self, outputs):
# outputs: list of {'val_loss': tensor, 'val_acc': tensor}
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean()
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
def epoch_end(self, epoch, result):
print(f"Epoch [{epoch}] val_loss: {result['val_loss']:.4f}, val_acc: {result['val_acc']:.4f}")
modelM = MLP()
import torch
import copy
import numpy as np
def accuracy(outputs, labels):
"""Return accuracy as a tensor (0..1)."""
_, preds = torch.max(outputs, dim=1)
return (preds == labels).float().mean()
def evaluate(model, val_loader, device=None):
"""Run validation loop using model.validation_step and validation_epoch_end."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
outputs = []
with torch.no_grad():
for batch in val_loader:
# assume model.validation_step moves tensors to device or expects CPU — we handle both:
out = model.validation_step(batch)
# ensure tensors (val_loss and val_acc) so validation_epoch_end can stack them
if isinstance(out['val_loss'], float):
out['val_loss'] = torch.tensor(out['val_loss'], device=device)
if isinstance(out['val_acc'], float):
out['val_acc'] = torch.tensor(out['val_acc'], device=device)
outputs.append(out)
return model.validation_epoch_end(outputs)
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD,
device=None, scheduler=None, save_path=None, print_every=1):
"""
Train model using the model.training_step / validation_step / validation_epoch_end helpers.
- opt_func: optimizer class (e.g. torch.optim.SGD or torch.optim.Adam)
- scheduler: optional lr scheduler (StepLR or ReduceLROnPlateau)
- save_path: optional path to save best model (by val_acc)
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = opt_func(model.parameters(), lr)
best_val_acc = -1.0
best_state = None
history = []
for epoch in range(epochs):
# Training
model.train()
train_losses = []
for batch in train_loader:
optimizer.zero_grad() # zero BEFORE backward
loss = model.training_step(batch) # model should handle moving batch to device
# ensure loss is a tensor on correct device
if not torch.is_tensor(loss):
loss = torch.tensor(float(loss), device=device, requires_grad=True)
loss.backward()
optimizer.step()
train_losses.append(loss.detach().cpu().item())
# optional scheduler step for per-epoch schedulers (not ReduceLROnPlateau)
if scheduler is not None and not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step()
# Validation
result = evaluate(model, val_loader, device=device)
history.append(result)
# save best
if result.get("val_acc", -1) > best_val_acc:
best_val_acc = result["val_acc"]
best_state = copy.deepcopy(model.state_dict())
if save_path is not None:
torch.save(best_state, save_path)
# scheduler that depends on metric
if scheduler is not None and isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
# commonly step on validation loss; change to val_acc if configured that way
scheduler.step(result.get("val_loss"))
if (epoch + 1) % print_every == 0:
model.epoch_end(epoch, result)
# restore best weights (so returned model is the best one)
if best_state is not None:
model.load_state_dict(best_state)
return history
result0L = evaluate(modelL, val_loader)
result0L
{'val_loss': 2.3027942180633545, 'val_acc': 0.13568037748336792}
history1L = fit(5, 0.001, modelL, train_loader, val_loader)
Epoch [0], val_loss: 1.9445, val_acc: 0.7042 Epoch [1], val_loss: 1.6802, val_acc: 0.7541 Epoch [2], val_loss: 1.4828, val_acc: 0.7741 Epoch [3], val_loss: 1.3338, val_acc: 0.7858 Epoch [4], val_loss: 1.2192, val_acc: 0.7962
history2L = fit(5, 0.001, modelL, train_loader, val_loader)
Epoch [0], val_loss: 1.1292, val_acc: 0.8043 Epoch [1], val_loss: 1.0570, val_acc: 0.8106 Epoch [2], val_loss: 0.9979, val_acc: 0.8153 Epoch [3], val_loss: 0.9488, val_acc: 0.8205 Epoch [4], val_loss: 0.9072, val_acc: 0.8245
history3L = fit(5, 0.001, modelL, train_loader, val_loader)
Epoch [0], val_loss: 0.8716, val_acc: 0.8268 Epoch [1], val_loss: 0.8408, val_acc: 0.8304 Epoch [2], val_loss: 0.8137, val_acc: 0.8336 Epoch [3], val_loss: 0.7899, val_acc: 0.8355 Epoch [4], val_loss: 0.7687, val_acc: 0.8378
history4L = fit(5, 0.001, modelL, train_loader, val_loader)
Epoch [0], val_loss: 0.7495, val_acc: 0.8399 Epoch [1], val_loss: 0.7323, val_acc: 0.8413 Epoch [2], val_loss: 0.7167, val_acc: 0.8437 Epoch [3], val_loss: 0.7024, val_acc: 0.8451 Epoch [4], val_loss: 0.6893, val_acc: 0.8463
historyL = [result0L] + history1L + history2L + history3L + history4L
accuracies = [result['val_acc'] for result in historyL]
plt.plot(accuracies, '-x')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('Accuracy Vs. No. of epochs')
Text(0.5, 1.0, 'Accuracy Vs. No. of epochs')
Now train the MLP model¶
result0M = evaluate(modelM, val_loader)
result0M
{'val_loss': 2.382965326309204, 'val_acc': 0.06398338824510574}
history1M = fit(5, 0.001, modelM, train_loader, val_loader)
Epoch [0] val_loss: 1.4342, val_acc: 0.5825 Epoch [1] val_loss: 1.0730, val_acc: 0.7175 Epoch [2] val_loss: 0.8916, val_acc: 0.7675 Epoch [3] val_loss: 0.7782, val_acc: 0.7992 Epoch [4] val_loss: 0.6975, val_acc: 0.8211
history2M = fit(5, 0.001, modelM, train_loader, val_loader)
Epoch [0] val_loss: 0.6370, val_acc: 0.8373 Epoch [1] val_loss: 0.5944, val_acc: 0.8509 Epoch [2] val_loss: 0.5576, val_acc: 0.8575 Epoch [3] val_loss: 0.5257, val_acc: 0.8664 Epoch [4] val_loss: 0.5024, val_acc: 0.8715
history3M = fit(5, 0.001, modelM, train_loader, val_loader)
Epoch [0] val_loss: 0.4810, val_acc: 0.8766 Epoch [1] val_loss: 0.4638, val_acc: 0.8797 Epoch [2] val_loss: 0.4487, val_acc: 0.8817 Epoch [3] val_loss: 0.4344, val_acc: 0.8844 Epoch [4] val_loss: 0.4187, val_acc: 0.8888
history4M = fit(5, 0.001, modelM, train_loader, val_loader)
Epoch [0] val_loss: 0.4086, val_acc: 0.8900 Epoch [1] val_loss: 0.3995, val_acc: 0.8936 Epoch [2] val_loss: 0.3885, val_acc: 0.8953 Epoch [3] val_loss: 0.3815, val_acc: 0.8973 Epoch [4] val_loss: 0.3710, val_acc: 0.8985
historyM = [result0M] + history1M + history2M + history3M + history4M
accuracies = [result['val_acc'] for result in historyM]
plt.plot(accuracies, '-x', color='C1')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('Accuracy Vs. No. of epochs')
Text(0.5, 1.0, 'Accuracy Vs. No. of epochs')
Logistic regression converges faster but has limited capacity, while an MLP converges more slowly but achieves higher final accuracy.¶
historyL = [result0L] + history1L + history2L + history3L + history4L
accL = [result['val_acc'] for result in historyL]
# MLP history
historyM = [result0M] + history1M + history2M + history3M + history4M
accM = [result['val_acc'] for result in historyM]
# Plot both on the same figure
plt.figure(figsize=(7, 5))
plt.plot(accL, '-x', label='Logistic Regression')
plt.plot(accM, '-o', label='MLP with 1 hidden layer')
plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy')
plt.title('Validation Accuracy vs Number of Epochs')
plt.legend()
plt.grid(True)
plt.savefig("accuracy_comparison1.png", dpi=300, bbox_inches="tight")
plt.show()
6. Model Evaluation¶
### test dataset
test_dataset = MNIST(root = 'data/', train = False, transform = transforms.ToTensor())
img, label = test_dataset[0]
plt.imshow(img[0], cmap = 'gray')
print("shape: ", img.shape)
print('Label: ', label)
shape: torch.Size([1, 28, 28]) Label: 7
print(img.unsqueeze(0).shape)
print(img.shape)
torch.Size([1, 1, 28, 28]) torch.Size([1, 28, 28])
def predict_image(img, model):
xb = img.unsqueeze(0)
yb = model(xb)
_, preds = torch.max(yb, dim = 1)
return(preds[0].item())
img, label = test_dataset[0]
plt.imshow(img[0], cmap = 'gray')
print('Label:', label, ', Predicted :', predict_image(img, modelL))
Label: 7 , Predicted : 7
test_loader = DataLoader(test_dataset, batch_size = 256)
resultL = evaluate(modelL, test_loader)
resultL
{'val_loss': 0.6390308141708374, 'val_acc': 0.8609374761581421}
resultM = evaluate(modelM, test_loader)
resultM
{'val_loss': 0.3206319808959961, 'val_acc': 0.9159179925918579}
torch.save(modelL.state_dict(), 'mnist-logistic.pth')
torch.save(modelM.state_dict(), 'mnist-MLP.pth')
modelL.state_dict()
OrderedDict([('linear.weight',
tensor([[ 0.0194, 0.0009, -0.0122, ..., 0.0026, 0.0062, -0.0072],
[-0.0002, -0.0110, 0.0028, ..., 0.0160, -0.0014, -0.0027],
[-0.0003, -0.0111, -0.0031, ..., 0.0075, -0.0115, 0.0059],
...,
[-0.0242, -0.0007, -0.0163, ..., -0.0071, 0.0064, 0.0015],
[-0.0057, 0.0084, -0.0080, ..., -0.0046, -0.0181, -0.0037],
[ 0.0053, 0.0077, 0.0021, ..., -0.0068, -0.0114, -0.0157]])),
('linear.bias',
tensor([-0.0448, 0.0907, -0.0222, -0.0230, 0.0257, 0.0416, -0.0090, 0.0434,
-0.0916, -0.0107]))])
The weights¶
### weights from Logistic Regression
import matplotlib.pyplot as plt
w = modelL.state_dict()['linear.weight'].cpu().numpy() # shape (10, 784)
fig, axes = plt.subplots(2,5, figsize=(10,5))
for i, ax in enumerate(axes.flat):
ax.imshow(w[i].reshape(28,28), cmap='seismic')
ax.set_title(f"class {i}")
ax.axis('off')
plt.show()
### weights from MLP
# approximate effective input → output weights
W_eff = modelM.fc3.weight @ modelM.fc2.weight @ modelM.fc1.weight
# shape: (10, 784)
def plot_class_templates(W_eff):
plt.figure(figsize=(10,4))
for i in range(10):
plt.subplot(2,5,i+1)
plt.imshow(W_eff[i].detach().cpu().reshape(28,28), cmap="seismic")
plt.title(f"Class {i}")
plt.axis("off")
plt.suptitle("MLP Approx. Class Templates", fontsize=14)
plt.tight_layout()
plt.show()
plot_class_templates(W_eff)
### Check wrong prediciton
modelL.eval()
LogisticRegression( (linear): Linear(in_features=784, out_features=10, bias=True) )
import matplotlib.pyplot as plt
wrong_images = []
wrong_preds = []
wrong_labels = []
with torch.no_grad():
for images, labels in test_loader:
outputs = modelL(images)
_, preds = torch.max(outputs, dim=1)
# find misclassified indices
wrong_idx = preds != labels
wrong_images.extend(images[wrong_idx])
wrong_preds.extend(preds[wrong_idx])
wrong_labels.extend(labels[wrong_idx])
def show_wrong_predictions(images, preds, labels, n=10):
plt.figure(figsize=(12, 4))
for i in range(n):
plt.subplot(1, n, i + 1)
plt.imshow(images[i].squeeze(), cmap="gray")
plt.title(f"P:{preds[i].item()} / T:{labels[i].item()}")
plt.axis("off")
plt.show()
show_wrong_predictions(wrong_images, wrong_preds, wrong_labels, n=10)
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
def to_cpu(t):
"""Detach and move a tensor to CPU (returns cpu tensor)."""
return t.detach().cpu()
def im_to_plot(img_tensor, mean=None, std=None):
"""
Convert a CHW tensor to HWC numpy for plt.imshow.
Handles grayscale (1,H,W) and RGB (3,H,W).
If mean/std provided (iterable of length C) it will unnormalize.
"""
img = to_cpu(img_tensor)
if img.dim() == 3 and img.size(0) == 1:
img = img.squeeze(0) # (H, W)
if mean is not None and std is not None:
img = img * std[0] + mean[0]
return img.numpy()
elif img.dim() == 3 and img.size(0) == 3:
if mean is not None and std is not None:
mean_t = torch.tensor(mean).view(-1,1,1)
std_t = torch.tensor(std).view(-1,1,1)
img = img * std_t + mean_t
img = img.permute(1,2,0) # HWC
return img.numpy()
else:
return img.squeeze().numpy()
def collect_disagreements(ModelL, ModelM, dataloader, device=None,
max_examples_per_case=30, return_probs=True,
sort_by_confidence_gap=False):
"""
Collect examples where:
- ModelL wrong & ModelM correct
- ModelM wrong & ModelL correct
Returns dict with keys:
'L_wrong_M_correct' and 'M_wrong_L_correct' containing lists of tuples:
(img_tensor_cpu, predL, probL, predM, probM, label, conf_gap)
conf_gap = ModelM_conf - ModelL_conf (useful for sorting)
"""
# device auto-detect
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ModelL.to(device)
ModelM.to(device)
ModelL.eval()
ModelM.eval()
l_wrong_m_correct = []
m_wrong_l_correct = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
outL = ModelL(images)
outM = ModelM(images)
# handle binary vs multi-class logits
if outL.dim() == 1 or (outL.dim() == 2 and outL.size(1) == 1):
probsL = torch.sigmoid(outL.view(-1))
predsL = (probsL > 0.5).long()
confL = probsL.clone() # confidence for predicted class
else:
probsL = F.softmax(outL, dim=1)
predsL = torch.argmax(probsL, dim=1)
confL = probsL.max(dim=1).values
if outM.dim() == 1 or (outM.dim() == 2 and outM.size(1) == 1):
probsM = torch.sigmoid(outM.view(-1))
predsM = (probsM > 0.5).long()
confM = probsM.clone()
else:
probsM = F.softmax(outM, dim=1)
predsM = torch.argmax(probsM, dim=1)
confM = probsM.max(dim=1).values
# boolean masks
L_wrong = predsL != labels
M_correct = predsM == labels
mask_Lwrong_Mcorrect = L_wrong & M_correct
M_wrong = predsM != labels
L_correct = predsL == labels
mask_Mwrong_Lcorrect = M_wrong & L_correct
# collect examples
for idx in torch.where(mask_Lwrong_Mcorrect)[0]:
if len(l_wrong_m_correct) >= max_examples_per_case:
break
i = int(idx)
conf_gap = float(confM[i].cpu().item()) - float(confL[i].cpu().item())
l_wrong_m_correct.append((
to_cpu(images[i]),
int(predsL[i].cpu().item()),
float(confL[i].cpu().item()),
int(predsM[i].cpu().item()),
float(confM[i].cpu().item()),
int(labels[i].cpu().item()),
conf_gap
))
for idx in torch.where(mask_Mwrong_Lcorrect)[0]:
if len(m_wrong_l_correct) >= max_examples_per_case:
break
i = int(idx)
conf_gap = float(confM[i].cpu().item()) - float(confL[i].cpu().item())
m_wrong_l_correct.append((
to_cpu(images[i]),
int(predsL[i].cpu().item()),
float(confL[i].cpu().item()),
int(predsM[i].cpu().item()),
float(confM[i].cpu().item()),
int(labels[i].cpu().item()),
conf_gap
))
if len(l_wrong_m_correct) >= max_examples_per_case and len(m_wrong_l_correct) >= max_examples_per_case:
break
# optional sorting by confidence gap (largest positive gap first)
if sort_by_confidence_gap:
l_wrong_m_correct.sort(key=lambda t: t[-1], reverse=True) # M much more confident than L
m_wrong_l_correct.sort(key=lambda t: t[-1], reverse=True)
return {
"L_wrong_M_correct": l_wrong_m_correct,
"M_wrong_L_correct": m_wrong_l_correct
}
def plot_disagreements(case_list, title, n=8, mean=None, std=None, cmap="gray"):
"""
case_list: list of tuples (img_tensor_cpu, predL, confL, predM, confM, label, conf_gap)
"""
n = min(n, len(case_list))
if n == 0:
print(f"No examples for: {title}")
return
plt.figure(figsize=(3.5 * n, 3))
plt.suptitle(title, fontsize=14)
for i in range(n):
img_tensor, predL, confL, predM, confM, label, conf_gap = case_list[i]
ax = plt.subplot(1, n, i+1)
npimg = im_to_plot(img_tensor, mean=mean, std=std)
if npimg.ndim == 2:
ax.imshow(npimg, cmap=cmap)
else:
ax.imshow(np.clip(npimg, 0, 1))
ax.set_title(f"T:{label}\nL:{predL} ({confL:.2f})\nM:{predM} ({confM:.2f})", fontsize=9)
ax.axis("off")
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
# -------------------------
# Example usage (replace with your objects / loader):
# -------------------------
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ModelL.to(device); ModelM.to(device)
# cases = collect_disagreements(ModelL, ModelM, test_loader, device=device,
# max_examples_per_case=40, sort_by_confidence_gap=True)
# plot_disagreements(cases["L_wrong_M_correct"], "ModelL WRONG, ModelM CORRECT", n=8, mean=None, std=None)
# plot_disagreements(cases["M_wrong_L_correct"], "ModelM WRONG, ModelL CORRECT", n=8, mean=None, std=None)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelL.to(device); modelM.to(device)
cases = collect_disagreements(modelL, modelM, test_loader, device=device, max_examples_per_case=40, sort_by_confidence_gap=True)
plot_disagreements(cases["L_wrong_M_correct"], "ModelL WRONG, ModelM CORRECT", n=8, mean=None, std=None)
plot_disagreements(cases["M_wrong_L_correct"], "ModelM WRONG, ModelL CORRECT", n=8, mean=None, std=None)
Confusion Matrix¶
import torch
import numpy as np
def collect_preds(model, dataloader, device):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
preds = torch.argmax(outputs, dim=1)
all_preds.append(preds.cpu())
all_labels.append(labels.cpu())
return torch.cat(all_labels), torch.cat(all_preds)
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def plot_confusion_with_boxes(y_true, y_pred, title, top_k=6):
cm = confusion_matrix(y_true, y_pred)
num_classes = cm.shape[0]
fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(cm, interpolation="nearest", cmap="Blues")
plt.colorbar(im, ax=ax)
ax.set_title(title)
ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
ax.set_xticks(range(num_classes))
ax.set_yticks(range(num_classes))
# annotate counts
for i in range(num_classes):
for j in range(num_classes):
ax.text(j, i, cm[i, j], ha="center", va="center", fontsize=9)
# find top off-diagonal entries
off_diag = []
for i in range(num_classes):
for j in range(num_classes):
if i != j and cm[i, j] > 0:
off_diag.append((cm[i, j], i, j))
off_diag.sort(reverse=True)
top_errors = off_diag[:top_k]
# draw boxes
for _, i, j in top_errors:
rect = plt.Rectangle(
(j - 0.5, i - 0.5),
1,
1,
fill=False,
edgecolor="red",
linewidth=2
)
ax.add_patch(rect)
plt.tight_layout()
plt.show()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelL.to(device)
modelM.to(device)
y_true_L, y_pred_L = collect_preds(modelL, test_loader, device)
y_true_M, y_pred_M = collect_preds(modelM, test_loader, device)
plot_confusion_with_boxes(
y_true_L,
y_pred_L,
title="Confusion Matrix – ModelL (Logistic Regression)",
top_k=6
)
plot_confusion_with_boxes(
y_true_M,
y_pred_M,
title="Confusion Matrix – ModelM (MLP)",
top_k=6
)
Access Confidence in ML models¶
from sklearn.calibration import calibration_curve
import torch
import matplotlib.pyplot as plt
import numpy as np
def collect_confidence_and_correctness(model, dataloader, device):
model.eval()
probs_all = []
correct_all = []
with torch.no_grad():
for xb, yb in dataloader:
xb = xb.to(device)
yb = yb.to(device)
logits = model(xb)
probs = torch.softmax(logits, dim=1)
confs, preds = probs.max(dim=1)
probs_all.append(confs.cpu())
correct_all.append((preds == yb).float().cpu())
probs_all = torch.cat(probs_all).numpy()
correct_all = torch.cat(correct_all).numpy()
return probs_all, correct_all
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Logistic Regression
probs_L, correct_L = collect_confidence_and_correctness(modelL, test_loader, device)
frac_pos_L, mean_pred_L = calibration_curve(
correct_L, probs_L, n_bins=10, strategy="uniform"
)
# MLP
probs_M, correct_M = collect_confidence_and_correctness(modelM, test_loader, device)
frac_pos_M, mean_pred_M = calibration_curve(
correct_M, probs_M, n_bins=10, strategy="uniform"
)
plt.figure(figsize=(6, 6))
# Perfect calibration line
plt.plot([0, 1], [0, 1], "k--", label="Perfect calibration")
# Logistic Regression
plt.plot(mean_pred_L, frac_pos_L, "-o", label="Logistic Regression")
# MLP
plt.plot(mean_pred_M, frac_pos_M, "-s", label="MLP")
plt.xlabel("Mean predicted confidence")
plt.ylabel("Fraction of correct predictions")
plt.title("Reliability Diagram (Calibration Curve)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(
"calibration_curve_logistic_vs_mlp.png",
dpi=300,
bbox_inches="tight"
)
plt.show()
def expected_calibration_error(probs, correct, n_bins=10):
bins = np.linspace(0, 1, n_bins + 1)
ece = 0.0
for i in range(n_bins):
mask = (probs >= bins[i]) & (probs < bins[i+1])
if mask.sum() > 0:
acc = correct[mask].mean()
conf = probs[mask].mean()
ece += (mask.sum() / len(probs)) * abs(acc - conf)
return ece
print("ECE Logistic:", expected_calibration_error(probs_L, correct_L))
print("ECE MLP: ", expected_calibration_error(probs_M, correct_M))
ECE Logistic: 0.23214034938812259 ECE MLP: 0.07543363214433194
import numpy as np
import matplotlib.pyplot as plt
import torch
# If you already have probs_* and correct_* arrays from before, skip collection and set:
# probs_L, correct_L = probs_L, correct_L
# probs_M, correct_M = probs_M, correct_M
# Otherwise collect them:
def collect_confidence_and_correctness(model, dataloader, device):
model.eval()
probs_all = []
correct_all = []
with torch.no_grad():
for xb, yb in dataloader:
xb = xb.to(device); yb = yb.to(device)
logits = model(xb)
if logits.dim() == 1 or (logits.dim()==2 and logits.size(1) == 1):
probs = torch.sigmoid(logits.view(-1))
confs = probs.cpu().numpy()
preds = (probs > 0.5).long().cpu().numpy()
correct = (preds == yb.cpu().numpy()).astype(float)
else:
probs = torch.softmax(logits, dim=1)
confs = probs.max(dim=1).values.cpu().numpy()
preds = probs.argmax(dim=1).cpu().numpy()
correct = (preds == yb.cpu().numpy()).astype(float)
probs_all.append(confs)
correct_all.append(correct)
probs_all = np.concatenate(probs_all, axis=0)
correct_all = np.concatenate(correct_all, axis=0)
return probs_all, correct_all
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ensure models are on same device
modelL.to(device); modelM.to(device)
probs_L, correct_L = collect_confidence_and_correctness(modelL, test_loader, device)
probs_M, correct_M = collect_confidence_and_correctness(modelM, test_loader, device)
# split into correct / wrong
probs_L_corr = probs_L[correct_L == 1]
probs_L_wrong = probs_L[correct_L == 0]
probs_M_corr = probs_M[correct_M == 1]
probs_M_wrong = probs_M[correct_M == 0]
# summary stats
def print_summary(name, probs_corr, probs_wrong):
print(f"=== {name} ===")
print(f" correct: n={len(probs_corr)}, mean_conf={np.mean(probs_corr):.3f}, std={np.std(probs_corr):.3f}")
print(f" wrong: n={len(probs_wrong)}, mean_conf={np.mean(probs_wrong):.3f}, std={np.std(probs_wrong):.3f}")
print()
print_summary("Logistic Regression", probs_L_corr, probs_L_wrong)
print_summary("MLP", probs_M_corr, probs_M_wrong)
# Plotting
bins = np.linspace(0.0, 1.0, 21) # 20 bins (0.0-0.05-...-1.0)
fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True)
plt.suptitle("Confidence histograms (correct vs wrong)")
# Logistic column
ax = axes[0]
ax.hist(probs_L_corr, bins=bins, density=True, alpha=0.6, label=f"correct (n={len(probs_L_corr)})", color="#1f77b4")
ax.hist(probs_L_wrong, bins=bins, density=True, alpha=0.6, label=f"wrong (n={len(probs_L_wrong)})", color="#ff7f0e")
ax.set_title("Logistic Regression")
ax.set_xlabel("Max softmax confidence")
ax.set_ylabel("Density")
ax.legend()
ax.grid(True, alpha=0.3)
# MLP column
ax = axes[1]
ax.hist(probs_M_corr, bins=bins, density=True, alpha=0.6, label=f"correct (n={len(probs_M_corr)})", color="#1f77b4")
ax.hist(probs_M_wrong, bins=bins, density=True, alpha=0.6, label=f"wrong (n={len(probs_M_wrong)})", color="#ff7f0e")
ax.set_title("MLP")
ax.set_xlabel("Max softmax confidence")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig("confidence_histograms_logistic_vs_mlp.png", dpi=300, bbox_inches="tight")
plt.show()
=== Logistic Regression === correct: n=8636, mean_conf=0.672, std=0.193 wrong: n=1364, mean_conf=0.374, std=0.123 === MLP === correct: n=9139, mean_conf=0.869, std=0.163 wrong: n=861, mean_conf=0.516, std=0.171
Saliency maps¶
For an image model, a saliency map highlights the pixels that most influenced the model’s decision for one image and one class. Saliency measures: LIf I slightly change each pixel, how much would the prediction change? $$ \text{Saliency}(x) = \left| \frac{\partial S_c(x)}{\partial x} \right| $$
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
def _to_device(tensor, device):
return tensor.to(device)
def compute_saliency(model, images, labels=None, device=None,
target='pred', abs_val=True,
smooth=False, n_samples=25, stdev=0.1):
"""
Compute saliency maps for a batch of images.
model: PyTorch model (in eval mode)
images: torch.Tensor (B, C, H, W) on CPU or device
labels: optional true labels (B,)
device: torch.device or None (auto)
target: 'pred' (use predicted class) or 'true' (use provided labels)
abs_val: if True, take absolute value of gradients (common practice)
smooth: if True, use SmoothGrad (averaging gradients over noisy samples)
n_samples: number of noisy samples for SmoothGrad
stdev: noise std (fraction of data range) for SmoothGrad
Returns:
saliency_maps: numpy array (B, H, W) normalized to [0,1]
preds: predicted classes (B,)
targets_used: classes used for gradient (B,)
"""
model.eval()
if device is None:
device = next(model.parameters()).device
images = images.to(device)
images_orig = images.detach()
B = images.shape[0]
# helper to compute gradients for one input batch
def grads_for_inputs(x, target_idxs):
x.requires_grad_(True)
outputs = model(x) # logits
if outputs.dim() == 1 or (outputs.dim()==2 and outputs.size(1)==1):
# binary logits -> treat as single logit; target must be 0/1
# create scalar outputs for the target class: use logit*(2*t-1)
# but simpler: use probability after sigmoid
probs = torch.sigmoid(outputs.view(-1))
scores = probs if target_idxs is None else probs.gather(0, target_idxs)
else:
# multiclass
if target_idxs is None:
# fallback: use predicted class
preds = outputs.argmax(dim=1)
scores = outputs.gather(1, preds.unsqueeze(1)).squeeze(1)
else:
scores = outputs.gather(1, target_idxs.unsqueeze(1)).squeeze(1)
# backward on each example in batch: sum of scores enables grad calculation
grads = []
# zero grads
model.zero_grad()
scores_sum = scores.sum()
scores_sum.backward(retain_graph=False)
grad = x.grad.detach().clone() # (B, C, H, W)
x.requires_grad_(False)
x.grad = None
return grad
# Determine targets to use
with torch.no_grad():
logits = model(images)
if logits.dim() == 1 or (logits.dim()==2 and logits.size(1)==1):
probs = torch.sigmoid(logits.view(-1))
preds = (probs>0.5).long().cpu()
else:
preds = torch.argmax(logits, dim=1).cpu()
preds = preds.to('cpu')
if target == 'pred':
targets_used = preds.to(device)
elif target == 'true':
if labels is None:
raise ValueError("labels must be provided when target='true'")
targets_used = labels.to(device)
else:
raise ValueError("target must be 'pred' or 'true'")
# compute vanilla or smoothgrad
if not smooth:
grad = grads_for_inputs(images, targets_used).cpu() # (B, C, H, W)
else:
# SmoothGrad: average gradients from noisy samples
grads_accum = torch.zeros_like(images)
data_range = images.max() - images.min()
noise_std = stdev * (data_range if data_range > 0 else 1.0)
for _ in range(n_samples):
noise = torch.randn_like(images) * noise_std
noisy = (images_orig + noise).to(device)
g = grads_for_inputs(noisy, targets_used).cpu()
grads_accum += g
grad = grads_accum / float(n_samples)
# convert grad → saliency map (collapse channel)
# common options: absolute max across channels, L2 across channels
grad_np = grad.numpy() # (B, C, H, W)
# use absolute max across channels:
saliency = np.max(np.abs(grad_np), axis=1) if abs_val else np.max(grad_np, axis=1)
# normalize each image to [0,1]
saliency_norm = np.zeros_like(saliency)
for i in range(B):
s = saliency[i]
s = s - s.min()
if s.max() > 0:
s = s / s.max()
saliency_norm[i] = s
return saliency_norm, preds.cpu().numpy(), targets_used.cpu().numpy()
def plot_saliency_grid(images, saliency_maps, preds=None, labels=None, cmap='hot', alpha=0.6, ncols=5, title=None, unnormalize_fn=None):
"""
images: torch.Tensor (B, C, H, W) on CPU
saliency_maps: np.array (B, H, W) in [0,1]
preds, labels: optional lists or arrays (B,)
unnormalize_fn: function to convert image tensor -> numpy HxW or HxWx3 (optional)
"""
B = images.shape[0]
ncols = min(ncols, B)
nrows = int(np.ceil(B / ncols))
plt.figure(figsize=(3.2 * ncols, 3 * nrows))
if title:
plt.suptitle(title, fontsize=14)
for i in range(B):
img = images[i]
sal = saliency_maps[i]
ax = plt.subplot(nrows, ncols, i + 1)
# get image numpy for plotting
if unnormalize_fn:
npimg = unnormalize_fn(img)
else:
npimg = img.detach().cpu().numpy()
# convert CHW -> HWC if needed
if npimg.ndim == 3 and npimg.shape[0] in (1,3):
npimg = np.transpose(npimg, (1,2,0))
# if single-channel, reduce to HxW
if npimg.ndim == 3 and npimg.shape[2] == 1:
npimg = npimg[:,:,0]
# show base image
if npimg.ndim == 2:
ax.imshow(npimg, cmap='gray')
ax.imshow(sal, cmap=cmap, alpha=alpha, vmin=0, vmax=1)
else:
# color image: show image and overlay saliency as heatmap (use clipped)
ax.imshow(np.clip(npimg, 0, 1))
ax.imshow(sal, cmap=cmap, alpha=alpha, vmin=0, vmax=1)
title_lines = []
if preds is not None:
title_lines.append(f"P:{int(preds[i])}")
if labels is not None:
title_lines.append(f"T:{int(labels[i])}")
if title_lines:
ax.set_title(" / ".join(title_lines), fontsize=9)
ax.axis('off')
plt.tight_layout(rect=[0,0,1,0.95])
plt.show()
# -----------------------------
# Usage example: compare two models side-by-side
# -----------------------------
def compare_models_saliency(modelA, modelB, images_batch, labels_batch=None,
device=None, target='pred', smooth=False, n_samples=25, stdev=0.15,
unnormalize_fn=None, n_show=6):
"""
Compute and plot saliency maps for two models for the same inputs side-by-side.
modelA, modelB: PyTorch models (in eval mode)
images_batch: torch.Tensor (B, C, H, W) -- pick a small batch (e.g. B=6)
labels_batch: optional true labels (B,)
unnormalize_fn: optional function(img_tensor) -> numpy for displaying original image properly
"""
if device is None:
device = next(modelA.parameters()).device
images_cpu = images_batch.detach().cpu()
# compute for model A
salA, predsA, targetsA = compute_saliency(modelA, images_batch, labels=labels_batch, device=device,
target=target, abs_val=True, smooth=smooth, n_samples=n_samples, stdev=stdev)
# compute for model B
salB, predsB, targetsB = compute_saliency(modelB, images_batch, labels=labels_batch, device=device,
target=target, abs_val=True, smooth=smooth, n_samples=n_samples, stdev=stdev)
# choose how many to show
n_show = min(n_show, images_batch.shape[0])
# plot grid: for simplicity, show modelA then modelB in adjacent rows
print("Top: ModelL | Bottom: ModelM")
plt.figure(figsize=(6 * n_show, 6))
for i in range(n_show):
# original image (top row)
ax = plt.subplot(2, n_show, i + 1)
npimg = (unnormalize_fn(images_cpu[i]) if unnormalize_fn else images_cpu[i].numpy())
if npimg.ndim == 3 and npimg.shape[0] in (1,3): # CHW
npimg = np.transpose(npimg, (1,2,0))
if npimg.shape[2] == 1:
npimg = npimg[:,:,0]
ax.imshow(npimg, cmap='gray' if npimg.ndim==2 else None)
ax.set_title(f"Input #{i}")
ax.axis('off')
# modelA saliency (bottom-left)
ax = plt.subplot(2, n_show, n_show + i + 1)
if npimg.ndim == 2:
ax.imshow(npimg, cmap='gray')
ax.imshow(salA[i], cmap='hot', alpha=0.6, vmin=0, vmax=1)
else:
ax.imshow(npimg)
ax.imshow(salA[i], cmap='hot', alpha=0.6, vmin=0, vmax=1)
ax.set_title(f"A P:{predsA[i]} T:{(labels_batch[i].item() if labels_batch is not None else '-')}")
ax.axis('off')
# overlay ModelB saliency next to it (bonus: show as separate figure)
# show on the same subplot? better: open a new figure for ModelB side-by-side
plt.suptitle("Top row: inputs — Bottom row: ModelA saliency", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
# Now show ModelB saliency in a separate figure (side-by-side comparison easier to inspect)
plt.figure(figsize=(6 * n_show, 3))
for i in range(n_show):
ax = plt.subplot(1, n_show, i + 1)
npimg = (unnormalize_fn(images_cpu[i]) if unnormalize_fn else images_cpu[i].numpy())
if npimg.ndim == 3 and npimg.shape[0] in (1,3):
npimg = np.transpose(npimg, (1,2,0))
if npimg.shape[2] == 1:
npimg = npimg[:,:,0]
if npimg.ndim == 2:
ax.imshow(npimg, cmap='gray')
ax.imshow(salB[i], cmap='hot', alpha=0.6, vmin=0, vmax=1)
else:
ax.imshow(npimg)
ax.imshow(salB[i], cmap='hot', alpha=0.6, vmin=0, vmax=1)
ax.set_title(f"B P:{predsB[i]} T:{(labels_batch[i].item() if labels_batch is not None else '-')}")
ax.axis('off')
plt.suptitle("ModelB saliency maps", fontsize=14)
plt.tight_layout(rect=[0,0,1,0.95])
plt.show()
images, labels = next(iter(test_loader))
k = 6
images = images[:k]
labels = labels[:k]
compare_models_saliency(modelL, modelM, images, labels, device=device,
target='pred', smooth=True, n_samples=30, stdev=0.12,
unnormalize_fn=None, n_show=6)
Top: ModelL | Bottom: ModelM
### saving Saliency Maps
import matplotlib.pyplot as plt
import numpy as np
import torch
def _img_to_numpy_for_display(img_tensor, unnormalize_fn=None):
"""Return HxW or HxWx3 numpy array in [0,1] for plotting."""
if unnormalize_fn:
return unnormalize_fn(img_tensor)
arr = img_tensor.detach().cpu().numpy()
# CHW -> HWC if needed
if arr.ndim == 3 and arr.shape[0] in (1,3):
arr = np.transpose(arr, (1,2,0))
if arr.shape[2] == 1:
arr = arr[:,:,0]
return arr
def compare_models_saliency_return_figs(
modelA, modelB, images_batch, labels_batch=None,
device=None, target='pred', smooth=False, n_samples=25, stdev=0.15,
unnormalize_fn=None, n_show=6, show=True
):
"""
Compute saliency maps for modelA and modelB and return matplotlib Figure objects.
Returns: (figA, figB)
If show=True the figures are displayed; otherwise they are not shown but still returned.
"""
if device is None:
device = next(modelA.parameters()).device
images_batch = images_batch.to(device)
B = images_batch.shape[0]
n_show = min(n_show, B)
# compute saliency for both models (uses your compute_saliency function)
salA, predsA, targetsA = compute_saliency(
modelA, images_batch, labels=labels_batch, device=device,
target=target, abs_val=True, smooth=smooth, n_samples=n_samples, stdev=stdev
)
salB, predsB, targetsB = compute_saliency(
modelB, images_batch, labels=labels_batch, device=device,
target=target, abs_val=True, smooth=smooth, n_samples=n_samples, stdev=stdev
)
images_cpu = images_batch.detach().cpu()
# ---------- Figure A: Inputs (top row) + ModelA saliency (bottom row) ----------
figA, axesA = plt.subplots(2, n_show, figsize=(3.5 * n_show, 6))
figA.suptitle("Top: input images — Bottom: ModelA saliency", fontsize=14)
for i in range(n_show):
# input image (top)
ax = axesA[0, i] if n_show > 1 else axesA[0]
npimg = _img_to_numpy_for_display(images_cpu[i], unnormalize_fn)
if npimg.ndim == 2:
ax.imshow(npimg, cmap='gray')
else:
ax.imshow(np.clip(npimg, 0, 1))
ax.set_title(f"Input #{i}")
ax.axis('off')
# ModelA saliency (bottom)
ax = axesA[1, i] if n_show > 1 else axesA[1]
if npimg.ndim == 2:
ax.imshow(npimg, cmap='gray')
ax.imshow(salA[i], cmap='hot', alpha=0.6, vmin=0, vmax=1)
else:
ax.imshow(np.clip(npimg, 0, 1))
ax.imshow(salA[i], cmap='hot', alpha=0.6, vmin=0, vmax=1)
title_text = f"A P:{predsA[i]}"
if labels_batch is not None:
title_text += f" T:{int(labels_batch[i].item())}"
ax.set_title(title_text)
ax.axis('off')
plt.tight_layout(rect=[0,0,1,0.95])
# ---------- Figure B: ModelB saliency in one row ----------
figB, axesB = plt.subplots(1, n_show, figsize=(3.5 * n_show, 3.5))
figB.suptitle("ModelB saliency maps", fontsize=14)
for i in range(n_show):
ax = axesB[i] if n_show > 1 else axesB
npimg = _img_to_numpy_for_display(images_cpu[i], unnormalize_fn)
if npimg.ndim == 2:
ax.imshow(npimg, cmap='gray')
ax.imshow(salB[i], cmap='hot', alpha=0.6, vmin=0, vmax=1)
else:
ax.imshow(np.clip(npimg, 0, 1))
ax.imshow(salB[i], cmap='hot', alpha=0.6, vmin=0, vmax=1)
title_text = f"B P:{predsB[i]}"
if labels_batch is not None:
title_text += f" T:{int(labels_batch[i].item())}"
ax.set_title(title_text)
ax.axis('off')
plt.tight_layout(rect=[0,0,1,0.95])
if show:
plt.show() # display both figures
else:
# do not call plt.show() so figures remain available to save
pass
return figA, figB
images, labels = next(iter(test_loader))
images = images[:6]
labels = labels[:6]
# compute saliency and get figure objects (do not auto-show so we can save)
figA, figB = compare_models_saliency_return_figs(
modelL, modelM, images, labels,
device=device, target='pred', smooth=True, n_samples=30, stdev=0.12,
unnormalize_fn=None, n_show=6, show=False
)
# ensure output directory exists
# save them
figA.savefig("saliency_modelL.png", dpi=300, bbox_inches="tight")
figB.savefig("saliency_modelM.png", dpi=300, bbox_inches="tight")
# close to free memory
plt.close(figA)
plt.close(figB)
print("Saved: figures/saliency_modelL.png and figures/saliency_modelM.png")
Saved: figures/saliency_modelL.png and figures/saliency_modelM.png
# Requires your compute_saliency() from the message and the per_class_average_saliency/plot_per_class_comparison functions.
import numpy as np
import torch
from tqdm import tqdm # optional for progress bar
def collect_saliency_for_loader(model, dataloader, device=None,
target='pred', smooth=False, n_samples=25, stdev=0.12,
max_batches=None):
"""
Run compute_saliency() over dataloader and collect:
sal_all: (N, H, W) numpy, normalized [0,1]
preds_all: (N,) ints
confs_all: (N,) floats (max softmax prob)
labels_all: (N,) ints (true labels if provided by loader)
max_batches: optional limit for quick debugging
"""
model.eval()
if device is None:
device = next(model.parameters()).device
sal_list = []
preds_list = []
conf_list = []
labels_list = []
loss = torch.nn.CrossEntropyLoss() # only to infer nothing; not used
for b_idx, (images, labels) in enumerate(tqdm(dataloader, desc="Saliency batches")):
if (max_batches is not None) and (b_idx >= max_batches):
break
# compute saliency & preds using your function
sal_maps, preds_np, targets_used = compute_saliency(
model, images, labels=labels, device=device,
target=target, abs_val=True, smooth=smooth, n_samples=n_samples, stdev=stdev
)
# sal_maps: (B, H, W) numpy, preds_np: (B,), targets_used: (B,)
# compute confidences (max softmax) on the batch
with torch.no_grad():
images_dev = images.to(device)
logits = model(images_dev)
if logits.dim() == 1 or (logits.dim()==2 and logits.size(1) == 1):
probs = torch.sigmoid(logits.view(-1))
# For binary, treat confidence as prob for predicted class
confs = (probs > 0.5).float() # boolean -> 0/1, but better to use prob itself:
confs = probs.cpu().numpy()
else:
probs = torch.softmax(logits, dim=1)
confs = probs.max(dim=1).values.cpu().numpy() # (B,)
# append
sal_list.append(sal_maps) # numpy (B,H,W)
preds_list.append(preds_np) # numpy (B,)
conf_list.append(confs) # numpy (B,)
labels_list.append(labels.cpu().numpy())# numpy (B,)
# concat
sal_all = np.concatenate(sal_list, axis=0)
preds_all = np.concatenate(preds_list, axis=0)
conf_all = np.concatenate(conf_list, axis=0)
labels_all= np.concatenate(labels_list, axis=0)
return sal_all, preds_all, conf_all, labels_all
# Example usage:
# Assuming train/test loaders: test_loader (or val_loader)
# and models: modelL, modelM already on device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelL.to(device)
modelM.to(device)
# Collect saliency for both models (set smooth=True for SmoothGrad if you want)
salL, predsL, confL, labels = collect_saliency_for_loader(modelL, test_loader, device=device,
target='pred', smooth=False, n_samples=25, stdev=0.12)
salM, predsM, confM, labelsM = collect_saliency_for_loader(modelM, test_loader, device=device,
target='pred', smooth=False, n_samples=25, stdev=0.12)
# Sanity: labels should match
assert np.array_equal(labels, labelsM), "Labels mismatch between runs (shouldn't happen)."
# Now you have:
# salL: (N, H, W)
# salM: (N, H, W)
# labels: (N,)
# -------------------------
# Compute per-class averages and plot (re-using functions from earlier)
# -------------------------
# If you don't have per_class_average_saliency/plot_per_class_comparison defined in the cell,
# paste the definitions from the earlier message before running these lines.
avgL, counts = per_class_average_saliency(salL, labels, num_classes=10, normalize="max")
avgM, _ = per_class_average_saliency(salM, labels, num_classes=10, normalize="max")
# Plot side-by-side comparison (top row ModelL, bottom row ModelM)
plot_per_class_comparison(avgL, avgM, counts, cmap="hot", show_counts=True)
# -------------------------
# Bonus: save arrays for later
# -------------------------
np.save("saliency_modelL.npy", salL)
np.save("saliency_modelM.npy", salM)
np.save("labels.npy", labels)
np.save("predsL.npy", predsL)
np.save("predsM.npy", predsM)
np.save("confsL.npy", confL)
np.save("confsM.npy", confM)
print("Saved saliency arrays and confidences to disk.")
Saliency batches: 100%|██████████| 40/40 [00:00<00:00, 99.97it/s] Saliency batches: 100%|██████████| 40/40 [00:00<00:00, 82.57it/s]
Saved saliency arrays and confidences to disk.