API Documentation

General

ipex.optimize is generally used for generic PyTorch models.

ipex.optimize(model, dtype=None, optimizer=None, level='O1', inplace=False, conv_bn_folding=None, linear_bn_folding=None, weights_prepack=None, replace_dropout_with_identity=None, optimize_lstm=None, split_master_weight_for_bf16=None, fuse_update_step=None, auto_kernel_selection=None, sample_input=None, graph_mode=None, concat_linear=None)

Apply optimizations at Python frontend to the given model (nn.Module), as well as the given optimizer (optional). If the optimizer is given, optimizations will be applied for training. Otherwise, optimization will be applied for inference. Optimizations include conv+bn folding (for inference only), weight prepacking and so on.

Weight prepacking is a technique to accelerate performance of oneDNN operators. In order to achieve better vectorization and cache reuse, onednn uses a specific memory layout called blocked layout. Although the calculation itself with blocked layout is fast enough, from memory usage perspective it has drawbacks. Running with the blocked layout, oneDNN splits one or several dimensions of data into blocks with fixed size each time the operator is executed. More details information about oneDNN data mermory format is available at oneDNN manual. To reduce this overhead, data will be converted to predefined block shapes prior to the execution of oneDNN operator execution. In runtime, if the data shape matches oneDNN operator execution requirements, oneDNN won’t perform memory layout conversion but directly go to calculation. Through this methodology, called weight prepacking, it is possible to avoid runtime weight data format convertion and thus increase performance.

Parameters:
  • model (torch.nn.Module) – User model to apply optimizations on.

  • dtype (torch.dtype) – Only works for torch.bfloat16 and torch.half a.k.a torch.float16. Model parameters will be casted to torch.bfloat16 or torch.half according to dtype of settings. The default value is None, meaning do nothing. Note: Data type conversion is only applied to nn.Conv2d, nn.Linear and nn.ConvTranspose2d for both training and inference cases. For inference mode, additional data type conversion is applied to the weights of nn.Embedding and nn.LSTM.

  • optimizer (torch.optim.Optimizer) – User optimizer to apply optimizations on, such as SGD. The default value is None, meaning inference case.

  • level (string) – "O0" or "O1". No optimizations are applied with "O0". The optimizer function just returns the original model and optimizer. With "O1", the following optimizations are applied: conv+bn folding, weights prepack, dropout removal (inferenc model), master weight split and fused optimizer update step (training model). The optimization options can be further overridden by setting the following options explicitly. The default value is "O1".

  • inplace (bool) – Whether to perform inplace optimization. Default value is False.

  • conv_bn_folding (bool) – Whether to perform conv_bn folding. It only works for inference model. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob.

  • linear_bn_folding (bool) – Whether to perform linear_bn folding. It only works for inference model. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob.

  • weights_prepack (bool) – Whether to perform weight prepack for convolution and linear to avoid oneDNN weights reorder. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob. For now, XPU doesn’t support weights prepack.

  • replace_dropout_with_identity (bool) – Whether to replace nn.Dropout with nn.Identity. If replaced, the aten::dropout won’t be included in the JIT graph. This may provide more fusion opportunites on the graph. This only works for inference model. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob.

  • optimize_lstm (bool) – Whether to replace nn.LSTM with IPEX LSTM which takes advantage of oneDNN kernels to get better performance. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob.

  • split_master_weight_for_bf16 (bool) – Whether to split master weights update for BF16 training. This saves memory comparing to master weight update solution. Split master weights update methodology doesn’t support all optimizers. The default value is None. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob.

  • fuse_update_step (bool) – Whether to use fused params update for training which have better performance. It doesn’t support all optimizers. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob.

  • sample_input (tuple or torch.Tensor) – Whether to feed sample input data to ipex.optimize. The shape of input data will impact the block format of packed weight. If not feed a sample input, Intel® Extension for PyTorch* will pack the weight per some predefined heuristics. If feed a sample input with real input shape, Intel® Extension for PyTorch* can get best block format.

  • auto_kernel_selection (bool) – Different backends may have different performances with different dtypes/shapes. Default value is False. Intel® Extension for PyTorch* will try to optimize the kernel selection for better performance if this knob is set to True. You might get better performance at the cost of extra memory usage. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob.

  • graph_mode – (bool) [prototype]: It will automatically apply a combination of methods to generate graph or multiple subgraphs if True. The default value is False.

  • concat_linear (bool) – Whether to perform concat_linear. It only works for inference model. The default value is None. Explicitly setting this knob overwrites the configuration set by level knob.

Returns:

Model and optimizer (if given) modified according to the level knob or other user settings. conv+bn folding may take place and dropout may be replaced by identity. In inference scenarios, convolutuon, linear and lstm will be replaced with the optimized counterparts in Intel® Extension for PyTorch* (weight prepack for convolution and linear) for good performance. In bfloat16 or float16 scenarios, parameters of convolution and linear will be casted to bfloat16 or float16 dtype.

Warning

Please invoke optimize function BEFORE invoking DDP in distributed training scenario.

The optimize function deepcopys the original model. If DDP is invoked before optimize function, DDP is applied on the origin model, rather than the one returned from optimize function. In this case, some operators in DDP, like allreduce, will not be invoked and thus may cause unpredictable accuracy loss.

Note

Please use torch.save(model.state_dict()) after invoking ipex.optimize function if you want to save the model into a check-point file. As mentioned in Pytorch Doc, saving the entire model by torch.save(model) does not save the model class itself, which is not recommended.

Examples

>>> # bfloat16 inference case.
>>> model = ...
>>> model.load_state_dict(torch.load(PATH))
>>> model.eval()
>>> optimized_model = ipex.optimize(model, dtype=torch.bfloat16)
>>> # running evaluation step.
>>> # bfloat16 training case.
>>> optimizer = ...
>>> model.train()
>>> optimized_model, optimized_optimizer = ipex.optimize(model, dtype=torch.bfloat16, optimizer=optimizer)
>>> # running training step.

