API Reference¶
Datasets¶
The simplest way to create datasets is with the dataset factory methods load_dataset()
, for using a
custom dataset, and get_dataset()
, for downloading and using a third-party dataset from a catalog such as TensorFlow
Datasets or Torchvision.
Factory Methods¶
- tlt.datasets.dataset_factory.load_dataset(dataset_dir: str, use_case: UseCaseType, framework: FrameworkType, dataset_name=None, **kwargs)[source]¶
A factory method for loading a custom dataset.
Image classification datasets expect a directory of images organized with subfolders for each image class, which can themselves be in split directories named ‘train’, ‘validation’, and/or ‘test’. Each class subfolder should contain .jpg images for the class. The name of the subfolder will be used as the class label.
dataset_dir ├── class_a ├── class_b └── class_c
Or:
dataset_dir ├── train | ├── class_a | ├── class_b | └── class_c ├── validation | ├── class_a | ├── class_b | └── class_c └── test ├── class_a ├── class_b └── class_c
Text classification datasets are expected to be a directory with text/csv file with two columns: the label and the text/sentence to classify. See the TFCustomTextClassificationDataset documentation for a list of the additional kwargs that are used for loading the a text classification dataset file.
class_a,<text> class_b,<text> class_a,<text> ...
- Parameters
dataset_dir (str) – directory containing the dataset
use_case (str or UseCaseType) – use case or task the dataset will be used to model
framework (str or FrameworkType) – framework
dataset_name (str) – optional; name of the dataset used for informational purposes
kwargs – optional; additional keyword arguments depending on the type of dataset being loaded
- Returns
(dataset)
- Raises
NotImplementedError – if the type of dataset being loaded is not supported
Example
>>> from tlt.datasets.dataset_factory import load_dataset >>> data = load_dataset('/tmp/data/flower_photos', 'image_classification', 'tensorflow') Found 3670 files belonging to 5 classes. >>> data.class_names ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
- tlt.datasets.dataset_factory.get_dataset(dataset_dir: str, use_case: UseCaseType, framework: FrameworkType, dataset_name: Optional[str] = None, dataset_catalog: Optional[str] = None, **kwargs)[source]¶
A factory method for using a dataset from a catalog.
- Parameters
dataset_dir (str) – directory containing the dataset or to which the dataset should be downloaded
use_case (str or UseCaseType) – use case or task the dataset will be used to model
framework (str or FrameworkType) – framework
dataset_name (str) – optional; name of the dataset
dataset_catalog (str) – optional; catalog from which to download the dataset. If a dataset name is provided and no dataset catalog is given, it will default to use tf_datasets for a TensorFlow model, torchvision for PyTorch CV models, and huggingface datasets for PyTorch NLP models or Hugging Face models.
**kwargs – optional; additional keyword arguments for the framework or dataset_catalog
- Returns
(dataset)
- Raises
NotImplementedError – if the dataset requested is not supported yet
Example
>>> from tlt.datasets.dataset_factory import get_dataset >>> data = get_dataset('/tmp/data/', 'image_classification', 'tensorflow', 'tf_flowers', 'tf_datasets') >>> sorted(data.class_names) ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
Class Reference¶
Image Classification¶
|
An image classification dataset from the TensorFlow datasets catalog |
|
An image classification dataset from the torchvision catalog |
|
A custom image classification dataset that can be used with TensorFlow models. |
|
A custom image classification dataset that can be used with PyTorch models. |
Base class for an image classification dataset |
Image Anomaly Detection¶
|
A custom image anomaly detection dataset that can be used with PyTorch models. |
Text Classification¶
|
A text classification dataset from the TensorFlow datasets catalog |
A text classification dataset from the Hugging Face datasets catalog |
|
|
A custom text classification dataset that can be used with TensorFlow models. |
|
A custom text classification dataset that can be used with Transformer models. |
Base class for a text classification dataset |
Text Generation¶
|
A custom text generation dataset that can be used with Transformer models. |
Base class for a text generation dataset |
Base Classes¶
Note
Users should rarely need to interact directly with these.
Base class to represent a PyTorch Dataset |
|
Base class to represent a TF Dataset |
|
Base class to represent Hugging Face Dataset |
|
Abstract base class for a dataset used for training and evaluation |
Models¶
Discover and work with available models by using model factory methods. The get_model()
function will download
third-party models, while the load_model()
function will load a custom model, from either a path location or a
model object in memory. The model discovery and inspection methods are get_supported_models()
and
print_supported_models()
.
Factory Methods¶
- tlt.models.model_factory.load_model(model_name: str, model, framework: Optional[FrameworkType] = None, use_case: Optional[UseCaseType] = None, model_hub: Optional[str] = None, **kwargs)[source]¶
A factory method for loading an existing model.
- Parameters
model_name (str) – name of model
model (model or str) – model object or directory with a saved_model.pb or model.pt file to load
framework (str or FrameworkType) – framework
use_case (str or UseCaseType) – use case
model_hub (str) – The model hub where the model originated
kwargs – optional; additional keyword arguments for optimizer and loss function configuration. The optimizer and loss arguments can be set to Optimizer and Loss classes, depending on the model’s framework (examples: optimizer=tf.keras.optimizers.Adam for TensorFlow, loss=torch.nn.CrossEntropyLoss for PyTorch). Additional keywords for those classes’ initialization can then be provided to further configure the objects when they are created (example: amsgrad=True for the PyTorch Adam optimizer). Refer to the framework documentation for the function you want to use.
- Returns
model object
Examples
>>> from tensorflow.keras import Sequential, Input >>> from tensorflow.keras.layers import Dense >>> from tlt.models.model_factory import load_model >>> my_model = Sequential([Input(shape=(3,)), Dense(4, activation='relu'), Dense(5, activation='softmax')]) >>> model = load_model('my_model', my_model, 'tensorflow', 'image_classification')
- tlt.models.model_factory.get_model(model_name: str, framework: Optional[FrameworkType] = None, use_case: Optional[UseCaseType] = None, **kwargs)[source]¶
A factory method for creating models.
- Parameters
model_name (str) – name of model
framework (str or FrameworkType) – framework
use_case (str or FrameworkType) – use case
kwargs – optional; additional keyword arguments for optimizer and loss function configuration. The optimizer and loss arguments can be set to Optimizer and Loss classes, depending on the model’s framework (examples: optimizer=tf.keras.optimizers.Adam for TensorFlow, loss=torch.nn.CrossEntropyLoss for PyTorch). Additional keywords for those classes’ initialization can then be provided to further configure the objects when they are created (example: amsgrad=True for the PyTorch Adam optimizer). Refer to the framework documentation for the function you want to use.
- Returns
model object
- Raises
NotImplementedError – if the model requested is not supported yet
Example
>>> from tlt.models.model_factory import get_model >>> model = get_model('efficientnet_b0', 'tensorflow') >>> model.image_size 224
- tlt.models.model_factory.get_supported_models(framework: Optional[FrameworkType] = None, use_case: Optional[UseCaseType] = None)[source]¶
Returns a dictionary of supported models organized by use case, model name, and framework. The leaf items in the dictionary are attributes about the pretrained model.
- Parameters
framework (str or FrameworkType) – framework
use_case (str or UseCaseType) – use case
- Returns
dictionary
- Raises
NameError – if a model config file is found with an unknown or missing use case
- tlt.models.model_factory.print_supported_models(framework: Optional[FrameworkType] = None, use_case: Optional[UseCaseType] = None, verbose: bool = False, markdown: bool = False)[source]¶
Prints a list of the supported models, categorized by use case. The results can be filtered to only show a given framework or use case.
- Parameters
framework (str or FrameworkType) – framework
use_case (str or UseCaseType) – use case
verbose (boolean) – include all model data from the config file in result, default is False
markdown (boolean) – Print results as markdown tables (used for updating documentation). Not compatible with verbose=True.
Class Reference¶
Image Classification¶
|
Class to represent a TF Hub pretrained model for image classification |
Class to represent a TF custom pretrained model for image classification |
|
|
Class to represent a Keras.applications pretrained model for image classification |
|
Class to represent a Torchvision pretrained model for image classification |
|
Class to represent a PyTorch model for image classification |
|
Class to represent a PyTorch Hub pretrained model for image classification |
Base class to represent a pretrained model for image classification |
Image Anomaly Detection¶
|
Class to represent a Torchvision pretrained model for anomaly detection |
|
Class to represent a PyTorch model for image classification |
Text Classification¶
Class to represent a TF pretrained model that can be used for binary text classification fine tuning. |
|
|
Class to represent a PyTorch Hugging Face pretrained model that can be used for multi-class text classification fine tuning. |
Class to represent a TensorFlow pretrained model from Hugging Face that can be used for binary text classification fine tuning. |
|
Class to represent a pretrained model for text classification |
Text Generation¶
|
Class to represent a PyTorch Hugging Face pretrained model that can be used for text generation fine tuning. |
Class to represent a pretrained model for text generation |
Base Classes¶
Note
Users should rarely need to interact directly with these.
Base class to represent a PyTorch model |
|
Base class to represent a TF pretrained model |
|
Base class to represent a Hugging Face model |
|
Abstract base class for a pretrained model that can be used for transfer learning |