neural_compressor.jax.utils.utility
The utility functions and classes for JAX.
Functions
|
Check if the current Keras backend is JAX and log a warning or error if not. |
|
Extend a dtype size function to support FP8 dtypes. |
|
Decorator function to register algorithms in the algos_mapping dictionary. |
|
Create a quantization function for the specified dtype. |
|
Create a dequantization function for the specified dtype. |
|
Compute the quantization scale for a weight tensor. |
|
Compute quantization scale and zero-point for a weight tensor. |
|
Print the model structure. |
|
Replace generate function for calibration and restore on demand. |
|
Apply operations to model layers matching the filter function. |
|
Check if quantized layer method API matches original layer method API. |
Module Contents
- neural_compressor.jax.utils.utility.check_backend(raise_error=True)[source]
Check if the current Keras backend is JAX and log a warning or error if not.
- neural_compressor.jax.utils.utility.add_fp8_support(function)[source]
Extend a dtype size function to support FP8 dtypes.
- Parameters:
function (Callable) – Function that returns the size of a dtype in bits.
- Returns:
Wrapped function that handles FP8 dtypes.
- Return type:
Callable
- neural_compressor.jax.utils.utility.register_algo(name)[source]
Decorator function to register algorithms in the algos_mapping dictionary.
- Usage example:
@register_algo(name=example_algo) def example_algo(model: tf.keras.Model, quant_config: StaticQuantConfig) -> tf.keras.Model:
…
- Parameters:
name (str) – The name under which the algorithm function will be registered.
- Returns:
The decorator function to be used with algorithm functions.
- Return type:
decorator
- neural_compressor.jax.utils.utility.get_quantize_fun(dtype=ml_dtypes.float8_e4m3, asymmetric=False)[source]
Create a quantization function for the specified dtype.
- Parameters:
dtype (jnp.dtype) – Target quantization dtype.
asymmetric (bool) – Whether to use asymmetric quantization for integer dtypes.
- Returns:
Quantization function that maps tensors to the target dtype.
- Return type:
Callable
- neural_compressor.jax.utils.utility.get_dequantize_fun(dtype=jnp.float32, asymmetric=False)[source]
Create a dequantization function for the specified dtype.
- Parameters:
dtype (jnp.dtype) – Output dtype after dequantization.
asymmetric (bool) – Whether to use asymmetric dequantization.
- Returns:
Function that dequantizes tensors.
- Return type:
Callable
- neural_compressor.jax.utils.utility.get_scale(orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32)[source]
Compute the quantization scale for a weight tensor.
- Parameters:
orig_weight (jnp.ndarray) – Weight tensor to analyze.
dtype (jnp.dtype) – Target quantized dtype.
compute_dtype (jnp.dtype) – dtype for scale computation.
- Returns:
Computed scale tensor.
- Return type:
jnp.ndarray
- neural_compressor.jax.utils.utility.get_q_params(orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32, asymmetric=False)[source]
Compute quantization scale and zero-point for a weight tensor.
- Parameters:
orig_weight (jnp.ndarray) – Weight tensor to analyze.
dtype (jnp.dtype) – Target quantized dtype.
compute_dtype (jnp.dtype) – dtype for scale computation.
asymmetric (bool) – Whether to compute asymmetric quantization parameters.
- Returns:
Scale and zero-point. Zero-point is None for floating-point dtypes or symmetric quantization.
- Return type:
Tuple[jnp.ndarray, Optional[jnp.ndarray]]
- neural_compressor.jax.utils.utility.print_model(container, max_lines=999999, internal=True, str_length=(0, 0), path='')[source]
Print the model structure.
- Parameters:
container (keras.Model) – The model or layer to be printed.
max_lines (int) – The maximum number of elements to print.
internal (bool) – Whether to print layers from internal _layers (True) or public layers API (False).
str_length (Tuple[int, int]) – Tuple with max lengths for class name and path.
path (str) – Prefix path for the current layer.
- Returns:
Logs model structure via the logger.
- Return type:
None
- neural_compressor.jax.utils.utility.causal_lm_make_replace_generate_function(self, revert=False)[source]
Replace generate function for calibration and restore on demand.
- Parameters:
self (keras.Model) – Causal language model instance to modify.
revert (bool) – When True, restore the original generate function.
- Returns:
Updated generate function.
- Return type:
Callable
- neural_compressor.jax.utils.utility.iterate_over_layers(model, operations, /, *, filter_function: Callable | None = lambda _: ...)[source]
Apply operations to model layers matching the filter function.
- Parameters:
model (keras.Model) – Keras model with a _flatten_layers iterator.
operations (Iterable[Callable]) – Operations to apply to each layer.
filter_function (Callable, optional) – Predicate to select layers. Defaults to always True.
- Returns:
The original model after operations have been applied.
- Return type:
keras.Model
- neural_compressor.jax.utils.utility.verify_api(orig_cls, quant_cls, method_name)[source]
Check if quantized layer method API matches original layer method API.
- Parameters:
orig_cls (type) – Original layer class.
quant_cls (type) – Quantized layer class.
method_name (str) – Method name to compare.
- Returns:
Logs an error if the method signatures differ.
- Return type:
None