neural_compressor.experimental.distillation

Distillation class.

Module Contents

Classes

Distillation

Distillation class derived from Component class.

class neural_compressor.experimental.distillation.Distillation(conf_fname_or_obj=None)

Bases: neural_compressor.experimental.component.Component

Distillation class derived from Component class.

Distillation class abstracted the pipeline of knowledge distillation, transfer the knowledge of the teacher model to the student model.

Parameters:

conf_fname_or_obj (string or obj) – The path to the YAML configuration file or Distillation_Conf containing accuracy goal, distillation objective and related dataloaders etc.

_epoch_ran

A integer indicating how much epochs ran.

eval_frequency

The frequency for doing evaluation of the student model in terms of epoch.

best_score

The best metric of the student model in the training.

best_model

The best student model found in the training.

property criterion

Getter of criterion.

Returns:

The criterion used in the distillation process.

property optimizer

Getter of optimizer.

Returns:

The optimizer used in the distillation process.

property teacher_model

Getter of the teacher model.

Returns:

The teacher model used in the distillation process.

property student_model

Getter of the student model.

Returns:

The student model used in the distillation process.

property train_cfg

Getter of the train configuration.

Returns:

The train configuration used in the distillation process.

property evaluation_distributed

Getter to know whether need distributed evaluation dataloader.

property train_distributed

Getter to know whether need distributed training dataloader.

on_post_forward(input, teacher_output=None)

Set or compute output of teacher model.

Deprecated.

init_train_cfg()

Initialize the training configuration.

create_criterion()

Create the criterion for training.

create_optimizer()

Create the optimizer for training.

prepare()

Prepare hooks.

pre_process()

Preprocessing before the disillation pipeline.

Initialize necessary parts for distillation pipeline.

execute()

Do distillation pipeline.

First train the student model with the teacher model, after training, evaluating the best student model if any.

Returns:

Best distilled model found.

generate_hooks()

Register hooks for distillation.

Register necessary hooks for distillation pipeline.