Pruning ============ ## Introduction Network pruning is one of popular approaches of network compression, which removes the least important parameters in the network to achieve compact architectures with minimal accuracy drop. ## Pruning Types - Unstructured Pruning Unstructured pruning means finding and removing the less salient connection in the model where the nonzero patterns are irregular and could be anywhere in the matrix. - Structured Pruning Structured pruning means finding parameters in groups, deleting entire blocks, filters, or channels according to some pruning criterions. ## Pruning Algorithms
Pruning Type Pruning Granularity Pruning Algorithm Framework
Unstructured Pruning Element-wise Magnitude PyTorch, TensorFlow
Pattern Lock PyTorch
Structured Pruning Filter/Channel-wise Gradient Sensitivity PyTorch
Block-wise Group Lasso PyTorch
Element-wise Pattern Lock PyTorch
- Magnitude - The algorithm prunes the weight by the lowest absolute value at each layer with given sparsity target. - Gradient sensitivity - The algorithm prunes the head, intermediate layers, and hidden states in NLP model according to importance score calculated by following the paper [FastFormers](https://arxiv.org/abs/2010.13382). - Group Lasso - The algorithm uses Group lasso regularization to prune entire rows, columns or blocks of parameters that result in a smaller dense network. - Pattern Lock - The algorithm locks the sparsity pattern in fine tune phase by freezing those zero values of weight tensor during weight update of training. ## Pruning API ### User facing API Neural Compressor pruning API is defined under `neural_compressor.experimental.Pruning`, which takes a user defined yaml file as input. The user defined yaml defines training, pruning and evaluation behaviors. [API Readme](../docs/pruning_api.md). ### Usage 1: Launch pruning with user-defined yaml #### Launcher code Below is the launcher code if training behavior is defined in user-defined yaml. ``` from neural_compressor.experimental import Pruning prune = Pruning('/path/to/user/pruning/yaml') prune.model = model model = prune.fit() ``` #### User-defined yaml The user-defined yaml follows below syntax, note `train` section is optional if user implements `pruning_func` and sets to `pruning_func` attribute of pruning instance. User could refer to [the yaml template file](../docs/pruning.yaml) to know field meanings. ##### `train` The `train` section defines the training behavior, including what training hyper-parameter would be used and which dataloader is used during training. ##### `approach` The `approach` section defines which pruning algorithm is used and how to apply it during training process. - ``weight compression``: pruning target, currently only ``weight compression`` is supported. ``weight compression`` means zeroing the weight matrix. The parameters for `weight compression` is divided into global parameters and local parameters in different ``pruners``. Global parameters may contain `start_epoch`, `end_epoch`, `initial_sparsity`, `target_sparsity` and `frequency`. - `start_epoch`: on which epoch pruning begins - `end_epoch`: on which epoch pruning ends - `initial_sparsity`: initial sparsity goal, default 0. - `target_sparsity`: target sparsity goal - `frequency`: frequency to updating sparsity - `Pruner`: - `prune_type`: pruning algorithm, currently ``basic_magnitude``, ``gradient_sensitivity`` and ``group_lasso``are supported. - `names`: weight name to be pruned. If no weight is specified, all weights of the model will be pruned. - `parameters`: Additional parameters is required ``gradient_sensitivity`` prune_type, which is defined in ``parameters`` field. Those parameters determined how a weight is pruned, including the pruning target and the calculation of weight's importance. It contains: - `target`: the pruning target for weight, will override global config `target_sparsity` if set. - `stride`: each stride of the pruned weight. - `transpose`: whether to transpose weight before prune. - `normalize`: whether to normalize the calculated importance. - `index`: the index of calculated importance. - `importance_inputs`: inputs of the importance calculation for weight. - `importance_metric`: the metric used in importance calculation, currently ``abs_gradient`` and ``weighted_gradient`` are supported. Take above as an example, if we assume the 'bert.encoder.layer.0.attention.output.dense.weight' is the shape of [N, 12\*64]. The target 8 and stride 64 is used to control the pruned weight shape to be [N, 8\*64]. `Transpose` set to True indicates the weight is pruned at dim 1 and should be transposed to [12\*64, N] before pruning. `importance_input` and `importance_metric` specify the actual input and metric to calculate importance matrix. ### Usage 2: Launch pruning with user-defined pruning function #### Launcher code In this case, the launcher code is like the following: ```python from neural_compressor.experimental import Pruning, common prune = Pruning(args.config) prune.model = model prune.train_func = pruning_func model = prune.fit() ``` #### User-defined pruning function User can pass the customized training/evaluation functions to `Pruning` for flexible scenarios. In this case, pruning process can be done by pre-defined hooks in Neural Compressor. User needs to put those hooks inside the training function. Neural Compressor defines several hooks for user use: ``` on_epoch_begin(epoch) : Hook executed at each epoch beginning on_step_begin(batch) : Hook executed at each batch beginning on_step_end() : Hook executed at each batch end on_epoch_end() : Hook executed at each epoch end on_before_optimizer_step() : Hook executed after gradients calculated and before backward ``` Following section shows how to use hooks in user pass-in training function which is part of example from BERT training: ```python def pruning_func(model): for epoch in range(int(args.num_train_epochs)): pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') model.train() prune.on_epoch_begin(epoch) for step, batch in enumerate(train_dataloader): prune.on_step_begin(step) batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3]} #inputs['token_type_ids'] = batch[2] outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) if (step + 1) % args.gradient_accumulation_steps == 0: prune.on_before_optimizer_step() optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() prune.on_step_end() ... ``` ## Examples For related examples, please refer to [Pruning examples](../examples/README.md).