:py:mod:`neural_compressor.pruner.pruners`
==========================================

.. py:module:: neural_compressor.pruner.pruners

.. autoapi-nested-parse::

   Pruner.



Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   neural_compressor.pruner.pruners.BasePruner
   neural_compressor.pruner.pruners.BasicPruner
   neural_compressor.pruner.pruners.PatternLockPruner
   neural_compressor.pruner.pruners.ProgressivePruner



Functions
~~~~~~~~~

.. autoapisummary::

   neural_compressor.pruner.pruners.register_pruner
   neural_compressor.pruner.pruners.parse_valid_pruner_types
   neural_compressor.pruner.pruners.get_pruner



.. py:function:: register_pruner(name)

   Class decorator to register a Pruner subclass to the registry.

   Decorator function used before a Pattern subclass.
   Make sure that the Pruner class decorated by this function can be registered in PRUNERS.

   :param cls: The subclass of register.
   :type cls: class
   :param name: A string. Define the pruner type.

   :returns: The class of register.
   :rtype: cls


.. py:function:: parse_valid_pruner_types()

   Get all valid pruner names.


.. py:function:: get_pruner(config, modules)

   Get registered pruner class.

   Get a Pruner object from PRUNERS.

   :param modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
   :param config: A config dict object that contains the pruner information.

   :returns: A Pruner object.

   Raises: AssertionError: Cuurently only support pruners that have been registered in PRUNERS.


.. py:class:: BasePruner(config, modules)

   Pruning Pruner.

   The class which executes pruning process.

   :param modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
   :param config: A config dict object that contains the pruner information.

   .. attribute:: modules

      A dict {"module_name": Tensor} that stores the pruning modules' weights.

   .. attribute:: config

      A config dict object that contains the pruner information.

   .. attribute:: masks

      A dict {"module_name": Tensor} that stores the masks for modules' weights.

   .. attribute:: scores

      A dict {"module_name": Tensor} that stores the score for modules' weights,
      which are used to determine what parts to be pruned by a criterion.

   .. attribute:: pattern

      A Pattern object defined in ./patterns.py

   .. attribute:: scheduler

      A scheduler object defined in ./scheduler.py

   .. attribute:: current_sparsity_ratio

      A float representing the current model's sparsity ratio; it is initialized to be zero.

   .. attribute:: global_step

      An integer representing the total steps the model has run.

   .. attribute:: start_step

      An integer representing when to trigger pruning process.

   .. attribute:: end_step

      An integer representing when to end pruning process.

   .. attribute:: pruning_frequency

      An integer representing the pruning frequency; it is valid when iterative
      pruning is enabled.

   .. attribute:: target_sparsity_ratio

      A float showing the final sparsity after pruning.

   .. attribute:: max_sparsity_ratio_per_op

      A float showing the maximum sparsity ratio for every module.

   .. py:method:: on_epoch_begin(epoch)

      Implement at the beginning of each epoch.


   .. py:method:: mask_weights()

      Apply masks to corresponding modules' weights.

      Weights are multipled with masks. This is the formal pruning process.


   .. py:method:: mask_weights_general(input_masks)

      Apply input masks to corresponding modules' weights.

      Weights are multipled with input_masks.

      :param input_masks: A dict {"module_name": Tensor} that stores the masks for modules' weights.


   .. py:method:: on_step_begin(local_step)

      Implement at the start of each step.


   .. py:method:: update_masks(local_step)

      Update the masks at a given local step.


   .. py:method:: on_epoch_end()

      Implement at the end of each epoch.


   .. py:method:: on_step_end()

      Implement at the end of each step.


   .. py:method:: on_before_optimizer_step()

      Implement before optimizer.step().


   .. py:method:: on_after_optimizer_step()

      Implement after optimizer.step().

      Prune the model after optimization.


   .. py:method:: on_train_begin()

      Implement at the beginning of training phase.


   .. py:method:: on_train_end()

      Implement at the end of training phase.


   .. py:method:: on_before_eval()

      Implement at the beginning of evaluation phase.


   .. py:method:: on_after_eval()

      Implement at the end of evaluation phase.


   .. py:method:: check_is_pruned_step(step)

      Check if a pruning process should be performed at the current step.

      :param step: an integer representing the number of current step.

      :returns: A Boolean.



.. py:class:: BasicPruner(config, modules)

   Bases: :py:obj:`BasePruner`

   Pruning Pruner.

   The class which executes pruning process.
   1. Defines pruning functions called at step begin/end, epoch begin/end.
   2. Defines the pruning criterion.

   :param modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
   :param config: A config dict object that contains the pruner information.

   .. attribute:: pattern

      A Pattern object that defines pruning weights' arrangements within space.

   .. attribute:: criterion

      A Criterion Object that defines which weights are to be pruned

   .. attribute:: scheduler

      A Scheduler object that defines how the model's sparsity changes as training/pruning proceeds.

   .. attribute:: reg

      A Reg object that defines regulization terms.

   .. py:method:: set_global_step(global_step)

      Set global step number.


   .. py:method:: update_masks(local_step)

      Update the masks at a given local step.


   .. py:method:: on_before_optimizer_step()

      Implement before optimizer.step().


   .. py:method:: on_after_optimizer_step()

      Prune the model after optimization.



.. py:class:: PatternLockPruner(config, modules)

   Bases: :py:obj:`BasePruner`

   Pruning Pruner.

   A Pruner class derived from BasePruner.
   In this pruner, original model's sparsity pattern will be fixed while training.
   This pruner is useful when a user trains a sparse model without changing its original structure.

   :param modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
   :param config: A config dict object that contains the pruner information.

   .. attribute:: Inherit from parent class Pruner.

      

   .. py:method:: update_masks(local_step)

      Update the masks at a given local step.


   .. py:method:: on_after_optimizer_step()

      Implement after optimizer.step().

      Prune the model after optimization.



.. py:class:: ProgressivePruner(config, modules)

   Bases: :py:obj:`BasicPruner`

   Pruning Pruner.

   A Pruner class derived from BasePruner. In this pruner, mask interpolation will be applied.
   Mask interpolation is a fine-grained improvement for NxM structured pruning by adding interval
       masks between masks of two pruning steps.

   :param modules: A dict {"module_name": Tensor} that stores the pruning modules' weights.
   :param config: A config dict object that contains the pruner information.

   .. attribute:: Inherit from parent class Pruner.

      

   .. py:method:: check_progressive_validity()

      Check if the settings of progressive pruning are valid.


   .. py:method:: check_is_pruned_progressive_step(step)

      Check if a progressive pruning process should be performed at the current step.

      :param step: an integer representing the number of current step.

      :returns: A Boolean.


   .. py:method:: update_masks_progressive(local_step)

      Update the masks in progressive pruning mode at a given local step.


   .. py:method:: on_step_begin(local_step)

      Update the masks at a given local_step.

      Implement at the start of each step.


   .. py:method:: on_before_optimizer_step()

      Implement before optimizer.step().


   .. py:method:: on_after_optimizer_step()

      Prune the model after optimization.


   .. py:method:: print_progressive_sparsity()

      Output the progressive sparsity.