Generating Model Card for Image Classification with PyTorch
This notebook intends to provide an example of generating a model card for a PyTorch model finetuned for classifying images into Deepfake and Normal images using Intel Model Card Generator.
1. Download and Import Dependencies
[ ]:
!pip install evaluate datasets transformers[torch] scikit-learn torchvision Pillow
[3]:
from datasets import load_dataset
import torchvision.transforms as transforms
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import numpy as np
import pandas as pd
from intel_ai_safety.model_card_gen.model_card_gen import ModelCardGen
2. Download Dataset from Hugging Face Datasets
[5]:
ds = load_dataset("itsLeen/deepfake_vs_real_image_detection")
[6]:
ds
[6]:
DatasetDict({
train: Dataset({
features: ['image', 'label'],
num_rows: 6327
})
test: Dataset({
features: ['image', 'label'],
num_rows: 1117
})
})
[7]:
train_dataset = ds["train"]
test_dataset = ds["test"]
3. Transform the Image Dataset
[8]:
# Define the preprocessing steps
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Resize image to match model input size
transforms.ToTensor(), # Convert image to PyTorch tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize according to ResNet requirements
])
# Dataset class
class DeepfakeDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
img = item['image'].convert('RGB')
if self.transform:
img = self.transform(img)
label = item['label']
return img, label
Preparing Data Loaders for Training and Testing
[9]:
train_data = DeepfakeDataset(train_dataset, transform=preprocess)
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_data = DeepfakeDataset(test_dataset, transform=preprocess)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False)
4. Download Model and Process Outputs
Load the pre-trained ResNet50 model
[ ]:
model = models.resnet50(pretrained=True)
# Modify the final fully connected layer
num_classes = 2
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
5. Fine-tune Image Classification Model
[21]:
# Training loop
num_epochs = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
i = 0
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, labels in train_dataloader:
i+=1
images, labels = images.to(device), labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_dataloader):.4f}')
Epoch [1/2], Loss: 0.3365
Epoch [2/2], Loss: 0.2964
6. Save Model
Save the offline version of finetuned image classification model
[22]:
torch.save(model.state_dict(), 'simple_model.pth')
7. Evaluate Model
Evaluating Model Performance on Test Dataset
[29]:
model.eval()
correct = 0
total = 0
y_pred_prob = []
with torch.no_grad():
for images, labels in test_dataloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
probs = torch.softmax(outputs, dim=1)[:, 1]
_, predicted = torch.max(outputs.data, 1)
y_pred_prob.extend(probs.cpu().numpy())
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy: {100 * correct / total:.2f}%')
Accuracy: 87.47%
8. Transform Output for Model Card
We will create metrics_by_threshold dataframe containing performance metrics at threshold.
[ ]:
thetas = np.linspace(0, 1, 1001)
metrics_dict ={
'threshold': thetas,
'precision': [precision_score(test_dataset['label'], y_pred_prob > theta) for theta in thetas],
'recall': [recall_score(test_dataset['label'], y_pred_prob > theta) for theta in thetas],
'f1': [f1_score(test_dataset['label'], y_pred_prob > theta) for theta in thetas],
'accuracy' : [accuracy_score(test_dataset['label'], y_pred_prob > theta) for theta in thetas]
}
[ ]:
metrics_by_threshold = pd.DataFrame.from_dict(metrics_dict)
[41]:
metrics_by_threshold.to_csv('metrics_by_threshold.csv', index=False)
9. Generate Model Card
Simply load the metrics_by_threshold dataframe into the ModelCardGen.generate class method to build a model card.
[5]:
metrics_by_threshold = pd.read_csv("metrics_by_threshold.csv")
[1]:
mc = {
"schema_version": "0.0.1",
"model_details": {
"name": "Deepfake Image Detection",
"version": {
"name": "0.1",
"date": "2024"
},
"graphics": {},
"citations": [
{
"citation": '''@inproceedings{he2016deep,
title={Deep residual learning for image recognition},
author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={770--778},
year={2016}
}
'''
},
],
"overview": 'The fine-tuned ResNet50 model is used for classifying an image as either Deepfake or Normal. The model is fine-tuned using a dataset containing both deepfake and normal images along with their corresponding labels.',
}
}
[6]:
mcg = ModelCardGen.generate(metrics_by_threshold=metrics_by_threshold, model_card=mc)
mcg
[6]:
Model Details
Overview
The fine-tuned ResNet50 model is used for classifying an image as either Deepfake or Normal. The model is fine-tuned using a dataset containing both deepfake and normal images along with their corresponding labels.Model Performance
Overall Accuracy/Precision/Recall/F1
Version
name: 0.1
date: 2024
Citations
- @inproceedings{he2016deep, title={Deep residual learning for image recognition}, author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, pages={770--778}, year={2016} }