neural_compressor.compression.pruner.pruners

Pruning patterns.

Submodules

Functions

parse_valid_pruner_types()

Get all valid pruner names.

get_pruner(config, modules[, framework])

Get registered pruner class.

Package Contents

neural_compressor.compression.pruner.pruners.parse_valid_pruner_types()[source]

Get all valid pruner names.

neural_compressor.compression.pruner.pruners.get_pruner(config, modules, framework='pytorch')[source]

Get registered pruner class.

Get a Pruner object from PRUNERS.

Parameters:
  • modules – A dict {“module_name”: Tensor} that stores the pruning modules’ weights.

  • config – A config dict object that contains the pruner information.

Returns:

A Pruner object.

Raises: AssertionError: Currently only support pruners that have been registered in PRUNERS.