Microscaling Quantization

  1. Introduction

  2. Get Started with Microscaling Quantization API

  3. Examples

  4. Reference

Introduction

Numerous breakthroughs have emerged across various fields, such as text analysis, language translation and chatbot technologies, fueled by the development of large language models (LLMs). Nevertheless, their increasing power comes with the challenge of explosive growth in parameters, posing obstacles for practical use. To balance memory limits and accuracy preservation for AI models, the Microscaling (MX) specification was promoted from the well-known Microsoft Floating Point (MSFP) data type [1, 2]:

Format Name Element Data type Element Bits Scaling Block Size Scale Data Type Scale Bits
MXFP8 FP8 (E5M2) 8 32 E8M0 8
FP8 (E4M3)
MXFP6 FP6 (E3M2) 6 32 E8M0 8
FP6 (E2M3)
MXFP4 FP4 (E2M1) 4 32 E8M0 8
MXINT8 INT8 8 32 E8M0 8

At an equivalent accuracy level, the MX data type demonstrates the ability to occupy a smaller area and incur lower energy costs for multiply-accumulate compared to other conventional data types on the same silicon [1].

Neural Compressor seamlessly applies the MX data type to post-training quantization, offering meticulously crafted recipes to empower users to quantize LLMs without sacrificing accuracy. The workflow is shown as below.

Workflow of MX Quant (source [3])

The memory and computational limits of LLMs are more severe than other general neural networks, so our exploration focuses on LLMs first. The following table shows the basic MX quantization recipes in Neural Compressor and enumerates distinctions among various data types. The MX data type replaces general float scale with powers of two to be more hardware-friendly.

MX Format INT8 FP8
Scale $2^{exp}$ $\frac{MAX}{amax}$ $\frac{MAX}{amax}$
Zero point 0 (None) $2^{bits - 1}$ or $-min * scale$ 0 (None)
Granularity per-block (default blocksize is 32) per-channel or per-tensor per-channel or per-tensor

The exponent (exp) is equal to clamp(floor(log2(amax)) - maxExp, -127, 127), MAX is the representation range of the data type, amax is the max absolute value of per-block tensor, and rmin is the minimum value of the per-block tensor.

Get Started with Microscaling Quantization API

To get a model quantized with Microscaling Data Types, users can use the AutoRound Quantization API as follows.

Basic Usage

The following example demonstrates how to quantize a model using MX data types:

from neural_compressor.torch.quantization import AutoRoundConfig, prepare, convert
from transformers import AutoModelForCausalLM, AutoTokenizer

fp32_model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", trust_remote_code=True)
output_dir = "./saved_inc"

# quantization configuration
quant_config = AutoRoundConfig(
    tokenizer=tokenizer,  # Tokenizer for processing calibration data
    nsamples=32,  # Number of calibration samples (default: 128)
    seqlen=32,  # Sequence length of calibration data (default: 2048)
    iters=20,  # Number of optimization iterations (default: 200)
    scheme="MXFP4",  # MX quantization scheme: "MXFP4", "MXFP8"
    export_format="auto_round",  # Export format for the quantized model
    output_dir=output_dir,  # Directory to save the quantized model (default: "temp_auto_round")
)

# quantize the model and save to output_dir
model = prepare(model=fp32_model, quant_config=quant_config)
model = convert(model)

# loading
model = AutoModelForCausalLM.from_pretrained(output_dir, torch_dtype="auto", device_map="auto")

# inference
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=10)[0]))

Advantages of MX Quantization

  1. Hardware-Friendly: Uses power-of-2 scaling factors for efficient hardware implementation

  2. Fine-Grained Quantization: Per-block scaling (block size = 32) provides better accuracy than per-tensor or per-channel methods

  3. Zero-Point Free: No zero-point overhead, simplifying computation

  4. Memory Efficient: Significantly reduces model size while maintaining competitive accuracy

  5. Energy Efficient: Lower energy consumption for multiply-accumulate operations compared to traditional data types

Mix Precision (MXFP4 + MXFP8)

To achieve optimal compression ratios with acceptable accuracy, we integrate AutoRound automatic mix-precision algorithm. The mix-precision approach combines MXFP4 and MXFP8 formats to quantize different layers of the model based on their sensitivity to quantization.

