API Reference

Datasets

The simplest way to create datasets is with the dataset factory methods load_dataset(), for using a custom dataset, and get_dataset(), for downloading and using a third-party dataset from a catalog such as TensorFlow Datasets or Torchvision.

Factory Methods

tlt.datasets.dataset_factory.load_dataset(dataset_dir: str, use_case: UseCaseType, framework: FrameworkType, dataset_name=None, **kwargs)[source]

A factory method for loading a custom dataset.

Image classification datasets expect a directory of images organized with subfolders for each image class, which can themselves be in split directories named ‘train’, ‘validation’, and/or ‘test’. Each class subfolder should contain .jpg images for the class. The name of the subfolder will be used as the class label.

dataset_dir
  ├── class_a
  ├── class_b
  └── class_c

Or:

dataset_dir
  ├── train
  |   ├── class_a
  |   ├── class_b
  |   └── class_c
  ├── validation
  |   ├── class_a
  |   ├── class_b
  |   └── class_c
  └── test
      ├── class_a
      ├── class_b
      └── class_c

Text classification datasets are expected to be a directory with text/csv file with two columns: the label and the text/sentence to classify. See the TFCustomTextClassificationDataset documentation for a list of the additional kwargs that are used for loading the a text classification dataset file.

class_a,<text>
class_b,<text>
class_a,<text>
...
Parameters
  • dataset_dir (str) – directory containing the dataset

  • use_case (str or UseCaseType) – use case or task the dataset will be used to model

  • framework (str or FrameworkType) – framework

  • dataset_name (str) – optional; name of the dataset used for informational purposes

  • kwargs – optional; additional keyword arguments depending on the type of dataset being loaded

Returns

(dataset)

Raises

NotImplementedError – if the type of dataset being loaded is not supported

Example

>>> from tlt.datasets.dataset_factory import load_dataset
>>> data = load_dataset('/tmp/data/flower_photos', 'image_classification', 'tensorflow')
Found 3670 files belonging to 5 classes.
>>> data.class_names
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
tlt.datasets.dataset_factory.get_dataset(dataset_dir: str, use_case: UseCaseType, framework: FrameworkType, dataset_name: Optional[str] = None, dataset_catalog: Optional[str] = None, **kwargs)[source]

A factory method for using a dataset from a catalog.

Parameters
  • dataset_dir (str) – directory containing the dataset or to which the dataset should be downloaded

  • use_case (str or UseCaseType) – use case or task the dataset will be used to model

  • framework (str or FrameworkType) – framework

  • dataset_name (str) – optional; name of the dataset

  • dataset_catalog (str) – optional; catalog from which to download the dataset. If a dataset name is provided and no dataset catalog is given, it will default to use tf_datasets for a TensorFlow model, torchvision for PyTorch CV models, and huggingface datasets for PyTorch NLP models or Hugging Face models.

  • **kwargs – optional; additional keyword arguments for the framework or dataset_catalog

Returns

(dataset)

Raises

NotImplementedError – if the dataset requested is not supported yet

Example

>>> from tlt.datasets.dataset_factory import get_dataset
>>> data = get_dataset('/tmp/data/', 'image_classification', 'tensorflow', 'tf_flowers', 'tf_datasets')
>>> sorted(data.class_names)
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

Class Reference

Image Classification

tfds_image_classification_dataset.TFDSImageClassificationDataset

An image classification dataset from the TensorFlow datasets catalog

torchvision_image_classification_dataset.TorchvisionImageClassificationDataset

An image classification dataset from the torchvision catalog

tf_custom_image_classification_dataset.TFCustomImageClassificationDataset

A custom image classification dataset that can be used with TensorFlow models.

pytorch_custom_image_classification_dataset.PyTorchCustomImageClassificationDataset

A custom image classification dataset that can be used with PyTorch models.

image_classification_dataset.ImageClassificationDataset

Base class for an image classification dataset

Image Anomaly Detection

pytorch_custom_image_anomaly_detection_dataset.PyTorchCustomImageAnomalyDetectionDataset

A custom image anomaly detection dataset that can be used with PyTorch models.

Text Classification

tfds_text_classification_dataset.TFDSTextClassificationDataset

A text classification dataset from the TensorFlow datasets catalog

hf_text_classification_dataset.HFTextClassificationDataset

A text classification dataset from the Hugging Face datasets catalog

tf_custom_text_classification_dataset.TFCustomTextClassificationDataset

A custom text classification dataset that can be used with TensorFlow models.

hf_custom_text_classification_dataset.HFCustomTextClassificationDataset

A custom text classification dataset that can be used with Transformer models.

text_classification_dataset.TextClassificationDataset

Base class for a text classification dataset

Text Generation

hf_custom_text_generation_dataset.HFCustomTextGenerationDataset

A custom text generation dataset that can be used with Transformer models.

text_generation_dataset.TextGenerationDataset

Base class for a text generation dataset

Base Classes

Note

Users should rarely need to interact directly with these.

pytorch_dataset.PyTorchDataset

Base class to represent a PyTorch Dataset

tf_dataset.TFDataset

Base class to represent a TF Dataset

hf_dataset.HFDataset

Base class to represent Hugging Face Dataset

dataset.BaseDataset

Abstract base class for a dataset used for training and evaluation

Models

Discover and work with available models by using model factory methods. The get_model() function will download third-party models, while the load_model() function will load a custom model, from either a path location or a model object in memory. The model discovery and inspection methods are get_supported_models() and print_supported_models().

Factory Methods

tlt.models.model_factory.load_model(model_name: str, model, framework: Optional[FrameworkType] = None, use_case: Optional[UseCaseType] = None, model_hub: Optional[str] = None, **kwargs)[source]

