neural_compressor.compression.pruner.regs
Regularizer.
Classes
Regularizer. |
|
Regularizer. |
Functions
|
Register a regularizator to the registry. |
|
Obtain the regularizer type. |
|
Get registered regularizator class. |
Module Contents
- neural_compressor.compression.pruner.regs.register_reg(name)[source]
Register a regularizator to the registry.
- Parameters:
name – A string that defines the scheduler type.
- Returns:
The class of register.
- Return type:
cls
- neural_compressor.compression.pruner.regs.get_reg_type(config)[source]
Obtain the regularizer type.
- Parameters:
config – A config dict object that includes information of the regularizer.
- neural_compressor.compression.pruner.regs.get_reg(config, modules, pattern)[source]
Get registered regularizator class.
- Parameters:
config – A config dict object that includes information of the regularizer.
modules – A dict {“module_name”: Tensor} that stores the pruning modules’ weights.
pattern – A config dict object that includes information of the pattern.
- class neural_compressor.compression.pruner.regs.BaseReg(config: dict, modules: dict, pattern: neural_compressor.compression.pruner.patterns.base.PytorchBasePattern)[source]
Regularizer.
The class that performs regularization.
- Parameters:
modules – A dict {“module_name”: Tensor} that stores the pruning modules’ weights.
config – A config dict object that includes information of the regularizer.
pattern – A config dict object that includes information of the pattern.
- class neural_compressor.compression.pruner.regs.GroupLasso(config: dict, modules: dict, pattern: neural_compressor.compression.pruner.patterns.base.PytorchBasePattern, coeff)[source]
Regularizer.
A regularizer class derived from BaseReg. In this class, the Group-lasso regularization will be performed. Group-lasso is a variable-selection and regularization method.
- Parameters:
modules – A dict {“module_name”: Tensor} that stores the pruning modules’ weights.
config – A config dict object that includes information of the regularizer.
pattern – A config dict object that includes information of the pattern.