neural_compressor.compression.pruner.regs

Regularizer.

Module Contents

Classes

BaseReg

Regularizer.

GroupLasso

Regularizer.

Functions

register_reg(name)

Register a regularizator to the registry.

get_reg_type(config)

Obtain the regularizer type.

get_reg(config, modules, pattern)

Get registered regularizator class.

neural_compressor.compression.pruner.regs.register_reg(name)[source]

Register a regularizator to the registry.

Parameters:

name – A string that defines the scheduler type.

Returns:

The class of register.

Return type:

cls

neural_compressor.compression.pruner.regs.get_reg_type(config)[source]

Obtain the regularizer type.

Parameters:

config – A config dict object that includes information of the regularizer.

neural_compressor.compression.pruner.regs.get_reg(config, modules, pattern)[source]

Get registered regularizator class.

Parameters:
  • config – A config dict object that includes information of the regularizer.

  • modules – A dict {“module_name”: Tensor} that stores the pruning modules’ weights.

  • pattern – A config dict object that includes information of the pattern.

class neural_compressor.compression.pruner.regs.BaseReg(config: dict, modules: dict, pattern: neural_compressor.compression.pruner.patterns.base.PytorchBasePattern)[source]

Regularizer.

The class that performs regularization.

Parameters:
  • modules – A dict {“module_name”: Tensor} that stores the pruning modules’ weights.

  • config – A config dict object that includes information of the regularizer.

  • pattern – A config dict object that includes information of the pattern.

class neural_compressor.compression.pruner.regs.GroupLasso(config: dict, modules: dict, pattern: neural_compressor.compression.pruner.patterns.base.PytorchBasePattern, coeff)[source]

Regularizer.

A regularizer class derived from BaseReg. In this class, the Group-lasso regularization will be performed. Group-lasso is a variable-selection and regularization method.

Parameters:
  • modules – A dict {“module_name”: Tensor} that stores the pruning modules’ weights.

  • config – A config dict object that includes information of the regularizer.

  • pattern – A config dict object that includes information of the pattern.

reg_terms[source]

A dict {“module_name”: Tensor} of regularization terms.

alpha[source]

A float representing the coefficient related to group lasso.