FLOWER17 Species classification using Transfer Learning

The model we’ll be using is AlexNet CNN Model which is a part of dffml-model-pytorch, a DFFML plugin which allows you to use PyTorch via DFFML. We can install it with pip. We will also be using image loading from dffml-config-image and YAML file loading from dffml-config-yaml.

$ pip install -U dffml-model-pytorch dffml-config-yaml dffml-config-image

There are 2 ways to perform Transfer Learning:

  1. Fine-tuning the CNN

    Initializing the network with pre-trained weights(trained on ImageNet1000 dataset) and training the whole network on the dataset.

  2. Using the CNN as fixed feature-extractor

    We freeze the parameters of the complete network except the final layer, so that the gradients for any other layer except the last layer are not computed in back-propagation.

In this example, we will be fine-tuning the AlexNet model. (We set trainable=True)

We first create a YAML file to define the last layer(s) to replace from the network architecture:

linear1:
  layer_type: Linear
  in_features: 4096
  out_features: 256
relu:
  layer_type: ReLU
dropout:
  layer_type: Dropout
  p: 0.2
linear2:
  layer_type: Linear
  in_features: 256
  out_features: 17
logsoftmax:
  layer_type: LogSoftmax
  dim: 1

Train the model.

dffml train \
  -model alexnet \
  -model-add_layers \
  -model-layers @layers.yaml \
  -model-clstype str \
  -model-classifications \
    crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
    bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
  -model-location alexnet_model \
  -model-epochs 20 \
  -model-batch_size 32 \
  -model-imageSize 224 \
  -model-validation_split 0.2 \
  -model-trainable \
  -model-enableGPU \
  -model-normalize_mean 0.485 0.456 0.406 \
  -model-normalize_std 0.229 0.224 0.225 \
  -model-features image:int:$((500*500)) \
  -model-predict label:str:1 \
  -sources f=dir \
    -source-foldername flower_dataset/train \
    -source-feature image \
    -source-labels \
      crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
      bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
  -log debug
INFO:dffml.AlexNetModelContext:Training complete in 5m 41s
INFO:dffml.AlexNetModelContext:Best Validation Accuracy: 0.927602

Assess the model’s accuracy.

dffml accuracy \
  -model alexnet \
  -model-add_layers \
  -model-layers @layers.yaml \
  -model-clstype str \
  -model-classifications \
    crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
    bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
  -model-location alexnet_model \
  -model-imageSize 224 \
  -model-trainable \
  -model-enableGPU \
  -model-normalize_mean 0.485 0.456 0.406 \
  -model-normalize_std 0.229 0.224 0.225 \
  -model-features image:int:$((500*500)) \
  -model-predict label:str:1 \
  -features label:str:1 \
  -sources f=dir \
    -source-foldername flower_dataset/test \
    -source-feature image \
    -source-labels \
      crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
      bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
  -scorer pytorchscore

The output is:

0.8196078431372549

Create an unknown_images.csv file which contains the filenames of the images to predict on.

cat > unknown_images.csv << EOF
key,image
daisy,daisy.jpg
pansy,pansy.jpg
tigerlily,tigerlily.jpg
buttercup,buttercup.jpg
EOF

In this example, the unknown_images.csv file contains the filenames of the following images

../examples/flower17/daisy.jpg ../examples/flower17/pansy.jpg ../examples/flower17/tigerlily.jpg ../examples/flower17/buttercup.jpg

Predict with the trained model.

dffml predict all \
  -model alexnet \
  -model-add_layers \
  -model-layers @layers.yaml \
  -model-clstype str \
  -model-classifications \
    crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
    bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
  -model-location alexnet_model \
  -model-imageSize 224 \
  -model-trainable \
  -model-enableGPU \
  -model-normalize_mean 0.485 0.456 0.406 \
  -model-normalize_std 0.229 0.224 0.225 \
  -model-features image:int:$((500*500)) \
  -model-predict label:str:1 \
  -sources f=csv \
  -source-filename unknown_images.csv \
  -source-loadfiles image \
  -pretty

Output


	Key:	daisy
                                                               Record Features
+----------------------------------------------------------------------------------------------------------------------------------------------+
|               image               |                    111, 126, 128, 110, 125, 127, 109, 124, 126, 1 ... (length:689520)                    |
+----------------------------------------------------------------------------------------------------------------------------------------------+

                                                                  Prediction
+----------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                    label                                                                     |
+----------------------------------------------------------------------------------------------------------------------------------------------+
|           Value:  daisy           |                                     Confidence:   0.9999972581863403                                     |
+----------------------------------------------------------------------------------------------------------------------------------------------+

	Key:	pansy
                                                               Record Features
+----------------------------------------------------------------------------------------------------------------------------------------------+
|               image               |                    153, 92, 136, 120, 59, 103, 112, 51, 95, 107,  ... (length:921600)                    |
+----------------------------------------------------------------------------------------------------------------------------------------------+

                                                                  Prediction
+----------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                    label                                                                     |
+----------------------------------------------------------------------------------------------------------------------------------------------+
|           Value:  pansy           |                                     Confidence:   0.999854564666748                                      |
+----------------------------------------------------------------------------------------------------------------------------------------------+

	Key:	tigerlily
                                                               Record Features
+----------------------------------------------------------------------------------------------------------------------------------------------+
|               image               |                    0, 15, 1, 0, 15, 1, 0, 15, 1, 0, 15, 1, 0, 15, ... (length:817920)                    |
+----------------------------------------------------------------------------------------------------------------------------------------------+

                                                                  Prediction
+----------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                    label                                                                     |
+----------------------------------------------------------------------------------------------------------------------------------------------+
|         Value:  tigerlily         |                                     Confidence:   0.9924067854881287                                     |
+----------------------------------------------------------------------------------------------------------------------------------------------+

	Key:	buttercup
                                                               Record Features
+----------------------------------------------------------------------------------------------------------------------------------------------+
|               image               |                    10, 18, 17, 10, 18, 17, 10, 18, 17, 10, 18, 17 ... (length:814080)                    |
+----------------------------------------------------------------------------------------------------------------------------------------------+

                                                                  Prediction
+----------------------------------------------------------------------------------------------------------------------------------------------+
|                                                                    label                                                                     |
+----------------------------------------------------------------------------------------------------------------------------------------------+
|         Value:  buttercup         |                                     Confidence:   0.9977046847343445                                     |
+----------------------------------------------------------------------------------------------------------------------------------------------+

The model predicts all the flower species correctly with 99% confidence!