neural_compressor.compression.pruner.patterns

Pruners.

Submodules

Package Contents

Functions

get_pattern(config, modules[, framework])

Get registered pattern class.

neural_compressor.compression.pruner.patterns.get_pattern(config, modules, framework='pytorch')[source]

Get registered pattern class.

Get a Pattern object from PATTERNS.

Parameters:
  • config – A config dict object that contains the pattern information.

  • modules – Torch neural network modules to be pruned with the pattern.

Returns:

A Pattern object.

Raises:

AssertionError – Currently only support patterns which have been registered in PATTERNS.