Explaining Custom CNN MNIST Classification Using the Attributions Explainer

1. Design the CNN from scatch

[ ]:
import torch, torchvision
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn import functional as F
torch.manual_seed(0)

import numpy as np

batch_size = 128
num_epochs = 1
device = torch.device('cpu')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5),
            nn.Dropout(),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(50, 10),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 320)
        x = self.fc_layers(x)
        return x

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output.log(), target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('mnist_data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('mnist_data', train=False, transform=transforms.Compose([
                       transforms.ToTensor()
                   ])),
    batch_size=batch_size, shuffle=True)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

2. Train the CNN on the MNIST dataset

[ ]:
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, epoch)

3. Predict the MNIST test data

[ ]:
# test the model
model.eval()
test_loss = 0
correct = 0
y_true = torch.empty(0)
y_pred = torch.empty((0, 10))
X_test = torch.empty((0, 1, 28, 28))

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        X_test = torch.cat((X_test, data))
        y_true, y_pred = torch.cat((y_true, target)), torch.cat((y_pred, output))

        test_loss += F.nll_loss(output.log(), target).item() # sum up batch loss
        pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))

4. Survey performance across all classes using the metrics_explainer plugin

[ ]:
from intel_ai_safety.explainer import metrics

classes = np.array(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])

cm = metrics.confusion_matrix(y_true, y_pred, classes)
cm.visualize()
print(cm.report)
[ ]:
plotter = metrics.plot(y_true, y_pred, classes)
plotter.pr_curve()
[ ]:
plotter.roc_curve()

5. Explain performance across the classes using the feature_attributions_explainer plugin

From (4), it can be observed from the confusion matrix that classes 4 and 9 perform poorly. Additionallly, there is a high misclassification rate exclusively amongst the two labels. In other words, it appears that the CNN if confusing 4’s with 9’s, and vice-versa. 7.4% of all the 9 examples were misclassified as 4, and 10% of all the 4 examples were misclassified as 9.

Let’s take a closer look at the pixel-based shap values for the test examples where the CNN predicts ‘9’ when the correct groundtruth label is ‘4’.

[ ]:
# get the prediction indices where the model predicted 9
pred_idx = list(np.where(np.argmax(y_pred, axis=1) == 9)[0])
# get the groundtruth indices where the true label is 4
gt_idx = list(np.where(y_true == 4)[0])

# collect the indices where the CNN misclassified 4 as 9
matches = list(set(pred_idx).intersection(gt_idx))
[ ]:
from intel_ai_safety.explainer.attributions import attributions
# run the deep explainer
deViz = attributions.deep_explainer(model, X_test[:100], X_test[matches[:6]], classes)
deViz.visualize()
[ ]:
# instatiate gradient explainer object
# run the deep explainer
grViz = attributions.gradient_explainer(model, X_test[:100],  X_test[matches[:6]], classes, 2)
grViz.visualize()

6. Conclusion

From the deep and gradient explainer visuals, it can be observed that the CNN pays close attention to the top of the digit in distinguishing between a 4 and a 9. On the first and last row of the above gradient explainer visualization we can the 4’s are closed. The contributes to postiive shap values (red) for the 9 classification. This begins explaining why the CNN is confusing the two digits.