Transfer Learning for Image Classification using TensorFlow and the Intel® Transfer Learning Tool API¶
This notebook uses the tlt
library to do transfer learning for image classfication with a TensorFlow pretrained model.
1. Import dependencies and setup parameters¶
This notebook assumes that you have already followed the instructions to setup a TensorFlow environment with all the dependencies required to run the notebook.
[ ]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import PIL.Image as Image
import tensorflow as tf
# tlt imports
from tlt.datasets import dataset_factory
from tlt.models import model_factory
from tlt.utils.file_utils import download_file, download_and_extract_tar_file
# Specify a directory for the dataset to be downloaded
dataset_dir = os.environ["DATASET_DIR"] if "DATASET_DIR" in os.environ else \
os.path.join(os.environ["HOME"], "dataset")
# Specify a directory for output
output_dir = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
os.path.join(os.environ["HOME"], "output")
print("Dataset directory:", dataset_dir)
print("Output directory:", output_dir)
2. Get the model¶
In this step, we call the Intel Transfer Learning Tool model factory to list supported TensorFlow image classification models. This is a list of pretrained models from TFHub and Keras Applications that we tested with our API. Optionally, the verbose=True
argument can be added to the print_supported_models
function call to get more information about each model (such as the image size, the original dataset, the preprocessor,
etc).
[ ]:
# See a list of available models
model_factory.print_supported_models(use_case='image_classification', framework='tensorflow')
Next, use the model factory to get one of the models listed in the previous cell. The get_model
function returns a model object that will later be used for training.
[ ]:
model = model_factory.get_model(model_name='resnet_v1_50', framework='tensorflow')
print("Model name:", model.model_name)
print("Framework:", model.framework)
print("Use case:", model.use_case)
print("Image size:", model.image_size)
3. Get the dataset¶
We call the dataset factory to get a sample image classification dataset. For demonstration purposes, we are using the tf_flowers dataset from the TensorFlow Datasets catalog. This dataset contains images of flowers in 5 different classes.
Option A: Use your own dataset¶
To use your own image dataset for transfer learning with the rest of this notebook, format your images as .jpg
files and save them in folders named after the classes that you want the model to predict. To provide a working example using the correct layout, we will download a flower species dataset. After downloading and extracting, you will have the following subdirectories in your dataset directory. Each species subfolder will contain numerous .jpg
files:
flower_photos
└── daisy
└── dandelion
└── roses
└── sunflowers
└── tulips
When using your own dataset, ensure that it is similarly organized with folders for each class. Change the custom_dataset_path
variable to point to your dataset folder.
[ ]:
# For demonstration purposes, we download a flowers dataset. To instead use your own dataset, set the
# custom_dataset_path to point to your dataset's directory and comment out the download_and_extract_tar_file line.
custom_dataset_path = os.path.join(dataset_dir, "flower_photos")
if not os.path.exists(custom_dataset_path):
download_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
download_and_extract_tar_file(download_url, dataset_dir)
Call the dataset factory to load the dataset from the directory.
[ ]:
# Load the dataset from the custom dataset path
dataset = dataset_factory.load_dataset(dataset_dir=custom_dataset_path,
use_case='image_classification',
framework='tensorflow')
print("Class names:", str(dataset.class_names))
Skip to the next step 4. Prepare the dataset to continue using the custom dataset.
Option B: Use a dataset from the TensorFlow Datasets catalog¶
[ ]:
dataset = dataset_factory.get_dataset(dataset_dir=dataset_dir,
use_case='image_classification',
framework='tensorflow',
dataset_name='tf_flowers',
dataset_catalog='tf_datasets')
print(dataset.info)
print("\nClass names:", str(dataset.class_names))
4. Prepare the dataset¶
Once you have your dataset from Option A or Option B above, use the following cells to split and preprocess the data. We split them into training and validation subsets, then resize the images to match the selected models, and then batch the images. Data augmentation can be applied by specifying the augmentations to be applied in add_aug parameter. Supported augmentations are: 1. hvflip - RandomHorizontalandVerticalFlip 2. hflip - RandomHorizontalFlip 3. vflip - RandomVerticalFlip 4. rotate - RandomRotate 5. zoom - RandomZoom
[ ]:
# Split the dataset into training and validation subsets
dataset.shuffle_split(train_pct=.75, val_pct=.25)
[ ]:
# Preprocess the dataset with an image size and preprocessor that match the model and a batch size of 32
batch_size = 32
dataset.preprocess(model.image_size, batch_size=batch_size, preprocessor=model.preprocessor)
5. Predict using the original model¶
We get a single batch from our dataset, and use that to call predict on our model. Since we haven’t done any training on the model yet, it will give us predictions using the original ImageNet trained model.
[ ]:
# Get a single batch from the dataset
images, labels = dataset.get_batch()
labels = [dataset.class_names[id] for id in labels]
# Download the ImageNet labels and load them into a list
labels_file = "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"
downloaded_file = tf.keras.utils.get_file("labels.txt", origin=labels_file)
imagenet_classes = []
with open(downloaded_file) as f:
imagenet_labels = f.readlines()
imagenet_classes = [l.strip() for l in imagenet_labels]
# Predict using the original model
predictions = model.predict(images)
predictions = [imagenet_classes[id] for id in predictions]
[ ]:
# Display the images with the predicted ImageNet label
plt.figure(figsize=(18,14))
plt.subplots_adjust(hspace=0.5)
for n in range(min(batch_size, 30)):
plt.subplot(6,5,n+1)
norm_images = (images[n]-np.min(images[n]))/(np.max(images[n])-np.min(images[n]))
plt.imshow(norm_images, vmin=np.min(norm_images), vmax=np.max(norm_images))
correct_prediction = labels[n] == predictions[n]
color = "darkgreen" if correct_prediction else "crimson"
title = predictions[n] if correct_prediction else "{}\n({})".format(predictions[n], labels[n])
plt.title(title, fontsize=14, color=color)
plt.axis('off')
_ = plt.suptitle("ImageNet predictions", fontsize=20)
plt.show()
print("Correct predictions are shown in green")
print("Incorrect predictions are shown in red with the actual label in parenthesis")
6. Transfer Learning¶
This step calls the model’s train function with the dataset that was just prepared. The training function will get the base model and add on a dense layer based on the number of classes in the dataset. The model is then compiled and trained based on the number of epochs specified in the argument. With the do_eval paramter set to True by default, this step will also show how the model can be evaluated and will return a list of metrics calculated from the dataset’s validation subset.
Arguments¶
Required¶
dataset (ImageClassificationDataset, required): Dataset to use when training the model
output_dir (str): Path to a writeable directory for checkpoint files
epochs (int): Number of epochs to train the model (default: 1)
Optional¶
initial_checkpoints (str): Path to checkpoint weights to load. If the path provided is a directory, the latest checkpoint will be used.
early_stopping (bool): Enable early stopping if convergence is reached while training at the end of each epoch. (default: False)
lr_decay (bool): If lr_decay is True and do_eval is True, learning rate decay on the validation loss is applied at the end of each epoch.
enable_auto_mixed_precision (bool or None): Enable auto mixed precision for training. Mixed precision uses both 16-bit and 32-bit floating point types to make training run faster and use less memory. It is recommended to enable auto mixed precision training when running on platforms that support bfloat16 (Intel third or fourth generation Xeon processors). If it is enabled on a platform that does not support bfloat16, it can be detrimental to the training performance. If enable_auto_mixed_precision is set to None, auto mixed precision will be automatically enabled when running with Intel fourth generation Xeon processors, and disabled for other platforms.
extra_layers (list[int]): Optionally insert additional dense layers between the base model and output layer. This can help increase accuracy when fine-tuning. The input should be a list of integers representing the number and size of the layers, for example [1024, 512] will insert two dense layers, the first with 1024 neurons and the second with 512 neurons.
Note: refer to release documentation for an up-to-date list of train arguments and their current descriptions
[ ]:
enable_auto_mixed_precision = None
# Train using the pretrained model with the new dataset
history = model.train(dataset, output_dir=output_dir, epochs=1,
enable_auto_mixed_precision=enable_auto_mixed_precision)
7. Predict¶
Let’s predict using the same single batch that we used earlier with the ImageNet trained model to visualize the model’s predictions after training.
[ ]:
# Predict with a single batch
predictions = model.predict(images, enable_auto_mixed_precision=enable_auto_mixed_precision)
# Map the predicted ids to the class names
predictions = [dataset.class_names[id] for id in predictions]
# Display the results
plt.figure(figsize=(18,14))
plt.subplots_adjust(hspace=0.5)
for n in range(min(batch_size, 30)):
plt.subplot(6,5,n+1)
norm_images = (images[n]-np.min(images[n]))/(np.max(images[n])-np.min(images[n]))
plt.imshow(norm_images, vmin=np.min(norm_images), vmax=np.max(norm_images))
correct_prediction = labels[n] == predictions[n]
color = "darkgreen" if correct_prediction else "crimson"
title = predictions[n] if correct_prediction else "{}\n({})".format(predictions[n], labels[n])
plt.title(title, fontsize=14, color=color)
plt.axis('off')
_ = plt.suptitle("Model predictions", fontsize=16)
plt.show()
print("Correct predictions are shown in green")
print("Incorrect predictions are shown in red with the actual label in parenthesis")
Custom Single Image Prediction¶
We can also predict using a single image that wasn’t part of our original dataset. We download a flower image from the Open Images Dataset and then resize it to match our model.
[ ]:
# Download an image from the web and resize it to match our model
image_url = "https://c8.staticflickr.com/8/7095/7210797228_c7fe51c3cb_z.jpg"
daisy = download_file(image_url, output_dir)
image_shape = (model.image_size, model.image_size)
daisy = Image.open(daisy).resize(image_shape)
daisy
Then, we call predict by passing the np array for our image and add a dimension to our array to represent the batch.
[ ]:
# Get the image as a np array and call predict while adding a batch dimension (with np.newaxis)
daisy = np.array(daisy)/255.0
result = model.predict(daisy[np.newaxis, ...])
# Print the predicted class name
print(dataset.class_names[result[0]])
8. Export¶
Next, we can call the model export
function to generate a saved_model.pb
. The model is saved in a format that is ready to use with TensorFlow Serving. Each time the model is exported, a new numbered directory is created, which allows serving to pick up the latest model.
[ ]:
saved_model_dir = model.export(output_dir)
9. Post-training quantization¶
In this section, the tlt
API uses Intel® Neural Compressor (INC) to benchmark and quantize the model to get optimal inference performance.
We use the Intel Neural Compressor config to benchmark the full precision model to see how it performs, as our baseline.
Note that there is a known issue when running Intel Neural Compressor from a notebook that you may sometimes see the error
zmq.error.ZMQError: Address already in use
. If you see this error, rerun the cell again.
Likewise, if the benchmark function returns an empty dictionary
{}
, run the cell again.
[ ]:
model.benchmark(dataset=dataset)
Next we use Intel Neural Compressor to automatically search for the optimal quantization recipe for low-precision model inference within the accuracy loss constrains defined in the config. Running post training quantization may take several minutes, depending on your hardware and the exit policy (timeout and max trials).
[ ]:
inc_output_dir = os.path.join(output_dir, 'quantized_models', model.model_name,
os.path.basename(saved_model_dir))
model.quantize(inc_output_dir, dataset=dataset)
Let’s benchmark using the quantized model, so that we can compare the performance to the full precision model that was originally benchmarked.
[ ]:
model.benchmark(dataset=dataset, saved_model_dir=inc_output_dir)
Dataset Citations¶
@ONLINE {tfflowers,
author = "The TensorFlow Team",
title = "Flowers",
month = "jan",
year = "2019",
url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }
@article{openimages,
title={OpenImages: A public dataset for large-scale multi-label and multi-class image classification.},
author={Krasin, Ivan and Duerig, Tom and Alldrin, Neil and Veit, Andreas and Abu-El-Haija, Sami
and Belongie, Serge and Cai, David and Feng, Zheyun and Ferrari, Vittorio and Gomes, Victor
and Gupta, Abhinav and Narayanan, Dhyanesh and Sun, Chen and Chechik, Gal and Murphy, Kevin},
journal={Dataset available from https://github.com/openimages},
year={2016}
}