neural_compressor.jax.quantization.config

The configs of algorithms for JAX.

Classes

OperatorConfig

Configuration pairing a quantization config with supported operators.

DynamicQuantConfig

Config class for JAX Dynamic quantization.

StaticQuantConfig

Config class for JAX Static quantization.

Functions

get_all_registered_configs(→ Dict[str, ...)

Get all registered configs for JAX framework.

get_default_dynamic_config(→ DynamicQuantConfig)

Generate the default Dynamic quantization config.

get_default_static_config(→ StaticQuantConfig)

Generate the default Static quantization config.

Module Contents

class neural_compressor.jax.quantization.config.OperatorConfig[source]

Configuration pairing a quantization config with supported operators.

class neural_compressor.jax.quantization.config.DynamicQuantConfig(weight_dtype: str = 'fp8_e4m3', activation_dtype: str = 'fp8_e4m3', white_list: List[neural_compressor.common.base_config.OP_NAME_OR_MODULE_TYPE] | None = DEFAULT_WHITE_LIST)[source]

Config class for JAX Dynamic quantization.

Dynamic quantization applies quantization to both weights and activations during runtime. This configuration supports various data types for flexible quantization strategies.

Supported dtypes:
  • “fp8”: 8-bit floating-point quantization (uses ml_dtypes.float8_e4m3 by default)

  • “int8”: 8-bit integer quantization

FP8 formats available:
  • “fp8_e4m3”: 4 exponent bits, 3 mantissa bits (default for “fp8”)

  • “fp8_e5m2”: 5 exponent bits, 2 mantissa bits

class neural_compressor.jax.quantization.config.StaticQuantConfig(weight_dtype: str = 'fp8_e4m3', activation_dtype: str = 'fp8_e4m3', white_list: List[neural_compressor.common.base_config.OP_NAME_OR_MODULE_TYPE] | None = DEFAULT_WHITE_LIST)[source]

Config class for JAX Static quantization.

Static quantization applies quantization to weights offline and activations during runtime using pre-computed calibration data. This configuration supports various data types for flexible quantization strategies.

Supported dtypes:
  • “fp8”: 8-bit floating-point quantization (uses ml_dtypes.float8_e4m3 by default)

  • “int8”: 8-bit integer quantization

FP8 formats available:
  • “fp8_e4m3”: 4 exponent bits, 3 mantissa bits (default for “fp8”)

  • “fp8_e5m2”: 5 exponent bits, 2 mantissa bits

neural_compressor.jax.quantization.config.get_all_registered_configs() Dict[str, neural_compressor.common.base_config.BaseConfig][source]

Get all registered configs for JAX framework.

Returns:

Mapping of config names to config classes.

Return type:

Dict[str, BaseConfig]

neural_compressor.jax.quantization.config.get_default_dynamic_config() DynamicQuantConfig[source]

Generate the default Dynamic quantization config.

Returns:

The default JAX Dynamic quantization config.

Return type:

DynamicQuantConfig

neural_compressor.jax.quantization.config.get_default_static_config() StaticQuantConfig[source]

Generate the default Static quantization config.

Returns:

The default JAX Static quantization config.

Return type:

StaticQuantConfig