Multimodal Breast Cancer Detection Explainability using the Intel® Explainable AI API

This application is a multimodal solution for predicting cancer diagnosis using categorized contrast enhanced mammography data and radiology notes. It trains two models - one for image classification and the other for text classification.

Import Dependencies and Setup Directories

[ ]:
# This notebook requires the latest version of intel-transfer-learning (v0.7.0)
# The package and directions to install it can be found at its repo:
# https://github.com/Intel/transfer-learning

! pip install --no-cache-dir  nltk docx2txt openpyxl et-xmlfile schema
[ ]:
import numpy as np
import os
import pandas as pd
import tensorflow as tf
import torch

from transformers import EvalPrediction, TrainingArguments, pipeline

# tlt imports
from tlt.datasets import dataset_factory
from tlt.models import model_factory

# explainability imports
import matplotlib.pyplot as plt
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import nltk
from nltk.corpus import words
import string
import shap
import warnings
warnings.filterwarnings( "ignore", module = "matplotlib\..*" )

# Specify the root directory where the images and annotations are located
dataset_dir = os.path.join(os.environ["DATASET_DIR"]) if "DATASET_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "dataset")

# Specify a directory for output
output_dir = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "output")

print("Dataset directory:", dataset_dir)
print("Output directory:", output_dir)

Dataset

Download the images and radiology annotations from https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=109379611 and save in the path <dataset_dir>/brca/data.

[ ]:
! python prepare_nlp_data.py --data_root {dataset_dir}/brca/data
[ ]:
! python prepare_vision_data.py --data_root {dataset_dir}/brca/data

Image files should have the .jpg extension and be arranged in subfolders for each class. The annotation file should be a .csv. The final brca dataset directory should look something like this:

brca
  ├── data
  │   ├── PKG - CDD-CESM
  │   ├── Medical reports for cases .zip
  │   ├── Radiology manual annotations.xlsx
  │   └── Radiology_hand_drawn_segmentations_v2.csv
  ├── annotation
  │   └── annotation.csv
  └── vision_images
      ├── Benign
      │   ├── P100_L_CM_CC.jpg
      │   ├── P100_L_CM_MLO.jpg
      │   └── ...
      ├── Malignant
      │   ├── P102_R_CM_CC.jpg
      │   ├── P102_R_CM_MLO.jpg
      │   └── ...
      └── Normal
          ├── P100_R_CM_CC.jpg
          ├── P100_R_CM_MLO.jpg
          └── ...
[ ]:
# User input needed - supply the path to the images in the dataset_dir according to your system
source_image_path = os.path.join(dataset_dir, 'brca', 'data', 'vision_images')
image_path = source_image_path

# User input needed - supply the path and name of the annotation file in the dataset_dir
source_annotation_path = os.path.join(dataset_dir, 'brca', 'data', 'annotation', 'annotation.csv')
annotation_path = source_annotation_path

Optional: Group Data by Patient ID

This section is not required to run the workload, but it is helpful to assign all of a subject’s records to be entirely in the train set or test set. This section will do a random stratification based on patient ID and save new copies of the grouped data files.

[ ]:
from data_utils import split_images, split_annotation

grouped_image_path = '{}_grouped'.format(source_image_path)

if os.path.isdir(grouped_image_path):
    print("Grouped directory already exists and will be used: {}".format(grouped_image_path))
else:
    split_images(source_image_path, grouped_image_path)

train_image_path = os.path.join(grouped_image_path, 'train')
test_image_path = os.path.join(grouped_image_path, 'test')
[ ]:
from data_utils import split_images, split_annotation

file_dir, file_name = os.path.split(source_annotation_path)
grouped_annotation_path = os.path.join(file_dir, '{}_grouped.csv'.format(os.path.splitext(file_name)[0]))

if os.path.isfile(grouped_annotation_path):
    print("Grouped annotation already exists and will be used: {}".format(grouped_annotation_path))
else:
    train_dataset, test_dataset = split_annotation(file_dir, file_name, train_image_path, test_image_path)
    train_dataset.to_csv(grouped_annotation_path, index=False)
    test_dataset.to_csv(grouped_annotation_path[:-4] + '_test.csv', index=False)
    print('Grouped training annotation saved to: {}'.format(grouped_annotation_path))
    print('Grouped testing annotation saved to: {}'.format(grouped_annotation_path[:-4] + '_test.csv'))

train_annotation_path = grouped_annotation_path
test_annotation_path = grouped_annotation_path[:-4] + '_test.csv'
label_col = 0  # Index of the label column in the grouped data file

Model 1: Image Classification with PyTorch

Get the Model and Dataset

Call the model factory to get a pretrained model from PyTorch Hub and the dataset factory to load the images from their location. The get_model function returns a model object that will later be used for training. We will use resnet50 by default.

[ ]:
viz_model = model_factory.get_model(model_name="resnet50", framework='pytorch')

# Load the dataset from the custom dataset path
train_viz_dataset = dataset_factory.load_dataset(dataset_dir=train_image_path,
                                       use_case='image_classification',
                                       framework='pytorch')

test_viz_dataset = dataset_factory.load_dataset(dataset_dir=test_image_path,
                                       use_case='image_classification',
                                       framework='pytorch')

print("Class names:", str(train_viz_dataset.class_names))

Data Preparation

Once you have your dataset loaded, use the following cell to preprocess the dataset. We split the images into training and validation subsets, resize them to match the model, and then batch the images.

[ ]:
batch_size = 16
# shuffle split the training dataset
train_viz_dataset.shuffle_split(train_pct=.80, val_pct=.20, seed=3)
train_viz_dataset.preprocess(viz_model.image_size, batch_size=batch_size)
test_viz_dataset.preprocess(viz_model.image_size, batch_size=batch_size)

Image dataset analysis

Let’s take a look at the dataset and verify that we are loading the data correctly. This includes looking at the distributions amongst the training and validation and visual confirmation of the images themselves.

[ ]:
# Create a label map function and reverse label map for the dataset
def label_map_func(label):
        if label == 'Benign':
            return 0
        elif label == 'Malignant':
            return 1
        elif label == 'Normal':
            return 2

reverse_label_map = {0: 'Benign', 1: 'Malignant', 2: 'Normal'}
[ ]:
train_label_count = {'Benign': 0, 'Malignant': 0, 'Normal': 0}

for x, y in train_viz_dataset.train_subset:
    train_label_count[reverse_label_map[y]] += 1

print('Training label distribution:')
train_label_count
[ ]:
valid_label_count = {'Benign': 0, 'Malignant': 0, 'Normal': 0}

for x, y in train_viz_dataset.validation_subset:
    valid_label_count[reverse_label_map[y]] += 1

print('Validation label distribution:')
valid_label_count
[ ]:
test_label_count = {'Benign': 0, 'Malignant': 0, 'Normal': 0}

for x, y in test_viz_dataset.dataset:
    test_label_count[reverse_label_map[y]] += 1

print('Validation label distribution:')
test_label_count
[ ]:
# get datsaet distrubtions
form = {'type':'domain'}
fig = make_subplots(rows=1, cols=3, specs=[[form, form, form]], subplot_titles=['Training', 'Validation', 'Testing'])
fig.add_trace(go.Pie(values=list(train_label_count.values()), labels=list(train_label_count.keys())), 1, 1)
fig.add_trace(go.Pie(values=list(valid_label_count.values()), labels=list(valid_label_count.keys())), 1, 2)
fig.add_trace(go.Pie(values=list(test_label_count.values()), labels=list(valid_label_count.keys())), 1, 3)

fig.update_layout(height=600, width=800, title_text="Label Distributions")
fig.show()
[ ]:
def get_examples(dataset, reverse_label_map, n=6):
    # get n images from each label in dataset and return as dictionary

    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

    example_images = {'Benign': [], 'Malignant': [], 'Normal': []}
    for x, y in loader:
        for i, label in enumerate(y):
            label_name = reverse_label_map[int(label)]
            if len(example_images[label_name]) < n:
                example_images[label_name].append(x[i])
        if len(example_images['Malignant']) == n and\
        len(example_images['Benign']) == n and\
        len(example_images['Normal']) == n:
            break
    return example_images
[ ]:
# plot some training examples
fig = plt.figure(figsize=(12,6))
columns = 6
rows = 3
fig.suptitle('Training Torch Tensor examples', size=16)