torch.xpu.optimize() is an alternative of optimize API in Intel® Extension for PyTorch*, to provide identical usage for XPU device only. The motivation of adding this alias is to unify the coding style in user scripts base on torch.xpu modular.

Examples

>>> # bfloat16 inference case.
>>> model = ...
>>> model.load_state_dict(torch.load(PATH))
>>> model.eval()
>>> optimized_model = torch.xpu.optimize(model, dtype=torch.bfloat16)
>>> # running evaluation step.
>>> # bfloat16 training case.
>>> optimizer = ...
>>> model.train()
>>> optimized_model, optimized_optimizer = torch.xpu.optimize(model, dtype=torch.bfloat16, optimizer=optimizer)
>>> # running training step.

ipex.llm.optimize is used for Large Language Models (LLM).

ipex.llm.optimize(model, optimizer=None, dtype=torch.float32, inplace=False, device='cpu', quantization_config=None, qconfig_summary_file=None, low_precision_checkpoint=None, sample_inputs=None, deployment_mode=True)

Apply optimizations at Python frontend to the given transformers model (nn.Module). This API focus on transformers models, especially for generation tasks inference.

Well supported model family with full functionalities: Llama, GPT-J, GPT-Neox, OPT, Falcon, Bloom, CodeGen, Baichuan, ChatGLM, GPTBigCode, T5, Mistral, MPT, Mixtral, StableLM, QWen, Git, Llava, Yuan, Phi.

For the model that is not in the scope of supported model family above, will try to apply default ipex.optimize transparently to get benifits (not include quantizations, only works for dtypes of torch.bfloat16 and torch.half and torch.float).

Parameters:
  • model (torch.nn.Module) – User model to apply optimizations.

  • optimizer (torch.optim.Optimizer) – User optimizer to apply optimizations on, such as SGD. The default value is None, meaning inference case.

  • dtype (torch.dtype) – Now it works for torch.bfloat16, torch.half and torch.float. The default value is torch.float. When working with quantization, it means the mixed dtype with quantization.

  • inplace (bool) – Whether to perform inplace optimization. Default value is False.

  • device (str) – Specifying the device on which the optimization will be performed. Can be either ‘cpu’ or ‘xpu’ (‘xpu’ is not applicable for cpu only packages). The default value is ‘cpu’.

  • quantization_config (object) – Defining the IPEX quantization recipe (Weight only quant or static quant). Default value is None. Once used, meaning using IPEX quantizatization model for model.generate().

  • qconfig_summary_file (str) – Path to the IPEX static quantization config json file. Default value is None. Work with quantization_config under static quantization use case. Need to do IPEX static quantization calibration and generate this file.

  • low_precision_checkpoint (dict or tuple of dict) – For weight only quantization with INT4 weights. If it’s a dict, it should be the state_dict of checkpoint (.pt) generated by GPTQ, etc. If a tuple is provided, it should be (checkpoint, checkpoint config), where checkpoint is the state_dict and checkpoint config is dict specifying keys of weight/scale/zero point/bias in the state_dict. The default config is {‘weight_key’: ‘packed_weight’, ‘scale_key’: ‘scale’, ‘zero_point_key’: ‘packed_zp’, bias_key: ‘bias’}. Change the values of the dict to make a custom config. Weights shape should be N by K and they are quantized to UINT4 and compressed along K, then stored as torch.int32. Zero points are also UINT4 and stored as INT32. Scales and bias are floating point values. Bias is optional. If bias is not in state dict, bias of the original model is used. Default value is None.

  • sample_inputs (Tuple tensors) – sample inputs used for model quantization or torchscript. Default value is None, and for well supported model, we provide this sample inputs automaticlly.

  • deployment_mode (bool) – Whether to apply the optimized model for deployment of model generation. It means there is no need to further apply optimization like torchscirpt. Default value is True.

Returns:

Optimized model object for model.generate(), also workable with model.forward

Warning

Please invoke ipex.llm.optimize function AFTER invoking DeepSpeed in Tensor Parallel inference scenario.

Examples

>>> # bfloat16 generation inference case.
>>> model = ...
>>> model.load_state_dict(torch.load(PATH))
>>> model.eval()
>>> optimized_model = ipex.llm.optimize(model, dtype=torch.bfloat16)
>>> optimized_model.generate()
class ipex.verbose(level)

On-demand oneDNN verbosing functionality

To make it easier to debug performance issues, oneDNN can dump verbose messages containing information like kernel size, input data size and execution duration while executing the kernel. The verbosing functionality can be invoked via an environment variable named DNNL_VERBOSE. However, this methodology dumps messages in all steps. Those are a large amount of verbose messages. Moreover, for investigating the performance issues, generally taking verbose messages for one single iteration is enough.

This on-demand verbosing functionality makes it possible to control scope for verbose message dumping. In the following example, verbose messages will be dumped out for the second inference only.

import intel_extension_for_pytorch as ipex
model(data)
with ipex.verbose(ipex.verbose.VERBOSE_ON):
    model(data)
Parameters:

level

Verbose level

  • VERBOSE_OFF: Disable verbosing

  • VERBOSE_ON: Enable verbosing

  • VERBOSE_ON_CREATION: Enable verbosing, including oneDNN kernel creation

LLM Module Level Optimizations (Prototype)

Module level optimization APIs are provided for optimizing customized LLMs.

class ipex.llm.modules.LinearSilu(linear)

Applies a linear transformation to the input data, and then apply PyTorch SILU (see https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html) on the result:

result = torch.nn.functional.silu(linear(input))
Parameters:

