API Documentation
General
- ipex.optimize(model, dtype=None, optimizer=None, level='O1', inplace=False, conv_bn_folding=None, linear_bn_folding=None, weights_prepack=None, replace_dropout_with_identity=None, optimize_lstm=None, split_master_weight_for_bf16=None, fuse_update_step=None, auto_kernel_selection=None, sample_input=None, graph_mode=None)
Apply optimizations at Python frontend to the given model (nn.Module), as well as the given optimizer (optional). If the optimizer is given, optimizations will be applied for training. Otherwise, optimization will be applied for inference. Optimizations include
conv+bn
folding (for inference only), weight prepacking and so on.Weight prepacking is a technique to accelerate performance of oneDNN operators. In order to achieve better vectorization and cache reuse, onednn uses a specific memory layout called
blocked layout
. Although the calculation itself withblocked layout
is fast enough, from memory usage perspective it has drawbacks. Running with theblocked layout
, oneDNN splits one or several dimensions of data into blocks with fixed size each time the operator is executed. More details information about oneDNN data mermory format is available at oneDNN manual. To reduce this overhead, data will be converted to predefined block shapes prior to the execution of oneDNN operator execution. In runtime, if the data shape matches oneDNN operator execution requirements, oneDNN won’t perform memory layout conversion but directly go to calculation. Through this methodology, calledweight prepacking
, it is possible to avoid runtime weight data format convertion and thus increase performance.- Parameters
model (torch.nn.Module) – User model to apply optimizations on.
dtype (torch.dtype) – Only works for
torch.bfloat16
andtorch.half
a.k.atorch.float16
. Model parameters will be casted totorch.bfloat16
ortorch.half
according to dtype of settings. The default value is None, meaning do nothing. Note: Data type conversion is only applied tonn.Conv2d
,nn.Linear
andnn.ConvTranspose2d
for both training and inference cases. For inference mode, additional data type conversion is applied to the weights ofnn.Embedding
andnn.LSTM
.optimizer (torch.optim.Optimizer) – User optimizer to apply optimizations on, such as SGD. The default value is
None
, meaning inference case.level (string) –
"O0"
or"O1"
. No optimizations are applied with"O0"
. The optimizer function just returns the original model and optimizer. With"O1"
, the following optimizations are applied: conv+bn folding, weights prepack, dropout removal (inferenc model), master weight split and fused optimizer update step (training model). The optimization options can be further overridden by setting the following options explicitly. The default value is"O1"
.inplace (bool) – Whether to perform inplace optimization. Default value is
False
.conv_bn_folding (bool) – Whether to perform
conv_bn
folding. It only works for inference model. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.linear_bn_folding (bool) – Whether to perform
linear_bn
folding. It only works for inference model. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.weights_prepack (bool) – Whether to perform weight prepack for convolution and linear to avoid oneDNN weights reorder. The default value is
None
. Explicitly setting this knob overwrites the configuration set bylevel
knob.replace_dropout_with_identity (bool) – Whether to replace
nn.Dropout
withnn.Identity
. If replaced, theaten::dropout
won’t be included in the JIT graph. This may provide more fusion opportunites on the graph. This only works for inference model. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.optimize_lstm (bool) – Whether to replace
nn.LSTM
withIPEX LSTM
which takes advantage of oneDNN kernels to get better performance. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.split_master_weight_for_bf16 (bool) – Whether to split master weights update for BF16 training. This saves memory comparing to master weight update solution. Split master weights update methodology doesn’t support all optimizers. The default value is None. The default value is
None
. Explicitly setting this knob overwrites the configuration set bylevel
knob.fuse_update_step (bool) – Whether to use fused params update for training which have better performance. It doesn’t support all optimizers. The default value is
None
. Explicitly setting this knob overwrites the configuration set bylevel
knob.sample_input (tuple or torch.Tensor) – Whether to feed sample input data to ipex.optimize. The shape of input data will impact the block format of packed weight. If not feed a sample input, Intel® Extension for PyTorch* will pack the weight per some predefined heuristics. If feed a sample input with real input shape, Intel® Extension for PyTorch* can get best block format.
auto_kernel_selection (bool) – Different backends may have different performances with different dtypes/shapes. Default value is False. Intel® Extension for PyTorch* will try to optimize the kernel selection for better performance if this knob is set to
True
. You might get better performance at the cost of extra memory usage. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.graph_mode – (bool) [experimental]: It will automatically apply a combination of methods to generate graph or multiple subgraphs if True. The default value is
False
.
- Returns
Model and optimizer (if given) modified according to the
level
knob or other user settings.conv+bn
folding may take place anddropout
may be replaced byidentity
. In inference scenarios, convolutuon, linear and lstm will be replaced with the optimized counterparts in Intel® Extension for PyTorch* (weight prepack for convolution and linear) for good performance. In bfloat16 or float16 scenarios, parameters of convolution and linear will be casted to bfloat16 or float16 dtype.
Warning
Please invoke
optimize
function AFTER loading weights to model viamodel.load_state_dict(torch.load(PATH))
.Warning
Please invoke
optimize
function BEFORE invoking DDP in distributed training scenario.The
optimize
function deepcopys the original model. If DDP is invoked beforeoptimize
function, DDP is applied on the origin model, rather than the one returned fromoptimize
function. In this case, some operators in DDP, like allreduce, will not be invoked and thus may cause unpredictable accuracy loss.Examples
>>> # bfloat16 inference case. >>> model = ... >>> model.load_state_dict(torch.load(PATH)) >>> model.eval() >>> optimized_model = ipex.optimize(model, dtype=torch.bfloat16) >>> # running evaluation step. >>> # bfloat16 training case. >>> optimizer = ... >>> model.train() >>> optimized_model, optimized_optimizer = ipex.optimize(model, dtype=torch.bfloat16, optimizer=optimizer) >>> # running training step.
- class ipex.verbose(level)
On-demand oneDNN verbosing functionality
To make it easier to debug performance issues, oneDNN can dump verbose messages containing information like kernel size, input data size and execution duration while executing the kernel. The verbosing functionality can be invoked via an environment variable named DNNL_VERBOSE. However, this methodology dumps messages in all steps. Those are a large amount of verbose messages. Moreover, for investigating the performance issues, generally taking verbose messages for one single iteration is enough.
This on-demand verbosing functionality makes it possible to control scope for verbose message dumping. In the following example, verbose messages will be dumped out for the second inference only.
import intel_extension_for_pytorch as ipex model(data) with ipex.verbose(ipex.utils.VERBOSE_ON): model(data)
- Parameters
level –
Verbose level
VERBOSE_OFF
: Disable verbosingVERBOSE_ON
: Enable verbosingVERBOSE_ON_CREATION
: Enable verbosing, including oneDNN kernel creation
Fast Bert (Experimental)
- ipex.fast_bert(model, dtype=torch.float32, optimizer=None, unpad=False)
Use TPP to speedup training/inference. fast_bert API is still a experimental feature and now only optimized for bert model.
- Parameters
model (torch.nn.Module) – User model to apply optimizations on.
dtype (torch.dtype) – Only works for
torch.bfloat16
andtorch.float
. The default value is torch.float.optimizer (torch.optim.Optimizer) – User optimizer to apply optimizations on, such as SGD. The default value is
None
, meaning inference case.unpad (bool) – Unpad the squence to reduce the sparsity.
seed (string) – The seed used for the libxsmm kernel. In general it should be same to the torch.seed
Warning
Please invoke
fast_bert
function AFTER loading weights to model viamodel.load_state_dict(torch.load(PATH))
.Warning
This API can’t be used when you have applied the ipex.optimize.
Warning
Please invoke
optimize
function BEFORE invoking DDP in distributed training scenario.Examples
>>> # bfloat16 inference case. >>> model = ... >>> model.load_state_dict(torch.load(PATH)) >>> model.eval() >>> optimized_model = ipex.tpp_bert(model, dtype=torch.bfloat16) >>> # running evaluation step. >>> # bfloat16 training case. >>> optimizer = ... >>> model.train() >>> optimized_model, optimized_optimizer = ipex.fast_bert(model, dtype=torch.bfloat16, optimizer=optimizer, unpad=True, seed=args.seed) >>> # running training step.
Graph Optimization
- ipex.enable_onednn_fusion(enabled)
Enables or disables oneDNN fusion functionality. If enabled, oneDNN operators will be fused in runtime, when intel_extension_for_pytorch is imported.
- Parameters
enabled (bool) – Whether to enable oneDNN fusion functionality or not. Default value is
True
.
Examples
>>> import intel_extension_for_pytorch as ipex >>> # to enable the oneDNN fusion >>> ipex.enable_onednn_fusion(True) >>> # to disable the oneDNN fusion >>> ipex.enable_onednn_fusion(False)
Quantization
- ipex.quantization.prepare(model, configure, example_inputs, inplace=False)
Prepare an FP32 torch.nn.Module model to do calibration or to convert to quantized model.
- Parameters
model (torch.nn.Module) – The FP32 model to be prepared.
configure (torch.quantization.qconfig.QConfig) – The observer settings about activation and weight.
example_inputs (tuple or torch.Tensor) – A tuple of example inputs that will be passed to the function while running to init quantization state.
inplace – (bool): It will change the given model in-place if True. The default value is
False
.
- Returns
torch.nn.Module
- ipex.quantization.convert(model, inplace=False)
Convert an FP32 prepared model to a model which will automatically insert fake quant before a quantizable module or operator.
- Parameters
model (torch.nn.Module) – The FP32 model to be convert.
inplace – (bool): It will change the given model in-place if True. The default value is
False
.
- Returns
torch.nn.Module
Experimental API, introduction is avaiable at feature page.
- ipex.quantization.autotune(prepared_model, calib_dataloader, eval_func, sampling_sizes=[100], accuracy_criterion={'relative': 0.01}, tuning_time=0)
Automatic accuracy-driven tuning helps users quickly find out the advanced recipe for INT8 inference.
- Parameters
prepared_model (torch.nn.Module) – the FP32 prepared model returned from ipex.quantization.prepare.
calib_dataloader (generator) – set a dataloader for calibration.
eval_func (function) – set a evaluation function. This function takes “model” as input parameter executes entire evaluation process with self contained metrics, and returns an accuracy value which is a scalar number. The higher the better.
sampling_sizes (list) – a list of sample sizes used in calibration, where the tuning algorithm would explore from. The default value is
[100]
.accuracy_criterion ({accuracy_criterion_type(str, 'relative' or 'absolute') – accuracy_criterion_value(float)}): set the maximum allowed accuracy loss, either relative or absolute. The default value is
{'relative': 0.01}
.tuning_time (seconds) – tuning timeout. The default value is
0
which means early stop.
- Returns
FP32 tuned model (torch.nn.Module)
CPU Runtime
- ipex.cpu.runtime.is_runtime_ext_enabled()
Helper function to check whether runtime extension is enabled or not.
- Parameters
None (None) – None
- Returns
- Whether the runtime exetension is enabled or not. If the
Intel OpenMP Library is preloaded, this API will return True. Otherwise, it will return False.
- Return type
bool
- class ipex.cpu.runtime.CPUPool(core_ids: Optional[list] = None, node_id: Optional[int] = None)
An abstraction of a pool of CPU cores used for intra-op parallelism.
- Parameters
core_ids (list) – A list of CPU cores’ ids used for intra-op parallelism.
node_id (int) – A numa node id with all CPU cores on the numa node.
node_id
doesn’t work ifcore_ids
is set.
- Returns
Generated ipex.cpu.runtime.CPUPool object.
- Return type
- class ipex.cpu.runtime.pin(cpu_pool: CPUPool)
Apply the given CPU pool to the master thread that runs the scoped code region or the function/method def.
- Parameters
cpu_pool (ipex.cpu.runtime.CPUPool) – ipex.cpu.runtime.CPUPool object, contains all CPU cores used by the designated operations.
- Returns
Generated ipex.cpu.runtime.pin object which can be used as a with context or a function decorator.
- Return type
- class ipex.cpu.runtime.MultiStreamModuleHint(*args, **kwargs)
MultiStreamModuleHint is a hint to MultiStreamModule about how to split the inputs or concat the output. Each argument should be None, with type of int or a container which containes int or None such as: (0, None, …) or [0, None, …]. If the argument is None, it means this argument will not be split or concat. If the argument is with type int, its value means along which dim this argument will be split or concat.
- Parameters
*args – Variable length argument list.
**kwargs – Arbitrary keyword arguments.
- Returns
Generated ipex.cpu.runtime.MultiStreamModuleHint object.
- Return type
- class ipex.cpu.runtime.MultiStreamModule(model, num_streams: ~typing.Union[int, str] = 'AUTO', cpu_pool: ~ipex.cpu.runtime.cpupool.CPUPool = <ipex.cpu.runtime.cpupool.CPUPool object>, concat_output: bool = True, input_split_hint: ~ipex.cpu.runtime.multi_stream.MultiStreamModuleHint = <ipex.cpu.runtime.multi_stream.MultiStreamModuleHint object>, output_concat_hint: ~ipex.cpu.runtime.multi_stream.MultiStreamModuleHint = <ipex.cpu.runtime.multi_stream.MultiStreamModuleHint object>)
MultiStreamModule supports inference with multi-stream throughput mode.
If the number of cores inside
cpu_pool
is divisible bynum_streams
, the cores will be allocated equally to each stream. If the number of cores insidecpu_pool
is not divisible bynum_streams
with remainder N, one extra core will be allocated to the first N streams. We suggest to set thenum_streams
as divisor of core number insidecpu_pool
.If the inputs’ batchsize is larger than and divisible by
num_streams
, the batchsize will be allocated equally to each stream. If batchsize is not divisible bynum_streams
with remainder N, one extra piece will be allocated to the first N streams. If the inputs’ batchsize is less thannum_streams
, only the first batchsize’s streams are used with mini batch as one. We suggest to set inputs’ batchsize larger than and divisible bynum_streams
. If you don’t want to tune the num of streams and leave it as “AUTO”, we suggest to set inputs’ batchsize larger than and divisible by number of cores.- Parameters
model (torch.jit.ScriptModule or torch.nn.Module) – The input model.
num_streams (Union[int, str]) – Number of instances (int) or “AUTO” (str). “AUTO” means the stream number will be selected automatically. Although “AUTO” usually provides a reasonable performance, it may still not be optimal for some cases which means manual tuning for number of streams is needed for this case.
cpu_pool (ipex.cpu.runtime.CPUPool) – An ipex.cpu.runtime.CPUPool object, contains all CPU cores used to run multi-stream inference.
concat_output (bool) – A flag indicates whether the output of each stream will be concatenated or not. The default value is True. Note: if the output of each stream can’t be concatenated, set this flag to false to get the raw output (a list of each stream’s output).
input_split_hint (MultiStreamModuleHint) – Hint to MultiStreamModule about how to split the inputs.
output_concat_hint (MultiStreamModuleHint) – Hint to MultiStreamModule about how to concat the outputs.
- Returns
Generated ipex.cpu.runtime.MultiStreamModule object.
- Return type
- class ipex.cpu.runtime.Task(module, cpu_pool: CPUPool)
An abstraction of computation based on PyTorch module and is scheduled asynchronously.
- Parameters
model (torch.jit.ScriptModule or torch.nn.Module) – The input module.
cpu_pool (ipex.cpu.runtime.CPUPool) – An ipex.cpu.runtime.CPUPool object, contains all CPU cores used to run Task asynchronously.
- Returns
Generated ipex.cpu.runtime.Task object.
- Return type
- ipex.cpu.runtime.get_core_list_of_node_id(node_id)
Helper function to get the CPU cores’ ids of the input numa node.
- Parameters
node_id (int) – Input numa node id.
- Returns
List of CPU cores’ ids on this numa node.
- Return type
list