FLOWER17 Species classification using Transfer Learning¶
The model we’ll be using is AlexNet CNN Model
which is a part of dffml-model-pytorch
, a DFFML plugin which allows you to use PyTorch
via DFFML. We can install it with pip
. We will also be using image loading from
dffml-config-image
and YAML file loading from dffml-config-yaml
.
$ pip install -U dffml-model-pytorch dffml-config-yaml dffml-config-image
There are 2 ways to perform Transfer Learning:
- Fine-tuning the CNN
Initializing the network with pre-trained weights(trained on ImageNet1000 dataset) and training the whole network on the dataset.
- Using the CNN as fixed feature-extractor
We freeze the parameters of the complete network except the final layer, so that the gradients for any other layer except the last layer are not computed in back-propagation.
In this example, we will be fine-tuning the AlexNet model. (We set trainable=True)
We first create a YAML file to define the last layer(s) to replace from the network architecture:
linear1:
layer_type: Linear
in_features: 4096
out_features: 256
relu:
layer_type: ReLU
dropout:
layer_type: Dropout
p: 0.2
linear2:
layer_type: Linear
in_features: 256
out_features: 17
logsoftmax:
layer_type: LogSoftmax
dim: 1
Train the model.
dffml train \
-model alexnet \
-model-add_layers \
-model-layers @layers.yaml \
-model-clstype str \
-model-classifications \
crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
-model-location alexnet_model \
-model-epochs 20 \
-model-batch_size 32 \
-model-imageSize 224 \
-model-validation_split 0.2 \
-model-trainable \
-model-enableGPU \
-model-normalize_mean 0.485 0.456 0.406 \
-model-normalize_std 0.229 0.224 0.225 \
-model-features image:int:$((500*500)) \
-model-predict label:str:1 \
-sources f=dir \
-source-foldername flower_dataset/train \
-source-feature image \
-source-labels \
crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
-log debug
INFO:dffml.AlexNetModelContext:Training complete in 5m 41s
INFO:dffml.AlexNetModelContext:Best Validation Accuracy: 0.927602
Assess the model’s accuracy.
dffml accuracy \
-model alexnet \
-model-add_layers \
-model-layers @layers.yaml \
-model-clstype str \
-model-classifications \
crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
-model-location alexnet_model \
-model-imageSize 224 \
-model-trainable \
-model-enableGPU \
-model-normalize_mean 0.485 0.456 0.406 \
-model-normalize_std 0.229 0.224 0.225 \
-model-features image:int:$((500*500)) \
-model-predict label:str:1 \
-features label:str:1 \
-sources f=dir \
-source-foldername flower_dataset/test \
-source-feature image \
-source-labels \
crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
-scorer pytorchscore
The output is:
0.8196078431372549
Create an unknown_images.csv
file which contains the filenames of the images to predict on.
cat > unknown_images.csv << EOF
key,image
daisy,daisy.jpg
pansy,pansy.jpg
tigerlily,tigerlily.jpg
buttercup,buttercup.jpg
EOF
In this example, the unknown_images.csv
file contains the filenames of the following images
Predict with the trained model.
dffml predict all \
-model alexnet \
-model-add_layers \
-model-layers @layers.yaml \
-model-clstype str \
-model-classifications \
crocus windflower fritillary tulip pansy dandelion tigerlily sunflower \
bluebell cowslip coltsfoot snowdrop daffodil lilyvalley iris buttercup daisy \
-model-location alexnet_model \
-model-imageSize 224 \
-model-trainable \
-model-enableGPU \
-model-normalize_mean 0.485 0.456 0.406 \
-model-normalize_std 0.229 0.224 0.225 \
-model-features image:int:$((500*500)) \
-model-predict label:str:1 \
-sources f=csv \
-source-filename unknown_images.csv \
-source-loadfiles image \
-pretty
Output
Key: daisy
Record Features
+----------------------------------------------------------------------------------------------------------------------------------------------+
| image | 111, 126, 128, 110, 125, 127, 109, 124, 126, 1 ... (length:689520) |
+----------------------------------------------------------------------------------------------------------------------------------------------+
Prediction
+----------------------------------------------------------------------------------------------------------------------------------------------+
| label |
+----------------------------------------------------------------------------------------------------------------------------------------------+
| Value: daisy | Confidence: 0.9999972581863403 |
+----------------------------------------------------------------------------------------------------------------------------------------------+
Key: pansy
Record Features
+----------------------------------------------------------------------------------------------------------------------------------------------+
| image | 153, 92, 136, 120, 59, 103, 112, 51, 95, 107, ... (length:921600) |
+----------------------------------------------------------------------------------------------------------------------------------------------+
Prediction
+----------------------------------------------------------------------------------------------------------------------------------------------+
| label |
+----------------------------------------------------------------------------------------------------------------------------------------------+
| Value: pansy | Confidence: 0.999854564666748 |
+----------------------------------------------------------------------------------------------------------------------------------------------+
Key: tigerlily
Record Features
+----------------------------------------------------------------------------------------------------------------------------------------------+
| image | 0, 15, 1, 0, 15, 1, 0, 15, 1, 0, 15, 1, 0, 15, ... (length:817920) |
+----------------------------------------------------------------------------------------------------------------------------------------------+
Prediction
+----------------------------------------------------------------------------------------------------------------------------------------------+
| label |
+----------------------------------------------------------------------------------------------------------------------------------------------+
| Value: tigerlily | Confidence: 0.9924067854881287 |
+----------------------------------------------------------------------------------------------------------------------------------------------+
Key: buttercup
Record Features
+----------------------------------------------------------------------------------------------------------------------------------------------+
| image | 10, 18, 17, 10, 18, 17, 10, 18, 17, 10, 18, 17 ... (length:814080) |
+----------------------------------------------------------------------------------------------------------------------------------------------+
Prediction
+----------------------------------------------------------------------------------------------------------------------------------------------+
| label |
+----------------------------------------------------------------------------------------------------------------------------------------------+
| Value: buttercup | Confidence: 0.9977046847343445 |
+----------------------------------------------------------------------------------------------------------------------------------------------+
The model predicts all the flower species correctly with 99% confidence!