PyTorch Quantization-Aware Training (QAT)

This document describes how to use Quantization-Aware Training (QAT) in Intel Neural Compressor (INC) to achieve high-accuracy, hardware-friendly quantization for large language models (LLMs). QAT explicitly simulates quantization during training, significantly narrowing the accuracy gap compared to pure Post-Training Quantization (PTQ), especially under aggressive compression (e.g., 4-bit).


Contents

  1. Overview

  2. QAT Workflow and Benefits

  3. Quick Start

  4. Core Components

  5. Configuration and Key Parameters

  6. Best Practices and Troubleshooting

  7. References


Overview

Quantization-Aware Training (QAT) is a training strategy that inserts “fake quantization” operators into the model during training. The forward pass mimics the quantized inference behavior, while gradients are still propagated in floating point. This allows the model to adapt to quantization noise and recover most of the accuracy that may be lost with PTQ-only workflows.

In this repository, QAT is integrated with:

  • PyTorch and Hugging Face Transformers

  • AutoRound for quantization schemes and kernels

  • Microscaling (MX) formats (e.g., MXFP4, MXFP8) for efficient deployment

Although this document focuses on LLMs (e.g., Llama 3.x), the same concepts can be extended to other architectures.


QAT Workflow and Benefits

High-Level Workflow

A typical QAT pipeline consists of the following stages:

  1. Train or Fine-Tune a Baseline Model

    • Train or fine-tune the model in FP32/BF16 to obtain a strong baseline.

    • Example: fine-tune meta-llama/Llama-3.1-8B in BF16.

  2. Quantization-Aware Fine-Tuning (QAT)

    • Insert QAT modules into the model using prepare_qat.

    • Optionally load the PTQ model weights as initialization.

    • Fine-tune the quantized model with a small learning rate.

  3. Export and Deployment

    • Save the QAT model as a standard Hugging Face model directory.

    • Deploy with compatible inference engines (e.g., vLLM), or export using INC/AutoRound export utilities for specific runtimes.

Why QAT?

Compared with PTQ, QAT offers:

  1. Higher Accuracy Under Aggressive Compression

    • QAT significantly reduces accuracy degradation for low-bit formats (e.g., MXFP4).

  2. Realistic Simulation of Inference Behavior

    • Fake-quant modules simulate the exact quantization scheme that will be used at inference, including MX formats.

  3. Better Robustness on LLMs

    • LLMs are often highly sensitive to weight perturbations; QAT helps the model adapt to such changes.

  4. Flexible Integration

    • QAT is implemented via modular PyTorch components (QuantLinear, TensorQuantizer) that can be inserted into standard Transformer architectures.


Quick Start

This section walks through an end-to-end example.

1. Setup Environment

From the llm_qat directory:

pip install -r requirements.txt

requirements.txt includes (among others):

  • auto-round==0.9.3

  • neural-compressor-pt==3.7

  • transformers==4.53.0

  • accelerate

  • datasets

  • lm-eval


2. Baseline Fine-Tuning (BF16)

You can fine-tune a BF16 model using FSDP for Llama 3.1 as follows:

accelerate launch --config-file accelerate_config/fsdp1.yaml \
  main.py \
  --model_name_or_path meta-llama/Llama-3.1-8B \
  --model_max_length 4096 \
  --dataloader_drop_last True \
  --do_train True \
  --do_eval True \
  --output_dir ./llama3.1-finetuned \
  --dataset Daring-Anteater \
  --num_train_epochs 2.0 \
  --per_device_train_batch_size 4 \
  --per_device_eval_batch_size 4 \
  --gradient_accumulation_steps 1 \
  --eval_accumulation_steps 1 \
  --save_strategy steps \
  --save_steps 3000 \
  --eval_strategy steps \
  --eval_steps 3000 \
  --load_best_model_at_end True \
  --save_total_limit 2 \
  --learning_rate 1e-5 \
  --weight_decay 0.0 \
  --warmup_ratio 0.1 \
  --lr_scheduler_type linear \
  --logging_steps 1 \
  --report_to tensorboard

Key points (see main.py):

  • TrainingArguments are extended with model_max_length, bf16, etc.

  • Dataset is created via make_supervised_data_module (see utils.py).

  • FSDP configuration is controlled by accelerate_config/fsdp1.yaml or fsdp2.yaml.


3. QAT Fine-Tuning

The core QAT logic is driven from main.py using a QuantizationArguments dataclass:

@dataclass
class QuantizationArguments:
    quant_scheme: str | None = field(
        default=None,
        metadata={
            "help": "Specify the quantization format for PTQ/QAT. If specified, PTQ/QAT will be enabled.",
            "choices": ["MXFP8", "MXFP4"],
        },
    )

