:py:mod:`neural_compressor.compression.pruner.utils` ==================================================== .. py:module:: neural_compressor.compression.pruner.utils .. autoapi-nested-parse:: Prune utils. Module Contents --------------- Functions ~~~~~~~~~ .. autoapisummary:: neural_compressor.compression.pruner.utils.get_sparsity_ratio neural_compressor.compression.pruner.utils.get_sparsity_ratio_tf neural_compressor.compression.pruner.utils.check_config neural_compressor.compression.pruner.utils.reset_none_to_default neural_compressor.compression.pruner.utils.update_params neural_compressor.compression.pruner.utils.process_weight_config neural_compressor.compression.pruner.utils.process_yaml_config neural_compressor.compression.pruner.utils.check_key_validity neural_compressor.compression.pruner.utils.process_and_check_config neural_compressor.compression.pruner.utils.process_config neural_compressor.compression.pruner.utils.parse_last_linear neural_compressor.compression.pruner.utils.parse_last_linear_tf neural_compressor.compression.pruner.utils.parse_to_prune neural_compressor.compression.pruner.utils.parse_to_prune_tf neural_compressor.compression.pruner.utils.generate_pruner_config neural_compressor.compression.pruner.utils.get_layers neural_compressor.compression.pruner.utils.collect_layer_inputs .. py:function:: get_sparsity_ratio(pruners, model) Calculate sparsity ratio of a module/layer. :returns: Three floats. elementwise_over_matmul_gemm_conv refers to zero elements' ratio in pruning layers. elementwise_over_all refers to zero elements' ratio in all layers in the model. blockwise_over_matmul_gemm_conv refers to all-zero blocks' ratio in pruning layers. .. py:function:: get_sparsity_ratio_tf(pruners, model) Calculate sparsity ratio of a module/layer. :returns: Three floats. elementwise_over_matmul_gemm_conv refers to zero elements' ratio in pruning layers. elementwise_over_all refers to zero elements' ratio in all layers in the model. blockwise_over_matmul_gemm_conv refers to all-zero blocks' ratio in pruning layers. .. py:function:: check_config(prune_config) Check if the configuration dict is valid for running Pruning object. :param prune_config: A config dict object that contains Pruning parameters and configurations. :returns: None if everything is correct. :raises AssertionError.: .. py:function:: reset_none_to_default(obj, key, default) Set undefined configurations to default values. :param obj: A dict{key: value} :param key: A string representing the key in obj. :param default: When the key is not in obj, add key by the default item in original obj. .. py:function:: update_params(info) Update parameters. .. py:function:: process_weight_config(global_config, local_configs, default_config) Process pruning configurations. :param global_config: A config dict object that contains pruning parameters and configurations. :param local_config: A config dict object that contains pruning parameters and configurations. :param default_config: A config dict object that contains pruning parameters and configurations. :returns: A config dict object that contains pruning parameters and configurations. :rtype: pruners_info .. py:function:: process_yaml_config(global_config, local_configs, default_config) Process the yaml configuration file. :param global_config: A config dict object that contains pruning parameters and configurations. :param local_config: A config dict object that contains pruning parameters and configurations. :param default_config: A config dict object that contains pruning parameters and configurations. :returns: A config dict object that contains pruning parameters and configurations. :rtype: pruners_info .. py:function:: check_key_validity(template_config, user_config) Check the validity of keys. :param template_config: A default config dict object that contains pruning parameters and configurations. :param user_config: A user config dict object that contains pruning parameters and configurations. .. py:function:: process_and_check_config(val) Process and check configurations. :param val: A dict that contains the layer-specific pruning configurations. .. py:function:: process_config(config) Obtain a config dict object from the config file. :param config: A string representing the path to the configuration file. :returns: A config dict object. .. py:function:: parse_last_linear(model) Locate the last linear layers of the model. While pruning, the final linear often acts like classifier head, which might cause accuracy drop. :param model: The model to be pruned. .. py:function:: parse_last_linear_tf(model) Locate the last linear layers of the model. While pruning, the final linear often acts like classifier head, which might cause accuracy drop. :param model: The model to be pruned. :type model: tf.keras.Model .. py:function:: parse_to_prune(config, model) Keep target pruned layers. :param config: A string representing the path to the configuration file. :param model: The model to be pruned. .. py:function:: parse_to_prune_tf(config, model) Keep target pruned layers. :param config: A string representing the path to the configuration file. :type config: string :param model: The model to be pruned. :type model: tf.keras.Model .. py:function:: generate_pruner_config(info) Generate pruner config object from prune information. :param info: A dotdict that saves prune information. :returns: A pruner config object. :rtype: pruner .. py:function:: get_layers(model) Get each layer's name and its module. :param model: The model to be pruned. Returns: each layer's name and its modules .. py:function:: collect_layer_inputs(model, layers, layer_idx, layer_inputs, device='cuda:0') Getting the forward input of a layer. :param model: The model to be pruned. :param layers: Selectable layers of the model. :param layer_idx: The layer index. :param layer_inputs: The dataloader or the output of the previous layer. :param device: Specify the type of device to return. Returns: input list.