A factory method for loading an existing model.

Parameters
  • model_name (str) – name of model

  • model (model or str) – model object or directory with a saved_model.pb or model.pt file to load

  • framework (str or FrameworkType) – framework

  • use_case (str or UseCaseType) – use case

  • model_hub (str) – The model hub where the model originated

  • kwargs – optional; additional keyword arguments for optimizer and loss function configuration. The optimizer and loss arguments can be set to Optimizer and Loss classes, depending on the model’s framework (examples: optimizer=tf.keras.optimizers.Adam for TensorFlow, loss=torch.nn.CrossEntropyLoss for PyTorch). Additional keywords for those classes’ initialization can then be provided to further configure the objects when they are created (example: amsgrad=True for the PyTorch Adam optimizer). Refer to the framework documentation for the function you want to use.

Returns

model object

Examples

>>> from tensorflow.keras import Sequential, Input
>>> from tensorflow.keras.layers import Dense
>>> from tlt.models.model_factory import load_model
>>> my_model = Sequential([Input(shape=(3,)), Dense(4, activation='relu'), Dense(5, activation='softmax')])
>>> model = load_model('my_model', my_model, 'tensorflow', 'image_classification')
tlt.models.model_factory.get_model(model_name: str, framework: Optional[FrameworkType] = None, use_case: Optional[UseCaseType] = None, **kwargs)[source]

A factory method for creating models.

Parameters
  • model_name (str) – name of model

  • framework (str or FrameworkType) – framework

  • use_case (str or FrameworkType) – use case

  • kwargs – optional; additional keyword arguments for optimizer and loss function configuration. The optimizer and loss arguments can be set to Optimizer and Loss classes, depending on the model’s framework (examples: optimizer=tf.keras.optimizers.Adam for TensorFlow, loss=torch.nn.CrossEntropyLoss for PyTorch). Additional keywords for those classes’ initialization can then be provided to further configure the objects when they are created (example: amsgrad=True for the PyTorch Adam optimizer). Refer to the framework documentation for the function you want to use.

Returns

model object

Raises

NotImplementedError – if the model requested is not supported yet

Example

>>> from tlt.models.model_factory import get_model
>>> model = get_model('efficientnet_b0', 'tensorflow')
>>> model.image_size
224
tlt.models.model_factory.get_supported_models(framework: Optional[FrameworkType] = None, use_case: Optional[UseCaseType] = None)[source]

Returns a dictionary of supported models organized by use case, model name, and framework. The leaf items in the dictionary are attributes about the pretrained model.

Parameters
  • framework (str or FrameworkType) – framework

  • use_case (str or UseCaseType) – use case

Returns

dictionary

Raises

NameError – if a model config file is found with an unknown or missing use case

tlt.models.model_factory.print_supported_models(framework: Optional[FrameworkType] = None, use_case: Optional[UseCaseType] = None, verbose: bool = False, markdown: bool = False)[source]

Prints a list of the supported models, categorized by use case. The results can be filtered to only show a given framework or use case.

Parameters
  • framework (str or FrameworkType) – framework

  • use_case (str or UseCaseType) – use case

  • verbose (boolean) – include all model data from the config file in result, default is False

  • markdown (boolean) – Print results as markdown tables (used for updating documentation). Not compatible with verbose=True.

Class Reference

Image Classification

tfhub_image_classification_model.TFHubImageClassificationModel

Class to represent a TF Hub pretrained model for image classification

tf_image_classification_model.TFImageClassificationModel

Class to represent a TF custom pretrained model for image classification

keras_image_classification_model.KerasImageClassificationModel

Class to represent a Keras.applications pretrained model for image classification

torchvision_image_classification_model.TorchvisionImageClassificationModel

Class to represent a Torchvision pretrained model for image classification

pytorch_image_classification_model.PyTorchImageClassificationModel

Class to represent a PyTorch model for image classification

pytorch_hub_image_classification_model.PyTorchHubImageClassificationModel

Class to represent a PyTorch Hub pretrained model for image classification

image_classification_model.ImageClassificationModel

Base class to represent a pretrained model for image classification

Image Anomaly Detection

torchvision_image_anomaly_detection_model.TorchvisionImageAnomalyDetectionModel

Class to represent a Torchvision pretrained model for anomaly detection

pytorch_image_anomaly_detection_model.PyTorchImageAnomalyDetectionModel

Class to represent a PyTorch model for image classification

Text Classification

tf_text_classification_model.TFTextClassificationModel

Class to represent a TF pretrained model that can be used for binary text classification fine tuning.

pytorch_hf_text_classification_model.PyTorchHFTextClassificationModel

Class to represent a PyTorch Hugging Face pretrained model that can be used for multi-class text classification fine tuning.

tf_hf_text_classification_model.TFHFTextClassificationModel

Class to represent a TensorFlow pretrained model from Hugging Face that can be used for binary text classification fine tuning.

text_classification_model.TextClassificationModel

Class to represent a pretrained model for text classification

Text Generation

pytorch_hf_text_generation_model.PyTorchHFTextGenerationModel

Class to represent a PyTorch Hugging Face pretrained model that can be used for text generation fine tuning.

text_generation_model.TextGenerationModel

Class to represent a pretrained model for text generation

Base Classes

Note

Users should rarely need to interact directly with these.

pytorch_model.PyTorchModel

Base class to represent a PyTorch model

tf_model.TFModel

Base class to represent a TF pretrained model

hf_model.HFModel

Base class to represent a Hugging Face model

model.BaseModel

Abstract base class for a pretrained model that can be used for transfer learning