tlt.models.text_classification.tf_text_classification_model.TFTextClassificationModel¶
- class tlt.models.text_classification.tf_text_classification_model.TFTextClassificationModel(model_name: str, model=None, optimizer=None, loss=None, **kwargs)[source]¶
Class to represent a TF pretrained model that can be used for binary text classification fine tuning.
- __init__(model_name: str, model=None, optimizer=None, loss=None, **kwargs)[source]¶
Class constructor
Methods
__init__
(model_name[, model, optimizer, loss])Class constructor
benchmark
(dataset[, saved_model_dir, ...])Use Intel Neural Compressor to benchmark the model with the dataset argument.
cleanup_saved_objects_for_distributed
()evaluate
(dataset[, use_test_set, ...])If there is a validation set, evaluation will be done on it (by default) or on the test set (by setting use_test_set=True).
export
(output_dir)Exports a trained model as a saved_model.pb file.
export_for_distributed
([export_dir, ...])Exports the model, optimizer, loss, train data and validation data to the export_dir for distributed script to access.
load_from_directory
(model_dir)Loads a saved model from the specified directory
optimize_graph
(output_dir[, overwrite_model])Performs FP32 graph optimization using the Intel Neural Compressor on the model and writes the inference-optimized model to the output_dir.
predict
(input_samples[, ...])Generates predictions for the specified input samples.
quantize
(output_dir, dataset[, config, ...])Performs post training quantization using the Intel Neural Compressor on the model using the dataset.
set_auto_mixed_precision
(...)Enable auto mixed precision for training.
train
(dataset, output_dir[, epochs, ...])Trains the model using the specified binary text classification dataset.
Attributes
dropout_layer_rate
The probability of any one node being dropped when a dropout layer is used
framework
Framework with which the model is compatible
learning_rate
Learning rate for the model
model_name
Name of the model
num_classes
The number of output neurons in the model; equal to the number of classes in the dataset
preprocessor
Preprocessor for the model
use_case
Use case (or category) to which the model belongs