Microscaling Quantization
Introduction
Numerous breakthroughs have emerged across various fields, such as text analysis, language translation and chatbot technologies, fueled by the development of large language models (LLMs). Nevertheless, their increasing power comes with the challenge of explosive growth in parameters, posing obstacles for practical use. To balance memory limits and accuracy preservation for AI models, the Microscaling (MX) specification was promoted from the well-known Microsoft Floating Point (MSFP) data type [1, 2]:
| Format Name | Element Data type | Element Bits | Scaling Block Size | Scale Data Type | Scale Bits |
|---|---|---|---|---|---|
| MXFP8 | FP8 (E5M2) | 8 | 32 | E8M0 | 8 |
| FP8 (E4M3) | |||||
| MXFP6 | FP6 (E3M2) | 6 | 32 | E8M0 | 8 |
| FP6 (E2M3) | |||||
| MXFP4 | FP4 (E2M1) | 4 | 32 | E8M0 | 8 |
| MXINT8 | INT8 | 8 | 32 | E8M0 | 8 |
At an equivalent accuracy level, the MX data type demonstrates the ability to occupy a smaller area and incur lower energy costs for multiply-accumulate compared to other conventional data types on the same silicon [1].
Neural Compressor seamlessly applies the MX data type to post-training quantization, offering meticulously crafted recipes to empower users to quantize LLMs without sacrificing accuracy. The workflow is shown as below.
The memory and computational limits of LLMs are more severe than other general neural networks, so our exploration focuses on LLMs first. The following table shows the basic MX quantization recipes in Neural Compressor and enumerates distinctions among various data types. The MX data type replaces general float scale with powers of two to be more hardware-friendly.
| MX Format | INT8 | FP8 | |
|---|---|---|---|
| Scale | $2^{exp}$ | $\frac{MAX}{amax}$ | $\frac{MAX}{amax}$ |
| Zero point | 0 (None) | $2^{bits - 1}$ or $-min * scale$ | 0 (None) |
| Granularity | per-block (default blocksize is 32) | per-channel or per-tensor | per-channel or per-tensor |
The exponent (exp) is equal to clamp(floor(log2(amax)) - maxExp, -127, 127), MAX is the representation range of the data type, amax is the max absolute value of per-block tensor, and rmin is the minimum value of the per-block tensor.
Get Started with Microscaling Quantization API
To get a model quantized with Microscaling Data Types, users can use the AutoRound Quantization API as follows.
Basic Usage
The following example demonstrates how to quantize a model using MX data types:
from neural_compressor.torch.quantization import AutoRoundConfig, prepare, convert
from transformers import AutoModelForCausalLM, AutoTokenizer
fp32_model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", trust_remote_code=True)
output_dir = "./saved_inc"
# quantization configuration
quant_config = AutoRoundConfig(
tokenizer=tokenizer, # Tokenizer for processing calibration data
nsamples=32, # Number of calibration samples (default: 128)
seqlen=32, # Sequence length of calibration data (default: 2048)
iters=20, # Number of optimization iterations (default: 200)
scheme="MXFP4", # MX quantization scheme: "MXFP4", "MXFP8"
export_format="auto_round", # Export format for the quantized model
output_dir=output_dir, # Directory to save the quantized model (default: "temp_auto_round")
)
# quantize the model and save to output_dir
model = prepare(model=fp32_model, quant_config=quant_config)
model = convert(model)
# loading
model = AutoModelForCausalLM.from_pretrained(output_dir, torch_dtype="auto", device_map="auto")
# inference
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=10)[0]))
Advantages of MX Quantization
Hardware-Friendly: Uses power-of-2 scaling factors for efficient hardware implementation
Fine-Grained Quantization: Per-block scaling (block size = 32) provides better accuracy than per-tensor or per-channel methods
Zero-Point Free: No zero-point overhead, simplifying computation
Memory Efficient: Significantly reduces model size while maintaining competitive accuracy
Energy Efficient: Lower energy consumption for multiply-accumulate operations compared to traditional data types
Mix Precision (MXFP4 + MXFP8)
To achieve optimal compression ratios with acceptable accuracy, we integrate AutoRound automatic mix-precision algorithm. The mix-precision approach combines MXFP4 and MXFP8 formats to quantize different layers of the model based on their sensitivity to quantization.
Benefits of Mix Precision
Better Accuracy-Compression Trade-off: Sensitive layers use MXFP8 (higher precision) while less sensitive layers use MXFP4 (higher compression), optimizing the overall model performance.
Flexible Configuration: Users can customize the precision assignment strategy based on their specific accuracy and compression requirements.
Automatic Layer Selection: The AutoRound algorithm automatically identifies which layers should use which precision level, reducing manual tuning effort.
Target Bits Configuration
To achieve optimal compression ratios in mixed-precision quantization, we provide the target_bits parameter for automated precision configuration.
Single target bit: If you pass a single float number, it will automatically generate an optimal quantization recipe to achieve that target average bit-width.
Multiple target bits: If you pass multiple float numbers, it will generate multiple recipes for different target bit-widths, allowing you to compare trade-offs between model size and accuracy.
Note: For MX data type, target_bits ranges from 4.25 to 8.25 due to scale bits overhead.
Usage Example
AutoTune with Multiple Target Bits
For automatically finding the best configuration across multiple target bits:
from neural_compressor.torch.quantization import AutoRoundConfig, autotune, TuningConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
fp32_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
# Define evaluation function
def eval_fn(model):
# Implement your evaluation logic here
# Return accuracy score
pass
# Configuration with multiple target bits
config = AutoRoundConfig(
tokenizer=tokenizer,
nsamples=128,
seqlen=2048,
iters=200,
target_bits=[7.2, 7.5, 7.8], # Try multiple target bits
options=["MXFP4", "MXFP8"],
shared_layers=[
["k_proj", "v_proj", "q_proj"],
["gate_proj", "up_proj"],
],
export_format="auto_round",
output_dir="./llama3.1-8B-MXFP4-MXFP8",
)
# AutoTune to find the best configuration
tuning_config = TuningConfig(config_set=[config], tolerable_loss=0.01)
model = autotune(fp32_model, tuning_config, eval_fn=eval_fn)
Key Parameters for Mix Precision
target_bits: Target average bit-width for the model. Can be a single float or a list of floats.
Single value: Generates one recipe for that specific target bit-width
Multiple values: Generates multiple recipes for comparison and selects the best one via autotune
options: List of available data types for mixed precision (e.g.,
["MXFP4", "MXFP8"])shared_layers: List of layer groups that should use the same precision. Each group is a list of layer name patterns.
Ensures architectural consistency (e.g., all attention projections use the same precision)
Improves model performance by maintaining balanced computation
tolerable_loss: Maximum acceptable accuracy loss compared to FP32 baseline (used with autotune)
Examples
PyTorch Examples
Multimodal Models: Llama-4-Scout-17B-16E-Instruct with MXFP4
Language Models: Llama3 series with MXFP4/MXFP8 and Mix Precision
Llama 3.1 8B: MXFP8, MXFP4, and Mix Precision (target_bits=7.8)
Llama 3.3 70B: MXFP8, MXFP4, and Mix Precision (target_bits=5.8)
Best Practices and Tips
Choosing the Right Data Type
| Data Type | Compression | Accuracy | Use Case | Export Format |
|---|---|---|---|---|
| MXFP8 | Moderate (8-bit) | High | Production models where accuracy is critical | auto_round |
| MXFP4 | High (4-bit) | Moderate | Aggressive compression with acceptable accuracy loss | auto_round |
| MXFP4+MXFP8 Mix | Configurable (4.25-8.25 bits) | High | Best balance between compression and accuracy | auto_round |
Common Issues and Solutions
Issue: Out of Memory (OOM) during quantization
Solution: Use
low_gpu_mem_usage=True, enableenable_torch_compile, reducensamples, or use smallerseqlen
Issue: Accuracy drop is too large
Solution: Increase
iters, use morensamples, or try mixed precision with highertarget_bits
Issue: Quantization is too slow
Solution: Reduce
itersor set to 0 for RTN, decreasensamples, enableenable_torch_compile
Issue: Model loading fails after quantization
Solution: Refer to auto_round/llama3/inference
Reference
[1]: Darvish Rouhani, Bita, et al. “Pushing the limits of narrow precision inferencing at cloud scale with microsoft floating point.” Advances in neural information processing systems 33 (2020): 10271-10281
[2]: OCP Microscaling Formats (MX) Specification
[3]: Rouhani, Bita Darvish, et al. “Microscaling Data Formats for Deep Learning.” arXiv preprint arXiv:2310.10537 (2023).