linear (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with silu.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.LinearSilu(linear_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> result = ipex_fusion(input)
class ipex.llm.modules.LinearSiluMul(linear)

Applies a linear transformation to the input data, then apply PyTorch SILU (see https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html) on the result, and multiplies the result by other:

result = torch.nn.functional.silu(linear(input)) * other
Parameters:

linear (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with silu and mul.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.LinearSiluMul(linear_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> other = torch.randn(4096, 4096)
>>> result = ipex_fusion(input, other)
class ipex.llm.modules.Linear2SiluMul(linear_s, linear_m)

Applies two linear transformation to the input data (linear_s and linear_m), then apply PyTorch SILU (see https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html) on the result from linear_s, and multiplies the result from linear_m:

result = torch.nn.functional.silu(linear_s(input)) * linear_m(input)
Parameters:
  • linear_s (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with silu.

  • linear_m (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with mul.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_s_module = torch.nn.Linear(4096, 4096)
>>> linear_m_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.Linear2SiluMul(linear_s_module, linear_m_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> result = ipex_fusion(input)
class ipex.llm.modules.LinearRelu(linear)

Applies a linear transformation to the input data, and then apply PyTorch RELU (see https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html) on the result:

result = torch.nn.functional.relu(linear(input))
Parameters:

linear (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with relu.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.LinearRelu(linear_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> result = ipex_fusion(input)
class ipex.llm.modules.LinearNewGelu(linear)

Applies a linear transformation to the input data, and then apply NewGELUActivation (see https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L50) on the result:

result = NewGELUActivation(linear(input))
Parameters:

linear (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with new_gelu.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.LinearNewGelu(linear_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> result = ipex_fusion(input)
class ipex.llm.modules.LinearGelu(linear)

Applies a linear transformation to the input data, and then apply PyTorch GELU (see https://pytorch.org/docs/stable/generated/torch.nn.functional.gelu.html) on the result:

result = torch.nn.functional.gelu(linear(input))
Parameters:

linear (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with gelu.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.LinearGelu(linear_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> result = ipex_fusion(input)
class ipex.llm.modules.LinearMul(linear)

Applies a linear transformation to the input data, and then multiplies the result by other:

result = linear(input) * other
Parameters:

linear (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with mul.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.LinearMul(linear_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> other = torch.randn(4096, 4096)
>>> result = ipex_fusion(input, other)
class ipex.llm.modules.LinearAdd(linear)

Applies a linear transformation to the input data, and then add the result by other:

result = linear(input) + other
Parameters:

linear (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with add.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.LinearAdd(linear_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> other = torch.randn(4096, 4096)
>>> result = ipex_fusion(input, other)
class ipex.llm.modules.LinearAddAdd(linear)

Applies a linear transformation to the input data, and then add the result by other_1 and other_2:

result = linear(input) + other_1 + other_2
Parameters:

linear (torch.nn.Linear module) – the original torch.nn.Linear module to be fused with add and add.

Shape:

Input and output shapes are the same as torch.nn.Linear.

Examples

>>> # module init:
>>> linear_module = torch.nn.Linear(4096, 4096)
>>> ipex_fusion = ipex.llm.modules.LinearAddAdd(linear_module)
>>> # module forward:
>>> input = torch.randn(4096, 4096)
>>> other_1 = torch.randn(4096, 4096)
>>> other_2 = torch.randn(4096, 4096)
>>> result = ipex_fusion(input, other_1, other_2)
class ipex.llm.modules.RotaryEmbedding(max_position_embeddings: int, pos_embd_dim: int, base=10000, backbone: str | None = None, extra_rope_config: dict | None = None)

[module init and forward] Applies RotaryEmbedding (see https://huggingface.co/papers/2104.09864) on the query or key before their multi-head attention computation.

module init

Parameters:

forward()

Parameters:
  • input (torch.Tensor) – input to be applied with position embeddings, taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] (as well as the output shape).

  • position_ids (torch.Tensor) – the according position_ids for the input. The shape should be [batch size, sequence length. In some cases, there is only one element which the past_kv_length, and position id can be constructed by past_kv_length + current_position.

  • num_head (int) – head num from the input shape.

  • head_dim (int) – head dim from the input shape.

  • offset (int) – the offset value. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, so the offset is 1. For llama, cos/sin is applied to the neighboring rotary_dim elements, so the offset is rotary_dim/2.

  • rotary_ndims (int) – the rotary dimension. e.g., 64 for GPTJ. head size for LLama.

Examples

>>> # module init:
>>> rope_module = ipex.llm.modules.RotaryEmbedding(2048, 64, base=10000, backbone="GPTJForCausalLM")
>>> # forward:
>>> query = torch.randn(1, 32, 16, 256)
>>> position_ids  = torch.arange(32).unsqueeze(0)
>>> query_rotery = rope_module(query, position_ids, 16, 256, 1, 64)

[Direct function call] This module also provides a .apply_function function call to be used on query and key at the same time without initializing the module (assume rotary embedding sin/cos values are provided).

apply_function()

Parameters:
  • query (torch.Tensor) – inputs to be applied with position embeddings, taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).

  • key (torch.Tensor) – inputs to be applied with position embeddings, taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).

  • sin/cos (torch.Tensor) – [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key.

  • rotary_ndims (int) – the rotary dimension. e.g., 64 for GPTJ. head size for LLama.

  • head_dim (int) – head dim from the input shape.

  • rotary_half (bool) – if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, so the offset is 1. if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements, so the offset is rotary_dim/2.

  • position_ids (torch.Tensor) – Default is None and optional if sin/cos is provided. the according position_ids for the input. The shape should be [batch size, sequence length].

Returns:

[batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim].

Return type:

query, key (torch.Tensor)

class ipex.llm.modules.RMSNorm(hidden_size: int, eps: float = 1e-06, weight: Tensor | None = None)

[module init and forward] Applies RMSnorm on the input (hidden states). (see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L76)

module init

Parameters:
  • hidden_size (int) – the size of the hidden states.

  • eps (float) – the variance_epsilon to apply RMSnorm, default using 1e-6.

  • weight (torch.Tensor) – the weight to apply RMSnorm, default None and will use torch.ones(hidden_size).

forward()

Parameters:

hidden_states (torch.Tensor) – input to be applied RMSnorm, usually taking shape of [batch size, sequence length, hidden_size] (as well as the output shape).

Examples

>>> # module init:
>>> rmsnorm_module = ipex.llm.modules.RMSNorm(4096)
>>> # forward:
>>> input = torch.randn(1, 32, 4096)
>>> result = rmsnorm_module(input)

[Direct function call] This module also provides a .apply_function function call to apply RMSNorm without initializing the module.

apply_function()

Parameters:
  • hidden_states (torch.Tensor) – the input tensor to apply RMSNorm.

  • weight (torch.Tensor) – the weight to apply RMSnorm.

  • eps (float) – the variance_epsilon to apply RMSnorm.

class ipex.llm.modules.FastLayerNorm(normalized_shape: Tuple[int, ...], eps: float, weight: Tensor, bias: Tensor | None = None)

[module init and forward] Applies PyTorch Layernorm (see https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) on the input (hidden states).

module init

Parameters:
  • normalized_shape ((int or list) or torch.Size)

  • eps (float) – a value added to the denominator for numerical stability.

  • weight (torch.Tensor) – the weight of Layernorm to apply normalization.

  • bias (torch.Tensor) – an additive bias for normalization.

forward()

Parameters:

hidden_states (torch.Tensor) – input to be applied Layernorm, usually taking shape of [batch size, sequence length, hidden_size] (as well as the output shape).

Examples

>>> # module init:
>>> layernorm = torch.nn.LayerNorm(4096)
>>> layernorm_module = ipex.llm.modules.FastLayerNorm(4096, eps=1e-05, weight=layernorm.weight, bias=layernorm.bias)
>>> # forward:
>>> input = torch.randn(1, 32, 4096)
>>> result = layernorm_module(input)

[Direct function call] This module also provides a .apply_function function call to apply fast layernorm without initializing the module.

apply_function()

Parameters:
  • hidden_states (torch.Tensor) – the input tensor to apply normalization.

  • normalized_shape (int or list) or torch.Size)

  • weight (torch.Tensor) – the weight to apply normalization.

  • bias (torch.Tensor) – an additive bias for normalization.

  • eps (float) – a value added to the denominator for numerical stability.

class ipex.llm.modules.IndirectAccessKVCacheAttention(text_max_length=2048)

kv_cache is used to reduce computation for Decoder layer but it also brings memory overheads, for example, when using beam search, the kv_cache should be reordered according to the latest beam idx and the current key/value should also be concat with kv_cache in the attention layer to get entire context to do scale dot product. When the sequence is very long, the memory overhead will be the performance bottleneck. This module provides an Indirect Access KV_cache(IAKV), Firstly, IAKV pre-allocates buffers(key and value use different buffers) to store all key/value hidden states and beam index information. It can use beam index history to decide which beam should be used by a timestamp and this information will generate an offset to access the kv_cache buffer.

Data Format:

The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size], the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token. All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].

module init

Parameters:

text_max_length (int) – the max length of kv cache to be used for generation (allocate the pre-cache buffer).

forward()

Parameters:
  • query (torch.Tensor) – Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).

  • key (torch.Tensor) – Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).

  • value (torch.Tensor) – Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).

  • scale_attn (float) – scale used by the attention layer. should be sqrt(head_size).

  • layer_past (tuple(torch.Tensor)) –

    tuple(seq_info, key_cache, value_cache, beam-idx).

    • key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);

    • value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);

    • beam-idx: history beam idx, shape:(max_seq, beam*batch);

    • seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).

  • head_mask (torch.Tensor) – Head mask tensor which is not supported by kernel yet.

  • attention_mask (torch.Tensor) – Attention mask information.

Returns:

Weighted value which is the output of scale dot product. shape (beam*batch, seq_len, head_num, head_size).

attn_weights: The output tensor of the first matmul in scale dot product which is not supported by kernel now.

new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).

Return type:

attn_output

Notes

How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)

def _reorder_cache(
    self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
    if (
        len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1
    ):
        for layer_past in past_key_values:
            layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
        return past_key_values

[Direct function call] This module also provides a .apply_function function call to apply IndirectAccessKVCacheAttention without initializing the module.

The parameters of apply_function() are the same as the forward() call.

class ipex.llm.modules.PagedAttention

This module follows the API of two class methods as [vLLM](https://blog.vllm.ai/2023/06/20/vllm.html) to enable the paged attention kernel in and use the layout of (num_blocks, self.block_size, num_heads, head_size) for key/value cache. The basic logic as following figure. Firstly, The DRAM buffer which includes num_blocks are pre-allocated to store key or value cache. For every block, block_size tokens can be stored. In the forward pass, the cache manager will firstly allocate some slots from this buffer and use reshape_and_cache API to store the key/value and then use single_query_cached_kv_attention API to do the scale-dot-product of MHA. The block is basic allocation unit of paged attention and the token intra-block are stored one-by-one. The block tables are used to map the logical block of sequence into the physical block.

[class method]: reshape_and_cache ipex.llm.modules.PagedAttention.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) This operator is used to store the key/value token states into the pre-allcated kv_cache buffers of paged attention.

Parameters:
  • key (torch.Tensor) – The keytensor. The shape should be [num_seqs, num_heads, head_size].

  • value (torch.Tensor) – The value tensor. The shape should be [num_seqs, num_heads, head_size].

  • key_cache (torch.Tensor) – The pre-allocated buffer to store the key cache. The shape should be [num_blocks, block_size, num_heads, head_size].

  • value_cache (torch.Tensor) – The pre-allocated buffer to store the value cache. The shape should be [num_blocks, block_size, num_heads, head_size].

  • slot_mapping (torch.Tensor) – It stores the position to store the key/value in the pre-allocated buffers. The shape should be the number of sequences. For sequence i, the slot_mapping[i] // block_number can get the block index, and the slot_mapping % block_size can get the offset of this block.

[class method]: single_query_cached_kv_attention

ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
                                                    out,
                                                    query,
                                                    key_cache,
                                                    value_cache,
                                                    head_mapping,
                                                    scale,
                                                    block_tables,
                                                    context_lens,
                                                    block_size,
                                                    max_context_len,
                                                    alibi_slopes
                                                    )

This operator is used to be calculated the scale-dot-product based on the paged attention.

Parameters:
  • out (torch.Tensor) – The output tensor with shape of [num_seqs, num_heads, head_size], where the num_seqs is the number of the sequence in this batch. The num_heads means the number of query head. head_size means the head dimension.

  • query (torch.Tensor) – The query tensor. The shape should be [num_seqs, num_heads, head_size].

  • key_cache (torch.Tensor) – The pre-allocated buffer to store the key cache. The shape should be [num_blocks, block_size, num_heads, head_size].

  • value_cache (torch.Tensor) – The pre-allocated buffer to store the value cache. The shape should be [num_blocks, block_size, num_heads, head_size].

  • head_mapping (torch.Tensor) – The mapping from the query head to the kv head. The shape should be the number of query heads.

  • scale (float) – The scale used by the scale-dot-product. In general, it is: float(1.0 / (head_size ** 0.5)).

  • block_tables – (torch.Tensor): The mapping table used to mapping the logical sequence to the physical sequence. The shape should be [num_seqs, max_num_blocks_per_seq].

  • context_lens (torch.Tensor) – The sequence length for every sequence. The size is [num_seqs].

  • block_size (int) – The block size which means the number of token in every block.

  • max_context_len (int) – The max sequence length.

  • alibi_slopes (torch.Tensor, optinal) – which is the alibi slope with the shape of (num_heads).

class ipex.llm.modules.VarlenAttention

[module init and forward] Applies PyTorch scaled_dot_product_attention on the inputs of query, key and value (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), and accept the variant (different) sequence length among the query, key and value.

This module does not have args for module init.

forward()

Parameters:
  • query (torch.Tensor) – shape [query_tokens, num_head, head_size], where tokens is total sequence length among batch size.

  • key (torch.Tensor) – shape [key_tokens, num_head, head_size], where tokens is total sequence length among batch size.

  • value (torch.Tensor) – shape [value_tokens, num_head, head_size], where tokens is total sequence length among batch size.

  • out (torch.Tensor) – buffer to get the results, the shape is the same as query.

  • seqlen_q (torch.Tensor) – shape [batch_size + 1], points the current query_tokens among total sequence length.

  • seqlen_k (torch.Tensor) – shape [batch_size + 1], points the current key_tokens among total sequence length.

  • max_seqlen_q (int) – max/total sequence length of query.

  • max_seqlen_k (int) – max/total sequence length of key.

  • pdropout (float) – dropout probability; if greater than 0.0, dropout is applied, default is 0.0.

  • softmax_scale (float) – scaling factor applied is prior to softmax.

  • is_causal (bool) – whether to apply causal attention masking, default is True.

Examples

>>> # module init:
>>> varlenAttention_module = ipex.llm.modules.VarlenAttention()
>>> # forward:
>>> query = torch.randn(32, 16, 256)
>>> key = torch.randn(32, 16, 256)
>>> value = torch.randn(32, 16, 256)
>>> out = torch.emply_like(query)
>>> seqlen_q = torch.tensor(1)
>>> seqlen_k = torch.tensor(1)
>>> max_seqlen_q = 1
>>> max_seqlen_k  = 1
>>> pdropout = 0.0
>>> softmax_scale  = 0.5
>>> varlenAttention_module(query, key, value, out, seqlen_q, seqlen_k, max_seqlen_q, max_seqlen_k, pdropout, softmax_scale)

[Direct function call] This module also provides a .apply_function function call to apply VarlenAttention without initializing the module.

The parameters of apply_function() are the same as the forward() call.

ipex.llm.functional.rotary_embedding(query: Tensor, key: Tensor, sin: Tensor, cos: Tensor, rotary_dim: int, rotary_half: bool, position_ids: Tensor | None = None)

Applies RotaryEmbedding (see https://huggingface.co/papers/2104.09864) on the query ` or `key before their multi-head attention computation.

Parameters:
  • query (torch.Tensor) – inputs to be applied with position embeddings, taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).

  • key (torch.Tensor) – inputs to be applied with position embeddings, taking shape of [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim] (as well as the output shape).

  • sin/cos (torch.Tensor) – [num_tokens, rotary_dim] the sin/cos value tensor generated to be applied on query/key.

  • rotary_ndims (int) – the rotary dimension. e.g., 64 for GPTJ. head size for LLama.

  • head_dim (int) – head dim from the input shape.

  • rotary_half (bool) –

    if False. e.g., GPT-J 6B/ChatGLM, cos/sin is applied to the neighboring 2 elements, so the offset is 1.

    if True, e.g., for llama, cos/sin is applied to the neighboring rotary_dim elements, so the offset is rotary_dim/2.

  • position_ids (torch.Tensor) – Default is None and optional if sin/cos is provided. The according position_ids for the input. The shape should be [batch size, sequence length].

Return

query, key (torch.Tensor): [batch size, sequence length, num_head/num_kv_head, head_dim] or [num_tokens, num_head/num_kv_head, head_dim].

ipex.llm.functional.rms_norm(hidden_states: Tensor, weight: Tensor, eps: float)

Applies RMSnorm on the input (hidden states). (see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L76)

Parameters:
  • hidden_states (torch.Tensor) – the input tensor to apply RMSNorm.

  • weight (torch.Tensor) – the weight to apply RMSnorm.

  • eps (float) – the variance_epsilon to apply RMSnorm.

ipex.llm.functional.fast_layer_norm(hidden_states: Tensor, normalized_shape: Tuple[int, ...], weight: Tensor, bias: Tensor, eps: float)

Applies PyTorch Layernorm (see https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) on the input (hidden states).

Parameters:
  • hidden_states (torch.Tensor) – the input tensor to apply normalization.

  • normalized_shape (int or list) or torch.Size) – expected input of size.

  • weight (torch.Tensor) – the weight to apply normalization.

  • bias (torch.Tensor) – an additive bias for normalization.

  • eps (float) – a value added to the denominator for numerical stability.

ipex.llm.functional.indirect_access_kv_cache_attention(query: Tensor, key: Tensor, value: Tensor, scale_attn: float, layer_past: Tuple[Tensor] | None = None, head_mask: Tuple[Tensor] | None = None, attention_mask: Tuple[Tensor] | None = None, alibi: Tensor | None = None, add_casual_mask: bool | None = True, seq_info: Tensor | None = None, text_max_length: int | None = 0)

kv_cache is used to reduce computation for Decoder layer but it also brings memory overheads, for example, when using beam search, the kv_cache should be reordered according to the latest beam idx and the current key/value should also be concat with kv_cache in the attention layer to get entire context to do scale dot product. When the sequence is very long, the memory overhead will be the performance bottleneck. This module provides an Indirect Access KV_cache(IAKV), Firstly, IAKV pre-allocates buffers(key and value use different buffers) to store all key/value hidden states and beam index information. It can use beam index history to decide which beam should be used by a timestamp and this information will generate an offset to access the kv_cache buffer.

Data Format:

The shape of the pre-allocated key(value) buffer is [max_seq, beam*batch, head_num, head_size], the hidden state of key/value which is the shape of [beam*batch, head_num, head_size] is stored token by token. All beam idx information of every timestamp is also stored in a Tensor with the shape of [max_seq, beam*batch].

Parameters:
  • query (torch.Tensor) – Query tensor; shape: (beam*batch, seq_len, head_num, head_dim).

  • key (torch.Tensor) – Key tensor; shape: (beam*batch, seq_len, head_num, head_dim).

  • value (torch.Tensor) – Value tensor; shape: (beam*batch, seq_len, head_num, head_dim).

  • scale_attn (float) – scale used by the attention layer. should be the sqrt(head_size).

  • layer_past (tuple(torch.Tensor)) –

    tuple(seq_info, key_cache, value_cache, beam-idx).

    • key_cache: key cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);

    • value_cache: value cache tensor, shape: (max_seq, beam*batch, head_num, head_dim);

    • beam-idx: history beam idx, shape:(max_seq, beam*batch);

    • seq_info: Sequence info tensor, shape:(1, 1, max_seq, max_seq).

  • head_mask (torch.Tensor) – Head mask tensor which is not supported by kernel yet.

  • attention_mask (torch.Tensor) – Attention mask information.

  • text_max_length (int) – the max length of kv cache to be used for generation (allocate the pre-cache buffer).

Returns:

weighted value which is the output of scale dot product. shape (beam*batch, seq_len, head_num, head_size).

attn_weights: the output tensor of the first matmul in scale dot product which is not supported by kernel now.

new_layer_past: updated layer_past (seq_info, key_cache, value_cache, beam-idx).

Return type:

attn_output

Notes

How to reorder KV cache when using the format of IndirectAccessKVCacheAttention (e.g., on llama model see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1318)

def _reorder_cache(
    self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
    if (
        len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1
    ):
        for layer_past in past_key_values:
            layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
        return past_key_values
ipex.llm.functional.varlen_attention(query: Tensor, key: Tensor, value: Tensor, out: Tensor, seqlen_q: Tensor, seqlen_k: Tensor, max_seqlen_q: int, max_seqlen_k: int, pdropout: float, softmax_scale: float, zero_tensors: bool, is_causal: bool, return_softmax: bool, gen_: Generator)

Applies PyTorch scaled_dot_product_attention on the inputs of query, key and value (see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), and accept the variant (different) sequence length among the query, key and value.

This module does not have args for module init.

forward()

Parameters:
  • query (torch.Tensor) – shape [query_tokens, num_head, head_size], where tokens is total sequence length among batch size.

  • key (torch.Tensor) – shape [key_tokens, num_head, head_size], where tokens is total sequence length among batch size.

  • value (torch.Tensor) – shape [value_tokens, num_head, head_size], where tokens is total sequence length among batch size.

  • out (torch.Tensor) – buffer to get the results, the shape is the same as query.

  • seqlen_q (torch.Tensor) – shape [batch_size + 1], points the current query_tokens among total sequence length.

  • seqlen_k (torch.Tensor) – shape [batch_size + 1], points the current key_tokens among total sequence length.

  • max_seqlen_q (int) – max/total sequence length of query.

  • max_seqlen_k (int) – max/total sequence length of key.

  • pdropout (float) – dropout probability; if greater than 0.0, dropout is applied, default is 0.0.

  • softmax_scale (float) – scaling factor applied is prior to softmax.

  • is_causal (bool) – whether to apply causal attention masking, default is True.

Fast Bert (Prototype)

ipex.fast_bert(model, dtype=torch.float32, optimizer=None, unpad=False)

Use TPP to speedup training/inference. fast_bert API is still a prototype feature and now only optimized for bert model.

Parameters:
  • model (torch.nn.Module) – User model to apply optimizations on.

  • dtype (torch.dtype) – Only works for torch.bfloat16 and torch.float . The default value is torch.float.

  • optimizer (torch.optim.Optimizer) – User optimizer to apply optimizations on, such as SGD. The default value is None, meaning inference case.

  • unpad (bool) – Unpad the squence to reduce the sparsity.

  • seed (string) – The seed used for the libxsmm kernel. In general it should be same to the torch.seed

Note

Currently ipex.fast_bert API is well optimized for training tasks. It works for inference tasks, though, please use the ipex.optimize API with TorchScript to achieve the peak performance.

Warning

Please invoke fast_bert function AFTER loading weights to model via model.load_state_dict(torch.load(PATH)).

Warning

This API can’t be used when you have applied the ipex.optimize.

Warning

Please invoke optimize function BEFORE invoking DDP in distributed training scenario.

Examples

>>> # bfloat16 inference case.
>>> model = ...
>>> model.load_state_dict(torch.load(PATH))
>>> model.eval()
>>> optimized_model = ipex.fast_bert(model, dtype=torch.bfloat16)
>>> # running evaluation step.
>>> # bfloat16 training case.
>>> optimizer = ...
>>> model.train()
>>> optimized_model, optimized_optimizer = ipex.fast_bert(model, dtype=torch.bfloat16,
        optimizer=optimizer, unpad=True, seed=args.seed)
>>> # running training step.

Graph Optimization

ipex.enable_onednn_fusion(enabled)

Enables or disables oneDNN fusion functionality. If enabled, oneDNN operators will be fused in runtime, when intel_extension_for_pytorch is imported.

Parameters:

enabled (bool) – Whether to enable oneDNN fusion functionality or not. Default value is True.

Examples

>>> import intel_extension_for_pytorch as ipex
>>> # to enable the oneDNN fusion
>>> ipex.enable_onednn_fusion(True)
>>> # to disable the oneDNN fusion
>>> ipex.enable_onednn_fusion(False)

Quantization

ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5, act_observer=None, act_ic_observer=None, wei_observer=None, wei_ic_observer=None, share_weight_observers=True)

Configuration with SmoothQuant for static quantization of large language models (LLM) For SmoothQuant, see https://arxiv.org/pdf/2211.10438.pdf

Parameters:
  • alpha – Hyper-parameter for SmoothQuant.

  • act_observer – Observer for activation of ops other than nn.Linear. HistogramObserver by default. For nn.Linear with SmoothQuant enabled, q-param is calculated based on act_ic_observer’s and wei_ic_observer’s min/max. It is not affected by this argument. Example: torch.ao.quantization.MinMaxObserver

  • act_ic_observer – Per-input-channel Observer for activation. For nn.Linear with SmoothQuant enabled only. PerChannelMinMaxObserver by default. Example: torch.ao.quantization.PerChannelMinMaxObserver.with_args(ch_axis=1)

  • wei_observer – Observer for weight of all weighted ops. For nn.Linear with SmoothQuant enabled, it calculates q-params after applying scaling factors. PerChannelMinMaxObserver by default. Example: torch.ao.quantization.PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)

  • wei_ic_observer – Per-input-channel Observer for weight. For nn.Linear with SmoothQuant enabled only. PerChannelMinMaxObserver by default. Example: torch.ao.quantization.PerChannelMinMaxObserver.with_args(ch_axis=1)

Returns:

torch.ao.quantization.QConfig

ipex.quantization.prepare(model, configure, example_inputs=None, inplace=False, bn_folding=True, example_kwarg_inputs=None)

Prepare an FP32 torch.nn.Module model to do calibration or to convert to quantized model.

Parameters:
  • model (torch.nn.Module) – The FP32 model to be prepared.

  • configure (torch.quantization.qconfig.QConfig) – The observer settings about activation and weight.

  • example_inputs (tuple or torch.Tensor) – A tuple of example inputs that will be passed to the function while running to init quantization state. Only one of this argument or example_kwarg_inputs should be specified.

  • inplace – (bool): It will change the given model in-place if True. The default value is False. Note that if bn_folding is True, the returned model is a different object from the original model even if inplace=True. So, with the following code >>> prepared_model = prepare(original_model, …, inplace=True) please use prepared_model for later operations to avoid unexpected behaviors.

  • bn_folding – (bool): whether to perform conv_bn and linear_bn folding. The default value is True.

  • example_kwarg_inputs (dict) – A dict of example inputs that will be passed to the function while running to init quantization state. Only one of this argument or example_inputs should be specified.

Returns:

torch.nn.Module

ipex.quantization.convert(model, inplace=False)

Convert an FP32 prepared model to a model which will automatically insert fake quant before a quantizable module or operator.

Parameters:
  • model (torch.nn.Module) – The FP32 model to be convert.

  • inplace – (bool): It will change the given model in-place if True. The default value is False.

Returns:

torch.nn.Module

Prototype API, introduction is avaiable at feature page.

ipex.quantization.autotune(model, calib_dataloader, calib_func=None, eval_func=None, op_type_dict=None, smoothquant_args=None, sampling_sizes=None, accuracy_criterion=None, tuning_time=0)

Automatic accuracy-driven tuning helps users quickly find out the advanced recipe for INT8 inference.

Parameters:
  • model (torch.nn.Module) – fp32 model.

  • calib_dataloader (generator) – set a dataloader for calibration.

  • calib_func (function) – calibration function for post-training static quantization. It is optional. This function takes “model” as input parameter and executes entire inference process.

  • eval_func (function) – set a evaluation function. This function takes “model” as input parameter executes entire evaluation process with self contained metrics, and returns an accuracy value which is a scalar number. The higher the better.

  • op_type_dict (dict) – Tuning constraints on optype-wise for advance user to reduce tuning space. User can specify the quantization config by op type:

  • smoothquant_args (dict) – smoothquant recipes for automatic global alpha tuning, and automatic layer-by-layer alpha tuning for the best INT8 accuracy.

  • sampling_sizes (list) – a list of sample sizes used in calibration, where the tuning algorithm would explore from. The default value is [100].

  • accuracy_criterion ({accuracy_criterion_type(str, 'relative' or 'absolute') – accuracy_criterion_value(float)}): set the maximum allowed accuracy loss, either relative or absolute. The default value is {'relative': 0.01}.

  • tuning_time (seconds) – tuning timeout. The default value is 0 which means early stop.

Returns:

the prepared model loaded qconfig after tuning.

Return type:

prepared_model (torch.nn.Module)

CPU Runtime

ipex.cpu.runtime.is_runtime_ext_enabled()

Helper function to check whether runtime extension is enabled or not.

Parameters:

None (None) – None

Returns:

Whether the runtime exetension is enabled or not. If the

Intel OpenMP Library is preloaded, this API will return True. Otherwise, it will return False.

Return type:

bool

class ipex.cpu.runtime.CPUPool(core_ids: list | None = None, node_id: int | None = None)

An abstraction of a pool of CPU cores used for intra-op parallelism.

Parameters:
  • core_ids (list) – A list of CPU cores’ ids used for intra-op parallelism.

  • node_id (int) – A numa node id with all CPU cores on the numa node. node_id doesn’t work if core_ids is set.

Returns:

Generated ipex.cpu.runtime.CPUPool object.

Return type:

ipex.cpu.runtime.CPUPool

class ipex.cpu.runtime.pin(cpu_pool: CPUPool)

Apply the given CPU pool to the master thread that runs the scoped code region or the function/method def.

Parameters:

cpu_pool (ipex.cpu.runtime.CPUPool) – ipex.cpu.runtime.CPUPool object, contains all CPU cores used by the designated operations.

Returns:

Generated ipex.cpu.runtime.pin object which can be used as a with context or a function decorator.

Return type:

ipex.cpu.runtime.pin

class ipex.cpu.runtime.MultiStreamModuleHint(*args, **kwargs)

MultiStreamModuleHint is a hint to MultiStreamModule about how to split the inputs or concat the output. Each argument should be None, with type of int or a container which containes int or None such as: (0, None, …) or [0, None, …]. If the argument is None, it means this argument will not be split or concat. If the argument is with type int, its value means along which dim this argument will be split or concat.

Parameters:
  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.

Returns:

Generated ipex.cpu.runtime.MultiStreamModuleHint object.

Return type:

ipex.cpu.runtime.MultiStreamModuleHint

class ipex.cpu.runtime.MultiStreamModule(model, num_streams: int | str = 'AUTO', cpu_pool: ~ipex.cpu.runtime.cpupool.CPUPool = <ipex.cpu.runtime.cpupool.CPUPool object>, concat_output: bool = True, input_split_hint: ~ipex.cpu.runtime.multi_stream.MultiStreamModuleHint = <ipex.cpu.runtime.multi_stream.MultiStreamModuleHint object>, output_concat_hint: ~ipex.cpu.runtime.multi_stream.MultiStreamModuleHint = <ipex.cpu.runtime.multi_stream.MultiStreamModuleHint object>)

MultiStreamModule supports inference with multi-stream throughput mode.

If the number of cores inside cpu_pool is divisible by num_streams, the cores will be allocated equally to each stream. If the number of cores inside cpu_pool is not divisible by num_streams with remainder N, one extra core will be allocated to the first N streams. We suggest to set the num_streams as divisor of core number inside cpu_pool.

If the inputs’ batchsize is larger than and divisible by num_streams, the batchsize will be allocated equally to each stream. If batchsize is not divisible by num_streams with remainder N, one extra piece will be allocated to the first N streams. If the inputs’ batchsize is less than num_streams, only the first batchsize’s streams are used with mini batch as one. We suggest to set inputs’ batchsize larger than and divisible by num_streams. If you don’t want to tune the num of streams and leave it as “AUTO”, we suggest to set inputs’ batchsize larger than and divisible by number of cores.

Parameters:
  • model (torch.jit.ScriptModule or torch.nn.Module) – The input model.

  • num_streams (Union[int, str]) – Number of instances (int) or “AUTO” (str). “AUTO” means the stream number will be selected automatically. Although “AUTO” usually provides a reasonable performance, it may still not be optimal for some cases which means manual tuning for number of streams is needed for this case.

  • cpu_pool (ipex.cpu.runtime.CPUPool) – An ipex.cpu.runtime.CPUPool object, contains all CPU cores used to run multi-stream inference.

  • concat_output (bool) – A flag indicates whether the output of each stream will be concatenated or not. The default value is True. Note: if the output of each stream can’t be concatenated, set this flag to false to get the raw output (a list of each stream’s output).

  • input_split_hint (MultiStreamModuleHint) – Hint to MultiStreamModule about how to split the inputs.

  • output_concat_hint (MultiStreamModuleHint) – Hint to MultiStreamModule about how to concat the outputs.

Returns:

Generated ipex.cpu.runtime.MultiStreamModule object.

Return type:

ipex.cpu.runtime.MultiStreamModule

class ipex.cpu.runtime.Task(module, cpu_pool: CPUPool)

An abstraction of computation based on PyTorch module and is scheduled asynchronously.

Parameters:
  • model (torch.jit.ScriptModule or torch.nn.Module) – The input module.

  • cpu_pool (ipex.cpu.runtime.CPUPool) – An ipex.cpu.runtime.CPUPool object, contains all CPU cores used to run Task asynchronously.

Returns:

Generated ipex.cpu.runtime.Task object.

Return type:

ipex.cpu.runtime.Task

ipex.cpu.runtime.get_core_list_of_node_id(node_id)

Helper function to get the CPU cores’ ids of the input numa node.

Parameters:

node_id (int) – Input numa node id.

Returns:

List of CPU cores’ ids on this numa node.

Return type:

list