neural_compressor.jax.quantization.config
The configs of algorithms for JAX.
Classes
Configuration pairing a quantization config with supported operators. |
|
Config class for JAX Dynamic quantization. |
|
Config class for JAX Static quantization. |
Functions
|
Get all registered configs for JAX framework. |
|
Generate the default Dynamic quantization config. |
|
Generate the default Static quantization config. |
Module Contents
- class neural_compressor.jax.quantization.config.OperatorConfig[source]
Configuration pairing a quantization config with supported operators.
- class neural_compressor.jax.quantization.config.DynamicQuantConfig(weight_dtype: str = 'fp8_e4m3', activation_dtype: str = 'fp8_e4m3', white_list: List[neural_compressor.common.base_config.OP_NAME_OR_MODULE_TYPE] | None = DEFAULT_WHITE_LIST)[source]
Config class for JAX Dynamic quantization.
Dynamic quantization applies quantization to both weights and activations during runtime. This configuration supports various data types for flexible quantization strategies.
- Supported dtypes:
“fp8”: 8-bit floating-point quantization (uses ml_dtypes.float8_e4m3 by default)
“int8”: 8-bit integer quantization
- FP8 formats available:
“fp8_e4m3”: 4 exponent bits, 3 mantissa bits (default for “fp8”)
“fp8_e5m2”: 5 exponent bits, 2 mantissa bits
- class neural_compressor.jax.quantization.config.StaticQuantConfig(weight_dtype: str = 'fp8_e4m3', activation_dtype: str = 'fp8_e4m3', white_list: List[neural_compressor.common.base_config.OP_NAME_OR_MODULE_TYPE] | None = DEFAULT_WHITE_LIST)[source]
Config class for JAX Static quantization.
Static quantization applies quantization to weights offline and activations during runtime using pre-computed calibration data. This configuration supports various data types for flexible quantization strategies.
- Supported dtypes:
“fp8”: 8-bit floating-point quantization (uses ml_dtypes.float8_e4m3 by default)
“int8”: 8-bit integer quantization
- FP8 formats available:
“fp8_e4m3”: 4 exponent bits, 3 mantissa bits (default for “fp8”)
“fp8_e5m2”: 5 exponent bits, 2 mantissa bits
- neural_compressor.jax.quantization.config.get_all_registered_configs() Dict[str, neural_compressor.common.base_config.BaseConfig][source]
Get all registered configs for JAX framework.
- Returns:
Mapping of config names to config classes.
- Return type:
Dict[str, BaseConfig]
- neural_compressor.jax.quantization.config.get_default_dynamic_config() DynamicQuantConfig[source]
Generate the default Dynamic quantization config.
- Returns:
The default JAX Dynamic quantization config.
- Return type:
- neural_compressor.jax.quantization.config.get_default_static_config() StaticQuantConfig[source]
Generate the default Static quantization config.
- Returns:
The default JAX Static quantization config.
- Return type: