Explaining a Custom Neural Network Heart Disease Classification Using the Attributions Explainer

[ ]:
from intel_ai_safety.explainer.attributions import attributions

import warnings
warnings.filterwarnings('ignore')
[ ]:
import tensorflow as tf
import pandas as pd

tf.__version__
[ ]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
[ ]:
file_url = "http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"
df = pd.read_csv(file_url)
[ ]:
# make target variable
y = df.pop('target')
[ ]:
# prepare features
list_numerical = ['age', 'thalach', 'trestbps',  'chol', 'oldpeak']

X = df[list_numerical]

Data Splitting

[ ]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Feature Preprocessing

[ ]:
scaler = StandardScaler().fit(X_train[list_numerical])

X_train[list_numerical] = scaler.transform(X_train[list_numerical])
X_test[list_numerical] = scaler.transform(X_test[list_numerical])

Model

[ ]:
model = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation='relu'),
    tf.keras.layers.Dense(10, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
  ])
[ ]:
model.compile(optimizer="adam",
              loss ="binary_crossentropy",
              metrics=["accuracy"])

model.fit(X_train, y_train,
         epochs=15,
         batch_size=13,
         validation_data=(X_test, y_test)
         )

Visualize the connectivity graph:

[ ]:
tf.keras.utils.plot_model(model, show_shapes=True, rankdir="LR")

Accuracy

[ ]:
loss, accuracy = model.evaluate(X_test, y_test)

print("Accuracy", accuracy)
[ ]:
predictions = model.predict(X_train)
[ ]:
print(
    "This particular patient had a %.1f percent probability "
    "of having a heart disease, as evaluated by our model." % (100 * predictions[0][0],)
)
[ ]:
# Let's look at the shap value estimations for this patient's features that resulted in this probability
ke = attributions.kernel_explainer(model, X_train.iloc[1:101, :], X_train.iloc[0, :])
ke.visualize()