train_example_images = get_examples(train_viz_dataset.train_subset, reverse_label_map)
for i in range(1, columns*rows +1):
    idx = i - 1
    if idx < 6:
        img = train_example_images['Benign'][idx]
    elif idx >= 6 and idx < 12:
        img = train_example_images['Malignant'][idx - 6]
    else:
        img = train_example_images['Normal'][idx - 12]

    fig.add_subplot(rows, columns, i)
    plt.axis('off')
    plt.tight_layout()
    if idx == 0 or idx == 6 or idx == 12:
        plt.axis('on')
        label_name = reverse_label_map[int(idx/6)]
        plt.ylabel(label_name, fontsize=16)
        plt.tick_params(axis='x', bottom=False, labelbottom=False)
        plt.tick_params(axis='y', left=False, labelleft=False)

    plt.imshow(torch.movedim(img, 0, 2).detach().cpu().numpy().astype(np.uint8))

plt.show()
[ ]:
# plot some validation images
fig = plt.figure(figsize=(12,6))
columns = 6
rows = 3
fig.suptitle('Validation Torch Tensor examples', size=16)


valid_example_images = get_examples(train_viz_dataset.validation_subset, reverse_label_map)

for i in range(1, columns*rows +1):
    idx = i - 1
    if idx < 6:
        img = valid_example_images['Benign'][idx]
    elif idx >= 6 and idx < 12:
        img = valid_example_images['Malignant'][idx - 6]
    else:
        img = valid_example_images['Normal'][idx - 12]

    fig.add_subplot(rows, columns, i)
    plt.axis('off')
    plt.tight_layout()
    if idx == 0 or idx == 6 or idx == 12:
        plt.axis('on')
        label_name = reverse_label_map[int(idx/6)]
        plt.ylabel(label_name, fontsize=16)
        plt.tick_params(axis='x', bottom=False, labelbottom=False)
        plt.tick_params(axis='y', left=False, labelleft=False)

    plt.imshow(torch.movedim(img, 0, 2).detach().cpu().numpy().astype(np.uint8))

plt.show()

Transfer Learning

This step calls the model’s train function with the dataset that was just prepared. The training function will get the PyTorch feature vector and add on a dense layer based on the number of classes in the dataset. The model is then compiled and trained based on the number of epochs specified in the argument. We also add two more dense layers using the extra_layers parameter.

To optionally insert additional dense layers between the base model and output layer, extra_layers=[1024, 512] will insert two dense layers, the first with 1024 neurons and the second with 512 neurons.

[ ]:
viz_history = viz_model.train(train_viz_dataset, output_dir=output_dir, epochs=5, seed=10, extra_layers=[1024, 512], ipex_optimize=False)
[ ]:
validation_viz_metrics = viz_model.evaluate(train_viz_dataset)
test_viz_metrics = viz_model.evaluate(test_viz_dataset)
print(validation_viz_metrics)
print(test_viz_metrics)

Save the Computer Vision Model

[ ]:
saved_model_dir = viz_model.export(output_dir)

Error Analysis

Analyzing the errors via a confusion matrix and ROC and PR curves will help us identify if our model is exibiting any label bias

[ ]:
from scipy.special import softmax
y_pred = []
# get the logit predictions and then convert to probabilities
for batch in test_viz_dataset.dataset:
    y_pred.append(softmax(viz_model._model(batch[0][None, :]).detach().numpy())[0])

y_true =[y for x, y in test_viz_dataset.dataset]
[ ]:
from intel_ai_safety.explainer import metrics
viz_cm = metrics.confusion_matrix(y_true, y_pred, test_viz_dataset.class_names)
viz_cm.visualize()
print(viz_cm.report)
[ ]:
plotter = metrics.plot(y_true, y_pred, test_viz_dataset.class_names)
plotter.pr_curve()
[ ]:
plotter.roc_curve()

Explainability

[ ]:
# convert one-hot encoded predictions to the index labels
y_pred_labels = np.array(y_pred).argmax(axis=1)

# get the malignant indexes and then the normal and benign prediction indexes
mal_idxs = np.where(np.array(y_true) == label_map_func('Malignant'))[0].tolist()
nor_preds = np.where(np.array(y_pred_labels) == label_map_func('Normal'))[0].tolist()
ben_preds = np.where(np.array(y_pred_labels) == label_map_func('Benign'))[0].tolist()
[ ]:
# get mal examples that were misclassified as ben
mal_classified_as_nor = list(set(mal_idxs).intersection(nor_preds))

