neural_compressor.jax.utils.utility =================================== .. py:module:: neural_compressor.jax.utils.utility .. autoapi-nested-parse:: The utility functions and classes for JAX. Functions --------- .. autoapisummary:: neural_compressor.jax.utils.utility.check_backend neural_compressor.jax.utils.utility.add_fp8_support neural_compressor.jax.utils.utility.register_algo neural_compressor.jax.utils.utility.get_quantize_fun neural_compressor.jax.utils.utility.get_dequantize_fun neural_compressor.jax.utils.utility.get_scale neural_compressor.jax.utils.utility.get_q_params neural_compressor.jax.utils.utility.print_model neural_compressor.jax.utils.utility.causal_lm_make_replace_generate_function neural_compressor.jax.utils.utility.iterate_over_layers neural_compressor.jax.utils.utility.verify_api Module Contents --------------- .. py:function:: check_backend(raise_error=True) Check if the current Keras backend is JAX and log a warning or error if not. .. py:function:: add_fp8_support(function) Extend a dtype size function to support FP8 dtypes. :param function: Function that returns the size of a dtype in bits. :type function: Callable :returns: Wrapped function that handles FP8 dtypes. :rtype: Callable .. py:function:: register_algo(name) Decorator function to register algorithms in the algos_mapping dictionary. Usage example: @register_algo(name=example_algo) def example_algo(model: tf.keras.Model, quant_config: StaticQuantConfig) -> tf.keras.Model: ... :param name: The name under which the algorithm function will be registered. :type name: str :returns: The decorator function to be used with algorithm functions. :rtype: decorator .. py:function:: get_quantize_fun(dtype=ml_dtypes.float8_e4m3, asymmetric=False) Create a quantization function for the specified dtype. :param dtype: Target quantization dtype. :type dtype: jnp.dtype :param asymmetric: Whether to use asymmetric quantization for integer dtypes. :type asymmetric: bool :returns: Quantization function that maps tensors to the target dtype. :rtype: Callable .. py:function:: get_dequantize_fun(dtype=jnp.float32, asymmetric=False) Create a dequantization function for the specified dtype. :param dtype: Output dtype after dequantization. :type dtype: jnp.dtype :param asymmetric: Whether to use asymmetric dequantization. :type asymmetric: bool :returns: Function that dequantizes tensors. :rtype: Callable .. py:function:: get_scale(orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32) Compute the quantization scale for a weight tensor. :param orig_weight: Weight tensor to analyze. :type orig_weight: jnp.ndarray :param dtype: Target quantized dtype. :type dtype: jnp.dtype :param compute_dtype: dtype for scale computation. :type compute_dtype: jnp.dtype :returns: Computed scale tensor. :rtype: jnp.ndarray .. py:function:: get_q_params(orig_weight, dtype=ml_dtypes.float8_e4m3, compute_dtype=jnp.float32, asymmetric=False) Compute quantization scale and zero-point for a weight tensor. :param orig_weight: Weight tensor to analyze. :type orig_weight: jnp.ndarray :param dtype: Target quantized dtype. :type dtype: jnp.dtype :param compute_dtype: dtype for scale computation. :type compute_dtype: jnp.dtype :param asymmetric: Whether to compute asymmetric quantization parameters. :type asymmetric: bool :returns: Scale and zero-point. Zero-point is `None` for floating-point dtypes or symmetric quantization. :rtype: Tuple[jnp.ndarray, Optional[jnp.ndarray]] .. py:function:: print_model(container, max_lines=999999, internal=True, str_length=(0, 0), path='') Print the model structure. :param container: The model or layer to be printed. :type container: keras.Model :param max_lines: The maximum number of elements to print. :type max_lines: int :param internal: Whether to print layers from internal _layers (True) or public layers API (False). :type internal: bool :param str_length: Tuple with max lengths for class name and path. :type str_length: Tuple[int, int] :param path: Prefix path for the current layer. :type path: str :returns: Logs model structure via the logger. :rtype: None .. py:function:: causal_lm_make_replace_generate_function(self, revert=False) Replace generate function for calibration and restore on demand. :param self: Causal language model instance to modify. :type self: keras.Model :param revert: When True, restore the original generate function. :type revert: bool :returns: Updated generate function. :rtype: Callable .. py:function:: iterate_over_layers(model, operations, /, *, filter_function: Optional[Callable] = lambda _: True) Apply operations to model layers matching the filter function. :param model: Keras model with a _flatten_layers iterator. :type model: keras.Model :param operations: Operations to apply to each layer. :type operations: Iterable[Callable] :param filter_function: Predicate to select layers. Defaults to always True. :type filter_function: Callable, optional :returns: The original model after operations have been applied. :rtype: keras.Model .. py:function:: verify_api(orig_cls, quant_cls, method_name) Check if quantized layer method API matches original layer method API. :param orig_cls: Original layer class. :type orig_cls: type :param quant_cls: Quantized layer class. :type quant_cls: type :param method_name: Method name to compare. :type method_name: str :returns: Logs an error if the method signatures differ. :rtype: None