tlt.datasets.pytorch_dataset.PyTorchDataset

class tlt.datasets.pytorch_dataset.PyTorchDataset(dataset_dir, dataset_name='', dataset_catalog='')[source]

Base class to represent a PyTorch Dataset

__init__(dataset_dir, dataset_name='', dataset_catalog='')[source]

Class constructor

Methods

__init__(dataset_dir[, dataset_name, ...])

Class constructor

get_batch([subset])

Get a single batch of images and labels from the dataset.

get_inc_dataloaders()

preprocess([image_size, batch_size, add_aug])

Preprocess the dataset to resize, normalize, and batch the images.

shuffle_split([train_pct, val_pct, ...])

Randomly split the dataset into train, validation, and test subsets with a pseudo-random seed option.

Attributes

data_loader

A data loader object corresponding to the dataset

dataset

The framework dataset object

dataset_catalog

The string name of the dataset catalog (or None)

dataset_dir

Host directory containing the dataset files

dataset_name

Name of the dataset

test_loader

A data loader object corresponding to the test subset

test_subset

A subset of the dataset held out for final testing/evaluation

train_loader

A data loader object corresponding to the training subset

train_subset

A subset of the dataset used for training

validation_loader

A data loader object corresponding to the validation subset

validation_subset

A subset of the dataset used for validation/evaluation