# get mal examples that were misclassified as ben
mal_classified_as_ben = list(set(mal_idxs).intersection(ben_preds))
[ ]:
# get the images for all mals predicted as nors
mal_as_nor_images = [test_viz_dataset.dataset[i][0] for i in mal_classified_as_nor]

# get the images for all mals predicted as bens
mal_as_ben_images = [test_viz_dataset.dataset[i][0] for i in mal_classified_as_ben]
[ ]:
from skimage import io
# plot 14 mal_as_nor images
fig = plt.figure(figsize=(12,6))
columns = 7
rows = 2

for i in range(1, columns*rows +1):
    if i == len(mal_as_nor_images):
        break
    idx = i - 1

    fig.add_subplot(rows, columns, i)
    plt.axis('off')
    plt.tight_layout()

    plt.imshow(torch.movedim(mal_as_nor_images[idx], 0, 2).detach().cpu().numpy().astype(np.uint8))

fig.suptitle('Malignant predicted as Normal', fontsize=18)
plt.tight_layout()
plt.show()
[ ]:
# lets calculate gradcam on the 0th, 1st and 10th images since they
# seem to have tnhe clearest visual of a malignant tumor
from intel_ai_safety.explainer.cam import pt_cam as cam

images = [torch.movedim(mal_as_nor_images[0], 0, 2).detach().cpu().numpy().astype(np.uint8),
          torch.movedim(mal_as_nor_images[3], 0, 2).detach().cpu().numpy().astype(np.uint8),
          torch.movedim(mal_as_nor_images[5], 0, 2).detach().cpu().numpy().astype(np.uint8)]


final_image_dim = (224, 224)
targetLayer = viz_model._model.layer4
xgc = cam.x_gradcam(viz_model._model, targetLayer,
                      label_map_func('Normal'),
                      images[0],
                      final_image_dim,
                      'cpu')

xgc.visualize()

xgc = cam.x_gradcam(viz_model._model, targetLayer,
                      label_map_func('Normal'),
                      images[1],
                      final_image_dim,
                      'cpu')

xgc.visualize()

xgc = cam.x_gradcam(viz_model._model, targetLayer,
                      label_map_func('Normal'),
                      images[2],
                      final_image_dim,
                      'cpu')

xgc.visualize()
[ ]:
# plot 14 mal_as_ben images
fig = plt.figure(figsize=(12,6))
columns = 7
rows = 2

for i in range(1, columns*rows +1):
    idx = i - 1
    if idx == len(mal_as_ben_images):
        break

    fig.add_subplot(rows, columns, i)
    plt.axis('off')
    plt.tight_layout()

    plt.imshow(torch.movedim(mal_as_ben_images[idx], 0, 2).detach().cpu().numpy().astype(np.uint8))

fig.suptitle('Malignant predicted as Benign', fontsize=18)
plt.tight_layout()
plt.show()
[ ]:
# lets calculate gradcam on the 5th, 10th and 11th images since they
# seem to have tnhe clearest visual of a malignant tumor

images = [torch.movedim(mal_as_ben_images[0], 0, 2).detach().cpu().numpy().astype(np.uint8),
          torch.movedim(mal_as_ben_images[1], 0, 2).detach().cpu().numpy().astype(np.uint8),
          torch.movedim(mal_as_ben_images[2], 0, 2).detach().cpu().numpy().astype(np.uint8)]



final_image_dim = (224, 224)
targetLayer = viz_model._model.layer4
xgc = cam.x_gradcam(viz_model._model, targetLayer,
                      label_map_func('Benign'),
                      images[0],
                      final_image_dim,
                      'cpu')

xgc.visualize()

xgc = cam.x_gradcam(viz_model._model, targetLayer,
                      label_map_func('Benign'),
                      images[1],
                      final_image_dim,
                      'cpu')

xgc.visualize()

xgc = cam.x_gradcam(viz_model._model, targetLayer,
                      label_map_func('Benign'),
                      images[2],
                      final_image_dim,
                      'cpu')

xgc.visualize()

Model 2: Text Classification with PyTorch

Get the Model and Dataset

Now we will call the model factory to get a pretrained model from HuggingFace and load the annotation file using the dataset factory. We will use clinical-bert for this part.