When --quant_scheme is provided, the model is prepared for QAT:

if quant_args.quant_scheme is not None:
    from neural_compressor.torch.quantization.quantize import prepare_qat

    model.train()
    if quant_args.quant_scheme == "MXFP8":
        # Default MXFP8 scheme
        prepare_qat(model)
    if quant_args.quant_scheme == "MXFP4":
        mappings = {torch.nn.Linear: "MXFP4"}
        prepare_qat(model, mappings)

    logger.info("Finish model preparation for QAT.")

Example: QAT with MXFP4

accelerate launch --config-file accelerate_config/fsdp1.yaml \
  main.py \
  --model_name_or_path ./llama3.1-finetuned \
  --model_max_length 4096 \
  --dataloader_drop_last True \
  --do_train True \
  --do_eval True \
  --quant_scheme MXFP4 \
  --output_dir ./llama3.1-finetuned-qat \
  --dataset Daring-Anteater \
  --max_steps 1000 \
  --per_device_train_batch_size 4 \
  --per_device_eval_batch_size 4 \
  --gradient_accumulation_steps 1 \
  --eval_accumulation_steps 1 \
  --save_strategy steps \
  --save_steps 3000 \
  --eval_strategy steps \
  --eval_steps 3000 \
  --load_best_model_at_end True \
  --save_total_limit 2 \
  --learning_rate 1e-5 \
  --weight_decay 0.0 \
  --warmup_ratio 0.03 \
  --lr_scheduler_type linear \
  --logging_steps 1 \
  --report_to tensorboard

4. Evaluation and Deployment

Once QAT training finishes, the model is saved in Hugging Face format via QATTrainer.save_model. You can evaluate it with vLLM and lm_eval:

lm_eval \
  --model vllm \
  --model_args pretrained=./llama3.1-finetuned-qat,\
tensor_parallel_size=1,data_parallel_size=1,\
gpu_memory_utilization=0.8,max_model_len=32768 \
  --tasks gsm8k \
  --batch_size 8

QATTrainer also ensures the correct dtype is written back to config.json after FSDP training, which helps downstream inference libraries choose the right precision.


Core Components

This section describes the key building blocks used by QAT under the hood.

1. prepare_qat and Module Replacement

The QAT preparation uses utility functions in neural_compressor/torch/algorithms/qat/quant_utils.py:

def replace_with_quant_linear(model, quant_cfg=None):
    """Recursively replace modules with quantized modules."""
    for name, child in model.named_children():
        if isinstance(child, nn.Linear):
            if "lm_head" in name:
                continue
            # Replace on the parent
            quantized = convert(child, quant_cfg, QuantLinear)
            setattr(model, name, quantized)
        replace_with_quant_linear(getattr(model, name), quant_cfg=quant_cfg)
    return model
  • All nn.Linear layers (except lm_head by default) are recursively replaced with QuantLinear.

  • The quantization configuration (quant_cfg) is retrieved from AutoRound schemes (preset_name_to_scheme).

2. QuantLinear

Defined in neural_compressor/torch/algorithms/qat/quant_linear.py:

class QuantLinear(nn.Module):
    """Quantized version of nn.Linear."""

    def forward(self, input: torch.Tensor):
        qw = self.weight_quantizer(self.weight)
        qi = self.input_quantizer(input)
        out = F.linear(qi, qw, self.bias)
        out = self.output_quantizer(out)
        return out

    def _setup(self, quant_cfg):
        self.weight_quantizer = TensorQuantizer(
            data_type=quant_cfg.data_type,
            block_size=quant_cfg.group_size,
            bits=quant_cfg.bits,
            sym=quant_cfg.sym,
            if_quant=True,
            learn_exponent=False,
        )
        self.input_quantizer = TensorQuantizer(
            data_type=quant_cfg.act_data_type,
            block_size=quant_cfg.act_group_size,
            bits=quant_cfg.act_bits,
            sym=quant_cfg.act_sym,
            if_quant=True,
            learn_exponent=False,
        )
        self.output_quantizer = TensorQuantizer(
            data_type=quant_cfg.act_data_type,
            block_size=quant_cfg.act_group_size,
            bits=quant_cfg.act_bits,
            sym=quant_cfg.act_sym,
            if_quant=False,
        )
        # Disable output quantization for now
        self.output_quantizer.disable()

Key points:

  • Weight quantizer: usually MXFP4 or MXFP8 with block-wise scaling.

  • Input quantizer: activation quantization; configurable via AutoRound scheme.

  • Output quantizer: currently disabled (acts as a pure passthrough). This can be extended later if full activation quantization is needed.

