tlt.models.image_classification.torchvision_image_classification_model.TorchvisionImageClassificationModel.train¶
- TorchvisionImageClassificationModel.train(dataset: ImageClassificationDataset, output_dir, epochs=1, initial_checkpoints=None, do_eval=True, early_stopping=False, lr_decay=True, seed=None, extra_layers=None, ipex_optimize=True, distributed=False, hostfile=None, nnodes=1, nproc_per_node=1, use_horovod=False, hvd_start_timeout=30, enable_auto_mixed_precision=None, device=None)[source]¶
Trains the model using the specified image classification dataset. The first time training is called, it will get the model from torchvision and add on a fully-connected dense layer with linear activation based on the number of classes in the specified dataset. The model and optimizer are defined and trained for the specified number of epochs.
- Parameters
dataset (ImageClassificationDataset) – Dataset to use when training the model
output_dir (str) – Path to a writeable directory for output files
epochs (int) – Number of epochs to train the model (default: 1)
initial_checkpoints (str) – Path to checkpoint weights to load. If the path provided is a directory, the latest checkpoint will be used.
do_eval (bool) – If do_eval is True and the dataset has a validation subset, the model will be evaluated at the end of each epoch.
early_stopping (bool) – Enable early stopping if convergence is reached while training
enable_auto_mixed_precision (bool or None) – Enable auto mixed precision for evaluate. Mixed precision uses both 16-bit and 32-bit floating point types to make evaluation run faster and use less memory. It is recommended to enable auto mixed precision when running on platforms that support bfloat16 (Intel third or fourth generation Xeon processors). If it is enabled on a platform that does not support bfloat16, it can be detrimental to the evaluation performance. If enable_auto_mixed_precision is set to None, auto mixed precision will be automatically enabled when running with Intel fourth generation Xeon processors, and disabled for other platforms.
lr_decay (bool) – If lr_decay is True and do_eval is True, learning rate decay on the validation loss is applied at the end of each epoch.
seed (int) – Optionally set a seed for reproducibility.
extra_layers (list[int]) – Optionally insert additional dense layers between the base model and output layer. This can help increase accuracy when fine-tuning a PyTorch model. The input should be a list of integers representing the number and size of the layers, for example [1024, 512] will insert two dense layers, the first with 1024 neurons and the second with 512 neurons.
ipex_optimize (bool) – Use Intel Extension for PyTorch (IPEX). Defaults to True.
distributed (bool) – Boolean flag to use distributed training. Defaults to False.
hostfile (str) – Name of the hostfile for distributed training. Defaults to None.
nnodes (int) – Number of nodes to use for distributed training. Defaults to 1.
nproc_per_node (int) – Number of processes to spawn per node to use for distributed training. Defaults to 1.
device (str) – Enter “cpu” or “hpu” to specify which hardware device to run training on. If device=”hpu” is specified, but no HPU hardware or installs are detected, CPU will be used. (default: “cpu”)
- Returns
Trained PyTorch model object