[ ]:
# Set up NLP parameters
model_name = 'clinical-bert'
seq_length = 64
batch_size = 5
quantization_criterion = 0.05
quantization_max_trial = 50
[ ]:
nlp_model = model_factory.get_model(model_name=model_name, framework='pytorch')
[ ]:
# Create a label map function and reverse label map for the dataset
def label_map_func(label):
        if label == 'Benign':
            return 0
        elif label == 'Malignant':
            return 1
        elif label == 'Normal':
            return 2

reverse_label_map = {0: 'Benign', 1: 'Malignant', 2: 'Normal'}
[ ]:
os.path.split(os.path.splitext(train_annotation_path)[0] + '.csv')
[ ]:
train_file_dir, train_file_name =  os.path.split(os.path.splitext(train_annotation_path)[0] +'.csv')
train_nlp_dataset = dataset_factory.load_dataset(dataset_dir=train_file_dir,
                       use_case='text_classification',
                       framework='pytorch',
                       dataset_name='brca',
                       csv_file_name=train_file_name,
                       label_map_func=label_map_func,
                       class_names=['Benign', 'Malignant', 'Normal'],
                       header=True,
                       label_col=label_col,
                       shuffle_files=True,
                       exclude_cols=[2])

test_file_dir, test_file_name =  os.path.split(os.path.splitext(test_annotation_path)[0] +'.csv')
test_nlp_dataset = dataset_factory.load_dataset(dataset_dir=test_file_dir,
                       use_case='text_classification',
                       framework='pytorch',
                       dataset_name='brca',
                       csv_file_name=test_file_name,
                       label_map_func=label_map_func,
                       class_names=['Benign', 'Malignant', 'Normal'],
                       header=True,
                       label_col=label_col,
                       shuffle_files=True,
                       exclude_cols=[2])

Data Preparation

[ ]:
train_nlp_dataset.preprocess(nlp_model.hub_name, batch_size=batch_size, max_length=seq_length)
test_nlp_dataset.preprocess(nlp_model.hub_name, batch_size=batch_size, max_length=seq_length)
train_nlp_dataset.shuffle_split(train_pct=0.67, val_pct=0.33, shuffle_files=False)

Corpus analysis

Let’s take a look at the word distribution across each label to get an idea what BERT will be training on as well make sure that our training and validation datasets are distributed similarly.

[ ]:
import plotly.express as px

train_label_count = {'Benign': 0, 'Malignant': 0, 'Normal': 0}
for label in train_nlp_dataset.train_subset['label']:
    train_label_count[reverse_label_map[int(label)]] += 1

print('Training label distribution:')
train_label_count
[ ]:
valid_label_count = {'Benign': 0, 'Malignant': 0, 'Normal': 0}
for label in train_nlp_dataset.validation_subset['label']:
    valid_label_count[reverse_label_map[int(label)]] += 1

print('Validation label distribution:')
valid_label_count
[ ]:
test_label_count = {'Benign': 0, 'Malignant': 0, 'Normal': 0}
for label in test_nlp_dataset.dataset['label']:
    test_label_count[reverse_label_map[int(label)]] += 1

print('Validation label distribution:')
test_label_count
[ ]:
form = {'type':'domain'}

fig = make_subplots(rows=1, cols=3, specs=[[form, form, form]], subplot_titles=['Training', 'Validation', 'Testing'])
fig.add_trace(go.Pie(values=list(train_label_count.values()), labels=list(train_label_count.keys())), 1, 1)
fig.add_trace(go.Pie(values=list(valid_label_count.values()), labels=list(valid_label_count.keys())), 1, 2)
fig.add_trace(go.Pie(values=list(test_label_count.values()), labels=list(test_label_count.keys())), 1, 3)


fig.update_layout(height=600, width=800, title_text="Label Distributions")
fig.show()

[ ]:
nltk.download('punkt')
nltk.download('words')

def get_mc_df(words_list, n=50, ignored_words=[]):
    '''
    Get's the most common words from a list of words and returns a pd DataFrame for Plotly
    '''

    frequency_dict = nltk.FreqDist(words_list)
    most_common = frequency_dict.most_common(n=500)


    final_fd = pd.DataFrame({'Token': [], 'Frequency': []})
    cnt = 0
    idx = 0
    while(cnt < n):
        if most_common[idx][0] in string.punctuation:
            print(f'{most_common[idx][0]} is not a word')
        else:
            final_fd.loc[len(final_fd.index)] = [most_common[idx][0], most_common[idx][1]]
            cnt += 1
        idx += 1

    return final_fd

