tlt.models.text_classification.tf_text_classification_model.TFTextClassificationModel.train¶
- TFTextClassificationModel.train(dataset: TextClassificationDataset, output_dir, epochs=1, initial_checkpoints=None, do_eval=True, early_stopping=False, lr_decay=True, enable_auto_mixed_precision=None, shuffle_files=True, seed=None, distributed=False, hostfile=None, nnodes=1, nproc_per_node=1, **kwargs)[source]¶
Trains the model using the specified binary text classification dataset. If a path to initial checkpoints is provided, those weights are loaded before training.
- Parameters
dataset (TextClassificationDataset) – The dataset to use for training. If a train subset has been defined, that subset will be used to fit the model. Otherwise, the entire non-partitioned dataset will be used.
output_dir (str) – A writeable output directory to write checkpoint files during training
epochs (int) – The number of training epochs [default: 1]
initial_checkpoints (str) – Path to checkpoint weights to load. If the path provided is a directory, the latest checkpoint will be used.
do_eval (bool) – If do_eval is True and the dataset has a validation subset, the model will be evaluated at the end of each epoch.
early_stopping (bool) – Enable early stopping if convergence is reached while training at the end of each epoch.
lr_decay (bool) – If lr_decay is True and do_eval is True, learning rate decay on the validation loss is applied at the end of each epoch.
enable_auto_mixed_precision (bool or None) – Enable auto mixed precision for training. Mixed precision uses both 16-bit and 32-bit floating point types to make training run faster and use less memory. It is recommended to enable auto mixed precision training when running on platforms that support bfloat16 (Intel third or fourth generation Xeon processors). If it is enabled on a platform that does not support bfloat16, it can be detrimental to the training performance. If enable_auto_mixed_precision is set to None, auto mixed precision will be automatically enabled when running with Intel fourth generation Xeon processors, and disabled for other platforms.
shuffle_files (bool) – Boolean specifying whether to shuffle the training data before each epoch.
seed (int) – Optionally set a seed for reproducibility.
- Returns
History object from the model.fit() call
- Raises
FileExistsError – if the output directory is a file
TypeError – if the dataset specified is not a TextClassificationDataset
TypeError – if the output_dir parameter is not a string
TypeError – if the epochs parameter is not a integer
TypeError – if the initial_checkpoints parameter is not a string
NotImplementedError – if the specified dataset has more than 2 classes