Benefits of Mix Precision

  • Better Accuracy-Compression Trade-off: Sensitive layers use MXFP8 (higher precision) while less sensitive layers use MXFP4 (higher compression), optimizing the overall model performance.

  • Flexible Configuration: Users can customize the precision assignment strategy based on their specific accuracy and compression requirements.

  • Automatic Layer Selection: The AutoRound algorithm automatically identifies which layers should use which precision level, reducing manual tuning effort.

Target Bits Configuration

To achieve optimal compression ratios in mixed-precision quantization, we provide the target_bits parameter for automated precision configuration.

  • Single target bit: If you pass a single float number, it will automatically generate an optimal quantization recipe to achieve that target average bit-width.

  • Multiple target bits: If you pass multiple float numbers, it will generate multiple recipes for different target bit-widths, allowing you to compare trade-offs between model size and accuracy.

Note: For MX data type, target_bits ranges from 4.25 to 8.25 due to scale bits overhead.

Usage Example

AutoTune with Multiple Target Bits

For automatically finding the best configuration across multiple target bits:

from neural_compressor.torch.quantization import AutoRoundConfig, autotune, TuningConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

fp32_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")


# Define evaluation function
def eval_fn(model):
    # Implement your evaluation logic here
    # Return accuracy score
    pass


# Configuration with multiple target bits
config = AutoRoundConfig(
    tokenizer=tokenizer,
    nsamples=128,
    seqlen=2048,
    iters=200,
    target_bits=[7.2, 7.5, 7.8],  # Try multiple target bits
    options=["MXFP4", "MXFP8"],
    shared_layers=[
        ["k_proj", "v_proj", "q_proj"],
        ["gate_proj", "up_proj"],
    ],
    export_format="auto_round",
    output_dir="./llama3.1-8B-MXFP4-MXFP8",
)

# AutoTune to find the best configuration
tuning_config = TuningConfig(config_set=[config], tolerable_loss=0.01)
model = autotune(fp32_model, tuning_config, eval_fn=eval_fn)

Key Parameters for Mix Precision

  • target_bits: Target average bit-width for the model. Can be a single float or a list of floats.

    • Single value: Generates one recipe for that specific target bit-width

    • Multiple values: Generates multiple recipes for comparison and selects the best one via autotune

  • options: List of available data types for mixed precision (e.g., ["MXFP4", "MXFP8"])

  • shared_layers: List of layer groups that should use the same precision. Each group is a list of layer name patterns.

    • Ensures architectural consistency (e.g., all attention projections use the same precision)

    • Improves model performance by maintaining balanced computation

  • tolerable_loss: Maximum acceptable accuracy loss compared to FP32 baseline (used with autotune)

Examples

PyTorch Examples

Best Practices and Tips

Choosing the Right Data Type

Data Type Compression Accuracy Use Case Export Format
MXFP8 Moderate (8-bit) High Production models where accuracy is critical auto_round
MXFP4 High (4-bit) Moderate Aggressive compression with acceptable accuracy loss auto_round
MXFP4+MXFP8 Mix Configurable (4.25-8.25 bits) High Best balance between compression and accuracy auto_round

Common Issues and Solutions

Issue: Out of Memory (OOM) during quantization

  • Solution: Use low_gpu_mem_usage=True, enable enable_torch_compile, reduce nsamples, or use smaller seqlen

Issue: Accuracy drop is too large

  • Solution: Increase iters, use more nsamples, or try mixed precision with higher target_bits

Issue: Quantization is too slow

  • Solution: Reduce iters or set to 0 for RTN, decrease nsamples, enable enable_torch_compile

Issue: Model loading fails after quantization

Reference

[1]: Darvish Rouhani, Bita, et al. “Pushing the limits of narrow precision inferencing at cloud scale with microsoft floating point.” Advances in neural information processing systems 33 (2020): 10271-10281

[2]: OCP Microscaling Formats (MX) Specification

[3]: Rouhani, Bita Darvish, et al. “Microscaling Data Formats for Deep Learning.” arXiv preprint arXiv:2310.10537 (2023).