PyTorch Smooth Quantization

  1. Introduction

  2. Usage

  3. Supported Framework Matrix

Introduction

Quantization is a common compression operation to reduce memory and accelerate inference by converting the floating point matrix to an integer matrix. For large language models (LLMs) with gigantic parameters, the systematic outliers make quantification of activations difficult. SmoothQuant, a training free post-training quantization (PTQ) solution, offline migrates this difficulty from activations to weights with a mathematically equivalent transformation. Please refer to the document of Smooth Quant for detailed fundamental knowledge.

Usage

Fixed Alpha

To set a fixed alpha for the entire model, users can follow this example:

from neural_compressor.torch.quantization import SmoothQuantConfig, convert, prepare


def run_fn(model):
    model(example_inputs)


quant_config = SmoothQuantConfig(alpha=0.5)
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)

SmoothQuantConfig description:

alpha: a smooth factor to calculate the conversion per-channel scale and balance the quantization difficulty of activation and weight. Float value, default is 0.5.

Note: Alpha=”auto” and alpha auto-tuning was supported in old API, please stay tuned for the new API’s support for auto alpha.

Specify Quantization Rules

Intel(R) Neural Compressor support specify quantization rules by operator type for Smooth Quantization. Users can use set_local to fallback op type in SmoothQuantConfig to achieve the above purpose.

Here we don’t quantize Linear layers.

# fallback by op_type
quant_config.set_local("Linear", SmoothQuantConfig(w_dtype="fp32", act_dtype="fp32"))
prepared_model = prepare(model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)

To get more information, please refer to examples.

Supported Framework Matrix

Framework Alpha Folding
PyTorch [0-1] False
IPEX [0-1] True / False(Version>2.1)