neural_compressor.jax.quantization.config ========================================= .. py:module:: neural_compressor.jax.quantization.config .. autoapi-nested-parse:: The configs of algorithms for JAX. Classes ------- .. autoapisummary:: neural_compressor.jax.quantization.config.OperatorConfig neural_compressor.jax.quantization.config.DynamicQuantConfig neural_compressor.jax.quantization.config.StaticQuantConfig Functions --------- .. autoapisummary:: neural_compressor.jax.quantization.config.get_all_registered_configs neural_compressor.jax.quantization.config.get_default_dynamic_config neural_compressor.jax.quantization.config.get_default_static_config Module Contents --------------- .. py:class:: OperatorConfig Configuration pairing a quantization config with supported operators. .. py:class:: DynamicQuantConfig(weight_dtype: str = 'fp8_e4m3', activation_dtype: str = 'fp8_e4m3', const_scale: bool = False, const_weight: bool = False, white_list: Optional[List[neural_compressor.common.base_config.OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST) Config class for JAX Dynamic quantization. Dynamic quantization applies quantization to both weights and activations during runtime. This configuration supports various data types for flexible quantization strategies. Supported dtypes: - "fp8": 8-bit floating-point quantization (uses ml_dtypes.float8_e4m3 by default) - "int8": 8-bit integer quantization FP8 formats available: - "fp8_e4m3": 4 exponent bits, 3 mantissa bits (default for "fp8") - "fp8_e5m2": 5 exponent bits, 2 mantissa bits .. py:class:: StaticQuantConfig(weight_dtype: str = 'fp8_e4m3', activation_dtype: str = 'fp8_e4m3', const_scale: bool = False, const_weight: bool = False, white_list: Optional[List[neural_compressor.common.base_config.OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST) Config class for JAX Static quantization. Static quantization applies quantization to weights offline and activations during runtime using pre-computed calibration data. This configuration supports various data types for flexible quantization strategies. Supported dtypes: - "fp8": 8-bit floating-point quantization (uses ml_dtypes.float8_e4m3 by default) - "int8": 8-bit integer quantization FP8 formats available: - "fp8_e4m3": 4 exponent bits, 3 mantissa bits (default for "fp8") - "fp8_e5m2": 5 exponent bits, 2 mantissa bits .. py:function:: get_all_registered_configs() -> Dict[str, neural_compressor.common.base_config.BaseConfig] Get all registered configs for JAX framework. :returns: Mapping of config names to config classes. :rtype: Dict[str, BaseConfig] .. py:function:: get_default_dynamic_config() -> DynamicQuantConfig Generate the default Dynamic quantization config. :returns: The default JAX Dynamic quantization config. :rtype: DynamicQuantConfig .. py:function:: get_default_static_config() -> StaticQuantConfig Generate the default Static quantization config. :returns: The default JAX Static quantization config. :rtype: StaticQuantConfig