neural_compressor.data

Built-in dataloaders, datasets, transforms, filters for multiple framework backends.

Subpackages

Package Contents

Classes

Datasets

A base class for all framework datasets.

Dataset

The base class of dataset.

IterableDataset

An iterable Dataset.

DataLoader

Entrance of all configured DataLoaders. Will dispatch the DataLoaders to framework

TRANSFORMS

Transforms collection class.

BaseTransform

The base class for transform.

Postprocess

Just collect the infos to construct a Postprocess.

FILTERS

The filter register for all frameworks.

Filter

The base class for transform.

Functions

dataset_registry(dataset_type, framework[, dataset_format])

Register dataset subclasses.

transform_registry(transform_type, process, framework)

Class decorator used to register all transform subclasses.

filter_registry(filter_type, framework)

Register all transform subclasses.

Attributes

DATALOADERS

class neural_compressor.data.Datasets(framework)

Bases: object

A base class for all framework datasets.

Parameters:

framework (str) – framework name, like:”tensorflow”, “tensorflow_itex”, “mxnet”, “onnxrt_qdq”, “onnxrt_qlinearops”, “onnxrt_integerops”, “pytorch”, “pytorch_ipex”, “pytorch_fx”, “onnxrt_qoperator”.

class neural_compressor.data.Dataset

Bases: object

The base class of dataset.

Subclass datasets should overwrite two methods: __getitem__ for indexing to data sample and `__len__`for the size of the dataset

class neural_compressor.data.IterableDataset

Bases: object

An iterable Dataset.

Subclass iterable dataset should also implement a method: __iter__ for interating over the samples of the dataset.

neural_compressor.data.dataset_registry(dataset_type, framework, dataset_format='')

Register dataset subclasses.

Parameters:
  • cls (class) – The class of register.

  • dataset_type (str) – The dataset registration name

  • framework (str) – support 3 framework including ‘tensorflow’, ‘pytorch’, ‘mxnet’

  • data_format (str) – The format dataset saved, eg ‘raw_image’, ‘tfrecord’

Returns:

The class of register.

Return type:

cls

class neural_compressor.data.DataLoader

Bases: object

Entrance of all configured DataLoaders. Will dispatch the DataLoaders to framework specific one. Users will be not aware of the dispatching, and the Interface is unified.

class neural_compressor.data.TRANSFORMS(framework, process)

Bases: object

Transforms collection class.

Provide register method to register new Transforms and provide __getitem__ method to get Transforms according to Transforms type.

register(name, transform_cls)

Register new Transform according to Transforms type.

Parameters:
  • name (str) – process name

  • transform_cls (class) – process function wrapper class

class neural_compressor.data.BaseTransform

Bases: object

The base class for transform.

neural_compressor.data.transform_registry(transform_type, process, framework)

Class decorator used to register all transform subclasses.

Parameters:
  • transform_type (str) – Transform registration name

  • process (str) – support 3 process including ‘preprocess’, ‘postprocess’, ‘general’

  • framework (str) – support 4 framework including ‘tensorflow’, ‘pytorch’, ‘mxnet’, ‘onnxrt’

  • cls (class) – The class of register.

Returns:

The class of register.

Return type:

cls

class neural_compressor.data.Postprocess(postprocess_cls, name='user_postprocess', **kwargs)

Bases: object

Just collect the infos to construct a Postprocess.

class neural_compressor.data.FILTERS(framework)

Bases: object

The filter register for all frameworks.

Parameters:

framework (str) – frameworks in [“tensorflow”, “tensorflow_itex”, “mxnet”, “onnxrt_qdq”, “pytorch”, “pytorch_ipex”, “pytorch_fx”, “onnxrt_integerops”, “onnxrt_qlinearops”, “onnxrt_qoperator”].

class neural_compressor.data.Filter

Bases: object

The base class for transform.

__call__ method is needed when write user specific transform.

neural_compressor.data.filter_registry(filter_type, framework)

Register all transform subclasses.

Parameters:
  • filter_type (str) – fILTER registration name.

  • framework (str) – support 4 framework including ‘tensorflow’, ‘pytorch’, ‘mxnet’, ‘onnxrt’.

  • cls (class) – The class of register.

Returns:

The class of register.

Return type:

cls