Explaining Custom CNN MNIST Classification Using the Attributions Explainer
1. Design the CNN from scatch
[ ]:
import torch, torchvision
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn import functional as F
torch.manual_seed(0)
import numpy as np
batch_size = 128
num_epochs = 1
device = torch.device('cpu')
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Conv2d(10, 20, kernel_size=5),
nn.Dropout(),
nn.MaxPool2d(2),
nn.ReLU(),
)
self.fc_layers = nn.Sequential(
nn.Linear(320, 50),
nn.ReLU(),
nn.Dropout(),
nn.Linear(50, 10),
nn.Softmax(dim=1)
)
def forward(self, x):
x = self.conv_layers(x)
x = x.view(-1, 320)
x = self.fc_layers(x)
return x
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output.log(), target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('mnist_data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor()
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('mnist_data', train=False, transform=transforms.Compose([
transforms.ToTensor()
])),
batch_size=batch_size, shuffle=True)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
2. Train the CNN on the MNIST dataset
[ ]:
for epoch in range(1, num_epochs + 1):
train(model, device, train_loader, optimizer, epoch)
3. Predict the MNIST test data
[ ]:
# test the model
model.eval()
test_loss = 0
correct = 0
y_true = torch.empty(0)
y_pred = torch.empty((0, 10))
X_test = torch.empty((0, 1, 28, 28))
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
X_test = torch.cat((X_test, data))
y_true, y_pred = torch.cat((y_true, target)), torch.cat((y_pred, output))
test_loss += F.nll_loss(output.log(), target).item() # sum up batch loss
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
4. Survey performance across all classes using the metrics_explainer plugin
[ ]:
from intel_ai_safety.explainer import metrics
classes = np.array(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
cm = metrics.confusion_matrix(y_true, y_pred, classes)
cm.visualize()
print(cm.report)
[ ]:
plotter = metrics.plot(y_true, y_pred, classes)
plotter.pr_curve()
[ ]:
plotter.roc_curve()
5. Explain performance across the classes using the feature_attributions_explainer plugin
From (4), it can be observed from the confusion matrix that classes 4 and 9 perform poorly. Additionallly, there is a high misclassification rate exclusively amongst the two labels. In other words, it appears that the CNN if confusing 4’s with 9’s, and vice-versa. 7.4% of all the 9 examples were misclassified as 4, and 10% of all the 4 examples were misclassified as 9.
Let’s take a closer look at the pixel-based shap values for the test examples where the CNN predicts ‘9’ when the correct groundtruth label is ‘4’.
[ ]:
# get the prediction indices where the model predicted 9
pred_idx = list(np.where(np.argmax(y_pred, axis=1) == 9)[0])
# get the groundtruth indices where the true label is 4
gt_idx = list(np.where(y_true == 4)[0])
# collect the indices where the CNN misclassified 4 as 9
matches = list(set(pred_idx).intersection(gt_idx))
[ ]:
from intel_ai_safety.explainer.attributions import attributions
# run the deep explainer
deViz = attributions.deep_explainer(model, X_test[:100], X_test[matches[:6]], classes)
deViz.visualize()
[ ]:
# instatiate gradient explainer object
# run the deep explainer
grViz = attributions.gradient_explainer(model, X_test[:100], X_test[matches[:6]], classes, 2)
grViz.visualize()