JAX
Introduction
neural_compressor.jax provides an API for applying quantization to Keras models such as ViT and Gemma3.
Since only JAX is supported as the Keras backend, the environment variable KERAS_BACKEND should be set to jax.
The following 8-bit floating-point formats are supported: fp8_e4m3 and fp8_e5m2.
Quantized models can be saved and loaded using standard Keras APIs (save_model and load_model) or Keras Hub APIs (save_to_preset and from_preset). This approach allows users to take advantage of pre-quantized models with minimal code change - just add one line:
import neural_compressor.jax
Quantization was developed primarily to improve the performance of Keras models on Intel® Xeon® processors, but it can potentially be used on other platforms as well.
Quantization API
def quantize_model(
model: keras.Model,
quant_config: BaseConfig,
calib_function: Callable = None,
inplace: bool = True
):
"""Return a quantized Keras model according to the given configuration.
Args:
model: FP32 Keras model to be quantized.
quant_config: Quantization configuration.
calib_function: Function used for model calibration, required for static quantization.
inplace: When True, the original model is modified in-place and should not be used
afterward. A value of False is not yet supported.
Returns:
The quantized model.
"""
Post-Training Static Quantization
The maximum absolute values of weights and activations are collected offline using a calibration dataset. This dataset should be representative of the data distribution expected during inference. The calibration process runs on the original FP32 model and records tensor distributions for scale calculations. Typically, preparing several dozen samples is sufficient for calibration.
Examples
Examples of how to quantize a model and use a pre-quantized model can be found below:
Backend and Device
Although Intel® Neural Compressor can run on any platform supporting 8-bit floating point with Keras using the JAX backend, performance improvements from quantization will be visible on Intel® Xeon® processors (with AMX-FP8 extension) with JAX version greater than v0.9 (see the full JAX releases page).
To enable performance improvements from quantization, certain JAX/XLA features must be enabled by setting the following environment variable:
export XLA_FLAGS="\
--xla_cpu_experimental_onednn_custom_call=true --xla_cpu_use_onednn=false \
--xla_cpu_experimental_ynn_fusion_type=invalid --xla_cpu_use_xnnpack=false \
--xla_backend_extra_options=xla_cpu_disable_new_fusion_emitter"
Without this flag, quantized model operates in fake quantization mode, where tensors are rounded to the specified FP8 format but computations are still performed in 32-bit floating-point format.