neural_compressor.jax.utils.utility

The utility functions and classes for JAX.

Functions

check_backend([raise_error])

Check if the current Keras backend is JAX and log a warning or error if not.

add_fp8_support(function)

Extend a dtype size function to support FP8 dtypes.

register_algo(name)

Decorator function to register algorithms in the algos_mapping dictionary.

get_quantize_fun([dtype, asymmetric])

Create a quantization function for the specified dtype.

get_dequantize_fun([dtype, asymmetric])

Create a dequantization function for the specified dtype.

get_scale(orig_weight[, dtype, compute_dtype])

Compute the quantization scale for a weight tensor.

get_q_params(orig_weight[, dtype, compute_dtype, ...])

Compute quantization scale and zero-point for a weight tensor.

print_model(container[, max_lines, internal, ...])

Print the model structure.

causal_lm_make_replace_generate_function(self[, revert])

Replace generate function for calibration and restore on demand.

iterate_over_layers(model, operations, /, *[, ...])

Apply operations to model layers matching the filter function.

verify_api(orig_cls, quant_cls, method_name)

Check if quantized layer method API matches original layer method API.

Module Contents

neural_compressor.jax.utils.utility.check_backend(raise_error=True)[source]

Check if the current Keras backend is JAX and log a warning or error if not.

neural_compressor.jax.utils.utility.add_fp8_support(function)[source]

Extend a dtype size function to support FP8 dtypes.

Parameters:

function (Callable) – Function that returns the size of a dtype in bits.

Returns:

Wrapped function that handles FP8 dtypes.

Return type:

Callable

neural_compressor.jax.utils.utility.register_algo(name)[source]

Decorator function to register algorithms in the algos_mapping dictionary.

Usage example:

@register_algo(name=example_algo) def example_algo(model: tf.keras.Model, quant_config: StaticQuantConfig) -> tf.keras.Model:

Parameters:

name (str) – The name under which the algorithm function will be registered.

Returns:

The decorator function to be used with algorithm functions.

Return type:

decorator

neural_compressor.jax.utils.utility.get_quantize_fun(dtype=ml_dtypes.float8_e4m3, asymmetric=False)[source]

Create a quantization function for the specified dtype.

Parameters:
  • dtype (jnp.dtype) – Target quantization dtype.

  • asymmetric (bool) – Whether to use asymmetric quantization for integer dtypes.

Returns:

Quantization function that maps tensors to the target dtype.

Return type:

Callable

neural_compressor.jax.utils.utility.get_dequantize_fun(dtype=jnp.float32, asymmetric=False)[source]

Create a dequantization function for the specified dtype.

Parameters:
  • dtype (jnp.dtype) – Output dtype after dequantization.

  • asymmetric (bool) – Whether to use asymmetric dequantization.

Returns:

Function that dequantizes tensors.

Return type:

Callable

neural_compressor.jax.utils.utility.get_scale(orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32)[source]

Compute the quantization scale for a weight tensor.

Parameters:
  • orig_weight (jnp.ndarray) – Weight tensor to analyze.

  • dtype (jnp.dtype) – Target quantized dtype.

  • compute_dtype (jnp.dtype) – dtype for scale computation.

Returns:

Computed scale tensor.

Return type:

jnp.ndarray

neural_compressor.jax.utils.utility.get_q_params(orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32, asymmetric=False)[source]

Compute quantization scale and zero-point for a weight tensor.

Parameters:
  • orig_weight (jnp.ndarray) – Weight tensor to analyze.

  • dtype (jnp.dtype) – Target quantized dtype.

  • compute_dtype (jnp.dtype) – dtype for scale computation.

  • asymmetric (bool) – Whether to compute asymmetric quantization parameters.

Returns:

Scale and zero-point. Zero-point is None for floating-point dtypes or symmetric quantization.

Return type:

Tuple[jnp.ndarray, Optional[jnp.ndarray]]

neural_compressor.jax.utils.utility.print_model(container, max_lines=999999, internal=True, str_length=(0, 0), path='')[source]

Print the model structure.

Parameters:
  • container (keras.Model) – The model or layer to be printed.

  • max_lines (int) – The maximum number of elements to print.

  • internal (bool) – Whether to print layers from internal _layers (True) or public layers API (False).

  • str_length (Tuple[int, int]) – Tuple with max lengths for class name and path.

  • path (str) – Prefix path for the current layer.

Returns:

Logs model structure via the logger.

Return type:

None

neural_compressor.jax.utils.utility.causal_lm_make_replace_generate_function(self, revert=False)[source]

Replace generate function for calibration and restore on demand.

Parameters:
  • self (keras.Model) – Causal language model instance to modify.

  • revert (bool) – When True, restore the original generate function.

Returns:

Updated generate function.

Return type:

Callable

neural_compressor.jax.utils.utility.iterate_over_layers(model, operations, /, *, filter_function: Callable | None = lambda _: ...)[source]

Apply operations to model layers matching the filter function.

Parameters:
  • model (keras.Model) – Keras model with a _flatten_layers iterator.

  • operations (Iterable[Callable]) – Operations to apply to each layer.

  • filter_function (Callable, optional) – Predicate to select layers. Defaults to always True.

Returns:

The original model after operations have been applied.

Return type:

keras.Model

neural_compressor.jax.utils.utility.verify_api(orig_cls, quant_cls, method_name)[source]

Check if quantized layer method API matches original layer method API.

Parameters:
  • orig_cls (type) – Original layer class.

  • quant_cls (type) – Quantized layer class.

  • method_name (str) – Method name to compare.

Returns:

Logs an error if the method signatures differ.

Return type:

None