3. TensorQuantizer

Defined in neural_compressor/torch/algorithms/qat/tensor_quantizer.py, this module encapsulates the fake-quant logic:

class TensorQuantizer(nn.Module):
    def __init__(
        self,
        data_type="mx_fp8",
        bits=8,
        block_size=32,
        sym=True,
        if_quant=True,
        learn_exponent=False,
        amax=None,
        scale_shape=None,
        device=None,
    ):
        ...
        assert get_quant_func is not None, "The quantization function is imported from AutoRound, please install it."

        self.quant_func, self.data_type = get_quant_func(self.data_type, self.num_bits, self.sym)
        ...

In the forward pass:

def forward(self, inputs: torch.Tensor):
    if self._disabled or (not self._if_quant):
        self._input_dtype = inputs.dtype
        return inputs

    x = inputs.contiguous()
    if self.fake_quant:
        q = self._fake_quantize(x)[0]
    else:
        q = self._real_quantize(x)

    return q.to(inputs.dtype)
  • Uses AutoRound’s get_quant_func to obtain the proper quantizer for mx_fp4, mx_fp8, etc.

  • Supports block-wise exponent sharing (block_size) and optional saving of scales.

  • For MX formats, weight_pack can pack weights and E8M0 scales to efficient storage formats.

4. QAT-Specific Trainer (QATTrainer)

Defined in examples/.../llm_qat/utils.py:

class QATTrainer(Trainer):
    def save_model(self, *args, **kwargs):
        # Handle FSDP state-dict types and dtype rewriting
        ...
        if (not self.is_in_train) and self.args.should_save:
            out_dir = args[0]
            self._update_config_json_dtype(out_dir, str(self._original_dtype).split(".")[1])
  • Extends Hugging Face’s Trainer to handle:

    • FSDP full-state-dict export at the final checkpoint.

    • Ensuring that config.json contains the original model dtype (dtype or torch_dtype), which is important for downstream inference.


Configuration and Key Parameters

Command-Line Arguments (from main.py)

  1. ModelArguments

@dataclass
class ModelArguments:
    model_name_or_path: str = field(default="meta-llama/Llama-3.1-8B")
  1. TrainingArguments (extends transformers.TrainingArguments)

Key fields:

  • model_max_length: maximum sequence length (e.g., 2048, 4096).

  • dataloader_drop_last: whether to drop the last batch (useful for distributed training).

  • bf16: enable BF16 for training.

  1. DataArguments

@dataclass
class DataArguments:
    dataset: str = field(
        default="Daring-Anteater",
        metadata={"choices": ["Daring-Anteater", "cnn_dailymail"]},
    )
    train_size: int = 0  # 0 = default / automatic
    eval_size: int = 0

AutoRound Quantization Schemes

The mapping from scheme names to quantization configs is handled by AutoRound:

from auto_round.schemes import preset_name_to_scheme

quant_cfg = preset_name_to_scheme(scheme)

This quant_cfg is then used to initialize TensorQuantizer instances inside QuantLinear.

Detecting Quantization Format

get_quantization_format inspects layers to determine the applied MX format:

if weight_quantizer.num_bits == 8 and weight_quantizer.data_type == "mx_fp8":
    return "MXFP8"
if weight_quantizer.num_bits == 4 and weight_quantizer.data_type == "mx_fp4":
    return "MXFP4"

Best Practices and Troubleshooting

Memory / Performance Issues

Issue: Out-of-Memory (OOM) during QAT

  • Reduce per_device_train_batch_size and/or model_max_length.

  • Enable gradient checkpointing (--gradient_checkpointing True in HF args, and ensure gradient_checkpointing_kwargs is set when needed).

  • Use FSDP configuration (fsdp1.yaml or fsdp2.yaml) for sharded training:

    • fsdp_activation_checkpointing: true

    • fsdp_cpu_ram_efficient_loading: true

Issue: QAT Training Is Too Slow

  • Reduce max_steps or number of epochs for initial experiments.

  • Use fewer training samples (train_size).

  • Ensure BF16 mixed precision is enabled on supported hardware.

Accuracy Issues

Issue: Large Accuracy Drop After QAT

  • Increase the number of training steps or epochs.

  • Use a slightly higher precision scheme first (e.g., MXFP8), then transition to MXFP4 if needed.

  • Ensure the dataset used for QAT matches your target task distribution as much as possible.

Issue: Model Loading / Inference Errors

  • Verify the model directory contains:

    • config.json

    • pytorch_model.bin / model.safetensors

    • tokenizer files

  • Ensure your inference stack (e.g., vLLM) supports the quantization format produced by AutoRound/INC. Some runtimes may require additional export steps.


References