[ ]:
df = pd.read_csv(train_annotation_path)

# get string arrays of symptoms for each label
mal_text = list(df.loc[df['label'] == 'Malignant']['symptoms'])
nor_text = list(df.loc[df['label'] == 'Normal']['symptoms'])
ben_text = list(df.loc[df['label'] == 'Benign']['symptoms'])

# get tokenized words for each
mal_tokenized: list[str] = nltk.word_tokenize(" ".join(mal_text))
nor_tokenized: list[str] = nltk.word_tokenize(" ".join(nor_text))
ben_tokenized: list[str] = nltk.word_tokenize(" ".join(ben_text))

# generate the dataframes necesarry to plot distributions
mal_fd = get_mc_df(mal_tokenized)
nor_fd = get_mc_df(nor_tokenized)
ben_fd = get_mc_df(ben_tokenized)
[ ]:
fig = px.bar(mal_fd, x="Token", y='Frequency', color='Frequency', title='Malignant word distribution')
fig.update(layout_coloraxis_showscale=False)
fig.show()
[ ]:
fig = px.bar(nor_fd, x="Token", y='Frequency', color='Frequency', title='Normal word distribution')
fig.update(layout_coloraxis_showscale=False)
fig.show()
[ ]:
fig = px.bar(ben_fd, x="Token", y='Frequency', color='Frequency', title='Benign word distribution')
fig.update(layout_coloraxis_showscale=False)
fig.show()

Transfer Learning

This step calls the model’s train function with the dataset that was just prepared. The training function will get the pretrained model from HuggingFace and add on a dense layer based on the number of classes in the dataset. The model is then trained using an instance of HuggingFace Trainer for the number of epochs specified. If desired, a native PyTorch loop can be invoked instead of Trainer by setting use_trainer=False.

[ ]:
import transformers
transformers.set_seed(1)
nlp_history = nlp_model.train(train_nlp_dataset, output_dir, epochs=3, use_trainer=True, seed=1)

Save the NLP Model

[ ]:
nlp_model.export(output_dir)
[ ]:
# This currently isn't showing the correct output for test
train_nlp_metrics = nlp_model.evaluate(train_nlp_dataset)
test_nlp_metrics = nlp_model.evaluate(test_nlp_dataset)

Error analysis

We can see that BERT has a much better accuracy than the CNN. Nonetheless, similar to the CNN, let’s see where BERT makes mistakes across the three classes using a confusion matrix and ROC and PR curves.

[ ]:
# get predictions in logits (one-hot-encoded)
# NOTE: added a new flag to predict function
logit_predictions = nlp_model.predict(test_nlp_dataset.dataset, return_raw=True)['logits']
#convert logits to probability
from scipy.special import softmax
y_pred = softmax(logit_predictions.detach().numpy(), axis=1)
y_true = test_nlp_dataset.validation_subset['label'].numpy().astype(int)
[ ]:
from intel_ai_safety.explainer import metrics

nlp_cm = metrics.confusion_matrix(y_true, y_pred, test_nlp_dataset.class_names)
nlp_cm.visualize()
print(nlp_cm.report)
[ ]:
plotter = metrics.plot(y_true, y_pred, test_nlp_dataset.class_names)
plotter.pr_curve()
[ ]:
plotter.roc_curve()

Explanation

[ ]:
mal_idxs = np.where(test_nlp_dataset.dataset['label'].numpy() == label_map_func('Malignant'))[0].tolist()
ben_preds = np.where(nlp_model.predict(test_nlp_dataset.dataset).numpy() == label_map_func('Benign'))[0].tolist()

# get mal examples that were misclassified as ben
mal_classified_as_ben = list(set(mal_idxs).intersection(ben_preds))
[ ]:
mal_classified_as_ben_text = test_nlp_dataset.get_text(test_nlp_dataset.dataset[mal_classified_as_ben]['input_ids'])
[ ]:
# define a prediction function
def f(x):
    encoded_input = nlp_model._tokenizer(x.tolist(), padding=True, return_tensors='pt')
    outputs = nlp_model._model(**encoded_input)
    return softmax(outputs.logits.detach().numpy(), axis=1)
