:py:mod:`neural_compressor.experimental.pytorch_pruner.pruner` ============================================================== .. py:module:: neural_compressor.experimental.pytorch_pruner.pruner .. autoapi-nested-parse:: Pruner module. Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: neural_compressor.experimental.pytorch_pruner.pruner.Pruner neural_compressor.experimental.pytorch_pruner.pruner.MagnitudePruner neural_compressor.experimental.pytorch_pruner.pruner.SnipPruner neural_compressor.experimental.pytorch_pruner.pruner.SnipMomentumPruner neural_compressor.experimental.pytorch_pruner.pruner.PatternLockPruner Functions ~~~~~~~~~ .. autoapisummary:: neural_compressor.experimental.pytorch_pruner.pruner.register_pruners neural_compressor.experimental.pytorch_pruner.pruner.get_pruner .. py:function:: register_pruners(name) Class decorator to register a Pruner subclass to the registry. Decorator function used before a Pattern subclass. Make sure that the Pruner class decorated by this function can be registered in PRUNERS. :param cls: The subclass of register. :type cls: class :param name: A string. Define the pruner type. :returns: The class of register. :rtype: cls .. py:function:: get_pruner(modules, config) Get registered pruner class. Get a Pruner object from PRUNERS. :param modules: A dict {"module_name": Tensor}. Store the pruning modules' weights. :param config: A config dict object. Contains the pruner information. :returns: A Pruner object. Raises: AssertionError: Currently only support pruners which have been registered in PRUNERS. .. py:class:: Pruner(modules, config) Pruning Pruner. The class which executes pruning process. 1. Defines pruning functions called at step begin/end, epoch begin/end. 2. Defines the pruning criteria. :param modules: A dict {"module_name": Tensor}. Store the pruning modules' weights. :param config: A config dict object. Contains the pruner information. .. attribute:: modules A dict {"module_name": Tensor}. Store the pruning modules' weights. .. attribute:: config A config dict object. Contains the pruner information. .. attribute:: masks A dict {"module_name": Tensor}. Store the masks for modules' weights. .. attribute:: scores A dict {"module_name": Tensor}. Store the score for modules' weights, which are used to decide pruning parts with a criteria. .. attribute:: pattern A Pattern object. Defined in ./patterns.py .. attribute:: scheduler A scheduler object. Defined in ./scheduler.py .. attribute:: current_sparsity_ratio A float. Current model's sparsity ratio, initialized as zero. .. attribute:: global_step A integer. The total steps the model has run. .. attribute:: start_step A integer. When to trigger pruning process. .. attribute:: end_step A integer. When to end pruning process. .. attribute:: update_frequency_on_step A integer. The pruning frequency, which's valid when iterative pruning is enabled. .. attribute:: target_sparsity_ratio A float. The final sparsity after pruning. .. attribute:: max_sparsity_ratio_per_layer A float. Sparsity ratio maximum for every module. .. py:class:: MagnitudePruner(modules, config) Pruning Pruner. A Pruner class derived from Pruner. In this pruner, the scores are calculated based on weights. :param modules: A dict {"module_name": Tensor}. Store the pruning modules' weights. :param config: A config dict object. Contains the pruner information. .. attribute:: Inherit from parent class Pruner. .. py:class:: SnipPruner(modules, config) Pruning Pruner. A Pruner class derived from Pruner. In this pruner, the scores are calculated based on SNIP. Please refer to SNIP: Single-shot Network Pruning based on Connection Sensitivity (https://arxiv.org/abs/1810.02340) :param modules: A dict {"module_name": Tensor}. Store the pruning modules' weights. :param config: A config dict object. Contains the pruner information. .. attribute:: Inherit from parent class Pruner. .. py:class:: SnipMomentumPruner(modules, config) Pruning Pruner. A Pruner class derived from Pruner. In this pruner, the scores are calculated based on SNIP. Moreoever, the score map is updated with a momentum like process. :param modules: A dict {"module_name": Tensor}. Store the pruning modules' weights. :param config: A config dict object. Contains the pruner information. .. attribute:: Inherit from parent class Pruner. .. py:class:: PatternLockPruner(modules, config) Pruning Pruner. A Pruner class derived from Pruner. In this pruner, original model's sparsity pattern will be fixed while training. This pruner is useful when you want to train a sparse model without change its original structure. :param modules: A dict {"module_name": Tensor}. Store the pruning modules' weights. :param config: A config dict object. Contains the pruner information. .. attribute:: Inherit from parent class Pruner.