[ ]:
from intel_ai_safety.explainer.attributions import attributions
partition_explainer = attributions.partition_text_explainer(f, test_nlp_dataset.class_names, np.array(mal_classified_as_ben_text), r"\W+")
partition_explainer.visualize()

Int8 Quantization

We can use the Intel® Extension for Transformers to quantize the trained model for faster inference. If you want to run this part of the notebook, make sure you have intel-extension-for-transformers installed in your environment.

[ ]:
! pip install --no-cache-dir intel-extension-for-transformers==1.4
[ ]:
from intel_extension_for_transformers.transformers.trainer import NLPTrainer
from intel_extension_for_transformers.transformers import objectives, OptimizedModel, QuantizationConfig
from intel_extension_for_transformers.transformers import metrics as nlptk_metrics
[ ]:
# Set up quantization config
tune_metric = nlptk_metrics.Metric(
    name="eval_accuracy",
    greater_is_better=True,
    is_relative=True,
    criterion=quantization_criterion,
    weight_ratio=None,
)

objective = objectives.Objective(
    name="performance", greater_is_better=True, weight_ratio=None
)

quantization_config = QuantizationConfig(
    approach="PostTrainingDynamic",
    max_trials=quantization_max_trial,
    metrics=[tune_metric],
    objectives=[objective],
)

# Set up metrics computation
def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.argmax(preds, axis=1)
    return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
[ ]:
quantizer = NLPTrainer(model=nlp_model._model,
                       train_dataset=train_nlp_dataset.train_subset,
                       eval_dataset=train_nlp_dataset.validation_subset,
                       compute_metrics=compute_metrics,
                       tokenizer=train_nlp_dataset._tokenizer)
quantized_model = quantizer.quantize(quant_config=quantization_config)
[ ]:
results = quantizer.evaluate()
eval_acc = results.get("eval_accuracy")
print("Final Eval Accuracy: {:.5f}".format(eval_acc))

Save the Quantized NLP Model

[ ]:
quantizer.save_model(os.path.join(output_dir, 'quantized_BERT'))
nlp_model._model.config.save_pretrained(os.path.join(output_dir, 'quantized_BERT'))

Error analysis

The quantized BERT model has the same validation accuracy as it’s stock counterpart. This does not mean, however, that they perform the same. Let’s look at the confusion matrix and PR and ROC curves to see if the errors are different.

[ ]:
# get predictions in logits (one-hot-encoded)
# NOTE: added a new flag to predict function
logit_predictions = quantizer.predict(test_nlp_dataset.dataset)[0]
#convert logits to probability
from scipy.special import softmax
y_pred = softmax(logit_predictions, axis=1)
y_true = test_nlp_dataset.dataset['label'].numpy().astype(int)
[ ]:
quant_cm = metrics.confusion_matrix(y_true, y_pred, test_nlp_dataset.class_names)
quant_cm.visualize()
print(quant_cm.report)
[ ]:
plotter = metrics.plot(y_true, y_pred, test_nlp_dataset.class_names)
plotter.pr_curve()

Citations

Data Citation

Khaled R., Helal M., Alfarghaly O., Mokhtar O., Elkorany A., El Kassas H., Fahmy A. Categorized Digital Database for Low energy and Subtracted Contrast Enhanced Spectral Mammography images [Dataset]. (2021) The Cancer Imaging Archive. DOI: 10.7937/29kw-ae92

Publication Citation

Khaled, R., Helal, M., Alfarghaly, O., Mokhtar, O., Elkorany, A., El Kassas, H., & Fahmy, A. Categorized contrast enhanced mammography dataset for diagnostic and artificial intelligence research. (2022) Scientific Data, Volume 9, Issue 1. DOI: 10.1038/s41597-022-01238-0

TCIA Citation

Clark K, Vendt B, Smith K, Freymann J, Kirby J, Koppel P, Moore S, Phillips S, Maffitt D, Pringle M, Tarbox L, Prior F. The Cancer Imaging Archive (TCIA): Maintaining and Operating a Public Information Repository, Journal of Digital Imaging, Volume 26, Number 6, December, 2013, pp 1045-1057. DOI: 10.1007/s10278-013-9622-7