API Documentation
Device-Agnostic
- ipex.optimize(model, dtype=None, optimizer=None, level='O1', inplace=False, conv_bn_folding=None, linear_bn_folding=None, weights_prepack=None, replace_dropout_with_identity=None, optimize_lstm=None, split_master_weight_for_bf16=None, fuse_update_step=None, auto_kernel_selection=None, sample_input=None, graph_mode=None)
- Apply optimizations at Python frontend to the given model (nn.Module), as well as the given optimizer (optional). If the optimizer is given, optimizations will be applied for training. Otherwise, optimization will be applied for inference. Optimizations include - conv+bnfolding (for inference only), weight prepacking and so on.- Weight prepacking is a technique to accelerate performance of oneDNN operators. In order to achieve better vectorization and cache reuse, onednn uses a specific memory layout called - blocked layout. Although the calculation itself with- blocked layoutis fast enough, from memory usage perspective it has drawbacks. Running with the- blocked layout, oneDNN splits one or several dimensions of data into blocks with fixed size each time the operator is executed. More details information about oneDNN data mermory format is available at oneDNN manual. To reduce this overhead, data will be converted to predefined block shapes prior to the execution of oneDNN operator execution. In runtime, if the data shape matches oneDNN operator execution requirements, oneDNN won’t perform memory layout conversion but directly go to calculation. Through this methodology, called- weight prepacking, it is possible to avoid runtime weight data format convertion and thus increase performance.- Parameters
- model (torch.nn.Module) – User model to apply optimizations on. 
- dtype (torch.dtype) – Only works for - torch.bfloat16and- torch.halfa.k.a- torch.float16. Model parameters will be casted to- torch.bfloat16or- torch.halfaccording to dtype of settings. The default value is None, meaning do nothing. Note: Data type conversion is only applied to- nn.Conv2d,- nn.Linearand- nn.ConvTranspose2dfor both training and inference cases. For inference mode, additional data type conversion is applied to the weights of- nn.Embeddingand- nn.LSTM.
- optimizer (torch.optim.Optimizer) – User optimizer to apply optimizations on, such as SGD. The default value is - None, meaning inference case.
- level (string) – - "O0"or- "O1". No optimizations are applied with- "O0". The optimizer function just returns the original model and optimizer. With- "O1", the following optimizations are applied: conv+bn folding, weights prepack, dropout removal (inferenc model), master weight split and fused optimizer update step (training model). The optimization options can be further overridden by setting the following options explicitly. The default value is- "O1".
- inplace (bool) – Whether to perform inplace optimization. Default value is - False.
- conv_bn_folding (bool) – Whether to perform - conv_bnfolding. It only works for inference model. The default value is- None. Explicitly setting this knob overwrites the configuration set by- levelknob.
- linear_bn_folding (bool) – Whether to perform - linear_bnfolding. It only works for inference model. The default value is- None. Explicitly setting this knob overwrites the configuration set by- levelknob.
- weights_prepack (bool) – Whether to perform weight prepack for convolution and linear to avoid oneDNN weights reorder. The default value is - None. Explicitly setting this knob overwrites the configuration set by- levelknob. For now, XPU doesn’t support weights prepack.
- replace_dropout_with_identity (bool) – Whether to replace - nn.Dropoutwith- nn.Identity. If replaced, the- aten::dropoutwon’t be included in the JIT graph. This may provide more fusion opportunites on the graph. This only works for inference model. The default value is- None. Explicitly setting this knob overwrites the configuration set by- levelknob.
- optimize_lstm (bool) – Whether to replace - nn.LSTMwith- IPEX LSTMwhich takes advantage of oneDNN kernels to get better performance. The default value is- None. Explicitly setting this knob overwrites the configuration set by- levelknob.
- split_master_weight_for_bf16 (bool) – Whether to split master weights update for BF16 training. This saves memory comparing to master weight update solution. Split master weights update methodology doesn’t support all optimizers. The default value is None. The default value is - None. Explicitly setting this knob overwrites the configuration set by- levelknob.
- fuse_update_step (bool) – Whether to use fused params update for training which have better performance. It doesn’t support all optimizers. The default value is - None. Explicitly setting this knob overwrites the configuration set by- levelknob.
- sample_input (tuple or torch.Tensor) – Whether to feed sample input data to ipex.optimize. The shape of input data will impact the block format of packed weight. If not feed a sample input, Intel® Extension for PyTorch* will pack the weight per some predefined heuristics. If feed a sample input with real input shape, Intel® Extension for PyTorch* can get best block format. 
- auto_kernel_selection (bool) – Different backends may have different performances with different dtypes/shapes. Default value is False. Intel® Extension for PyTorch* will try to optimize the kernel selection for better performance if this knob is set to - True. You might get better performance at the cost of extra memory usage. The default value is- None. Explicitly setting this knob overwrites the configuration set by- levelknob.
- graph_mode – (bool) [experimental]: It will automatically apply a combination of methods to generate graph or multiple subgraphs if True. The default value is - False.
 
- Returns
- Model and optimizer (if given) modified according to the - levelknob or other user settings.- conv+bnfolding may take place and- dropoutmay be replaced by- identity. In inference scenarios, convolutuon, linear and lstm will be replaced with the optimized counterparts in Intel® Extension for PyTorch* (weight prepack for convolution and linear) for good performance. In bfloat16 or float16 scenarios, parameters of convolution and linear will be casted to bfloat16 or float16 dtype.
 - Warning - Please invoke - optimizefunction AFTER loading weights to model via- model.load_state_dict(torch.load(PATH)).- Warning - Please invoke - optimizefunction BEFORE invoking DDP in distributed training scenario.- The - optimizefunction deepcopys the original model. If DDP is invoked before- optimizefunction, DDP is applied on the origin model, rather than the one returned from- optimizefunction. In this case, some operators in DDP, like allreduce, will not be invoked and thus may cause unpredictable accuracy loss.- Examples - >>> # bfloat16 inference case. >>> model = ... >>> model.load_state_dict(torch.load(PATH)) >>> model.eval() >>> optimized_model = ipex.optimize(model, dtype=torch.bfloat16) >>> # running evaluation step. >>> # bfloat16 training case. >>> optimizer = ... >>> model.train() >>> optimized_model, optimized_optimizer = ipex.optimize(model, dtype=torch.bfloat16, optimizer=optimizer) >>> # running training step. - torch.xpu.optimize() is an alternative of optimize API in Intel® Extension for PyTorch*, to provide identical usage for XPU device only. The motivation of adding this alias is to unify the coding style in user scripts base on torch.xpu modular. - Examples - >>> # bfloat16 inference case. >>> model = ... >>> model.load_state_dict(torch.load(PATH)) >>> model.eval() >>> optimized_model = torch.xpu.optimize(model, dtype=torch.bfloat16) >>> # running evaluation step. >>> # bfloat16 training case. >>> optimizer = ... >>> model.train() >>> optimized_model, optimized_optimizer = torch.xpu.optimize(model, dtype=torch.bfloat16, optimizer=optimizer) >>> # running training step. 
- ipex.get_fp32_math_mode(device='cpu')
- Get the current fpmath_mode setting. - Parameters
- device (string) – - cpu,- xpu
- Returns
- Fpmath mode The value will be - FP32MathMode.FP32,- FP32MathMode.BF32or- FP32MathMode.TF32(GPU ONLY). oneDNN fpmath mode will be disabled by default if dtype is set to- FP32MathMode.FP32. The implicit- FP32to- TF32data type conversion will be enabled if dtype is set to- FP32MathMode.TF32. The implicit- FP32to- BF16data type conversion will be enabled if dtype is set to- FP32MathMode.BF32.
 - Examples - >>> import intel_extension_for_pytorch as ipex >>> # to get the current fpmath mode >>> ipex.get_fp32_math_mode(device="xpu") - torch.xpu.get_fp32_math_mode()is an alternative function in Intel® Extension for PyTorch*, to provide identical usage for XPU device only. The motivation of adding this alias is to unify the coding style in user scripts base on- torch.xpumodular.- Examples - >>> import intel_extension_for_pytorch as ipex >>> # to get the current fpmath mode >>> torch.xpu.get_fp32_math_mode(device="xpu") 
- ipex.set_fp32_math_mode(mode=FP32MathMode.FP32, device='cpu')
- Enable or disable implicit data type conversion. - Parameters
- mode (FP32MathMode) – - FP32MathMode.FP32,- FP32MathMode.BF32or- FP32MathMode.TF32(GPU ONLY). oneDNN fpmath mode will be disabled by default if dtype is set to- FP32MathMode.FP32. The implicit- FP32to- TF32data type conversion will be enabled if dtype is set to- FP32MathMode.TF32. The implicit- FP32to- BF16data type conversion will be enabled if dtype is set to- FP32MathMode.BF32.
- device (string) – - cpu,- xpu
 
 - Examples - >>> import intel_extension_for_pytorch as ipex >>> # to enable the implicit data type conversion >>> ipex.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.BF32) >>> # to disable the implicit data type conversion >>> ipex.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.FP32) - torch.xpu.set_fp32_math_mode()is an alternative function in Intel® Extension for PyTorch*, to provide identical usage for XPU device only. The motivation of adding this alias is to unify the coding style in user scripts base on- torch.xpumodular.- Examples - >>> import intel_extension_for_pytorch as ipex >>> # to enable the implicit data type conversion >>> torch.xpu.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.BF32) >>> # to disable the implicit data type conversion >>> torch.xpu.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.FP32) 
- class ipex.verbose(level)
- On-demand oneDNN verbosing functionality - To make it easier to debug performance issues, oneDNN can dump verbose messages containing information like kernel size, input data size and execution duration while executing the kernel. The verbosing functionality can be invoked via an environment variable named DNNL_VERBOSE. However, this methodology dumps messages in all steps. Those are a large amount of verbose messages. Moreover, for investigating the performance issues, generally taking verbose messages for one single iteration is enough. - This on-demand verbosing functionality makes it possible to control scope for verbose message dumping. In the following example, verbose messages will be dumped out for the second inference only. - import intel_extension_for_pytorch as ipex model(data) with ipex.verbose(ipex.VERBOSE_ON): model(data) - Parameters
- level – - Verbose level - VERBOSE_OFF: Disable verbosing
- VERBOSE_ON: Enable verbosing
- VERBOSE_ON_CREATION: Enable verbosing, including oneDNN kernel creation
 
 
GPU-Specific
Miscellaneous
- torch.xpu.current_device() int
- Returns the index of a currently selected device. 
- torch.xpu.current_stream(device: Optional[Union[device, str, int]] = None) Stream
- Returns the currently selected - Streamfor a given device.
- class torch.xpu.device(device)
- Context-manager that changes the selected device. - Parameters
- device (torch.device or int) – device index to select. It’s a no-op if this argument is a negative integer or - None.
 
- torch.xpu.device_count() int
- Returns the number of XPUs device available. 
- class torch.xpu.device_of(obj)
- Context-manager that changes the current device to that of given object. - You can use both tensors and storages as arguments. If a given object is not allocated on a GPU, this is a no-op. - Parameters
- obj (Tensor or Storage) – object allocated on the selected device. 
 
- torch.xpu.getDeviceIdListForCard(card_id=-1) list
- Returns the device list of card_id. By default, return device list of the card which contains max number of devices. 
- torch.xpu.get_device_name(device: Optional[Union[device, str, int]] = None) str
- Gets the name of a device. - Parameters
- device (torch.device or int, optional) – device for which to return the name. This function is a no-op if this argument is a negative integer. It uses the current device, given by - current_device(), if- deviceis- None(default).
 
- torch.xpu.get_device_properties(device: Union[device, str, int])
- Gets the xpu properties of a device. - Parameters
- device (torch.device or int, optional) – device for which to return the device properties. It uses the current device, given by - current_device(), if- deviceis- None(default).
- Returns
- the properties of the device 
- Return type
- _DeviceProperties 
 
- torch.xpu.init()
- Initialize the XPU’s state. This is a Python API about lazy initialization that avoids initializing XPU until the first time it is accessed. You may need to call this function explicitly in very rare cases, since IPEX could call this initialization automatically when XPU functionality is on-demand. - Does nothing if call this function repeatedly. 
- torch.xpu.is_available() bool
- Returns a bool indicating if XPU is currently available. 
- torch.xpu.is_initialized()
- Returns whether XPU state has been initialized. 
- torch.xpu.set_device(device: Union[device, str, int]) None
- Sets the current device. - Usage of this function is discouraged in favor of - device. In most cases it’s better to use- xpu_VISIBLE_DEVICESenvironmental variable.- Parameters
- device (torch.device or int) – selected device. This function is a no-op if this argument is negative. 
 
- torch.xpu.stream(stream: Optional[Stream]) StreamContext
- Wrapper around the Context-manager StreamContext that selects a given stream. - Parameters
- stream (Stream) – selected stream. This manager is a no-op if it’s - None.
 - Note - Streams are per-device. If the selected stream is not on the current device, this function will also change the current device to match the stream. 
- torch.xpu.synchronize(device: Optional[Union[device, str, int]] = None) None
- Waits for all kernels in all streams on a XPU device to complete. - Parameters
- device (torch.device or int, optional) – device for which to synchronize. It uses the current device, given by - current_device(), if- deviceis- None(default).
 
Random Number Generator
- torch.xpu.get_rng_state(device: Union[int, str, device] = 'xpu') Tensor
- Returns the random number generator state of the specified GPU as a ByteTensor. - Parameters
- device (torch.device or int, optional) – The device to return the RNG state of. Default: - 'xpu'(i.e.,- torch.device('xpu'), the current XPU device).
 - Warning - This function eagerly initializes XPU. 
- torch.xpu.get_rng_state_all() List[Tensor]
- Returns a list of ByteTensor representing the random number states of all devices. 
- torch.xpu.set_rng_state(new_state: Tensor, device: Union[int, str, device] = 'xpu') None
- Sets the random number generator state of the specified GPU. - Parameters
- new_state (torch.ByteTensor) – The desired state 
- device (torch.device or int, optional) – The device to set the RNG state. Default: - 'xpu'(i.e.,- torch.device('xpu'), the current XPU device).
 
 
- torch.xpu.set_rng_state_all(new_states: Iterable[Tensor]) None
- Sets the random number generator state of all devices. - Parameters
- new_states (Iterable of torch.ByteTensor) – The desired state for each device 
 
- torch.xpu.manual_seed(seed: int) None
- Sets the seed for generating random numbers for the current GPU. It’s safe to call this function if XPU is not available; in that case, it is silently ignored. - Parameters
- seed (int) – The desired seed. 
 - Warning - If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use - manual_seed_all().
- torch.xpu.manual_seed_all(seed: int) None
- Sets the seed for generating random numbers on all GPUs. It’s safe to call this function if XPU is not available; in that case, it is silently ignored. - Parameters
- seed (int) – The desired seed. 
 
- torch.xpu.seed() None
- Sets the seed for generating random numbers to a random number for the current GPU. It’s safe to call this function if XPU is not available; in that case, it is silently ignored. - Warning - If you are working with a multi-GPU model, this function will only initialize the seed on one GPU. To initialize all GPUs, use - seed_all().
- torch.xpu.seed_all() None
- Sets the seed for generating random numbers to a random number on all GPUs. It’s safe to call this function if XPU is not available; in that case, it is silently ignored. 
- torch.xpu.initial_seed() int
- Returns the current random seed of the current GPU. - Warning - This function eagerly initializes XPU. 
Streams and events
- class torch.xpu.Stream(device=None, priority=0, **kwargs)
- record_event(event=None)
- Records an event. - Parameters
- event (Event, optional) – event to record. If not given, a new one will be allocated. 
- Returns
- Recorded event. 
 
 - synchronize()
- Wait for all the kernels in this stream to complete. 
 - wait_event(event)
- Makes all future work submitted to the stream wait for an event. - Parameters
- event (Event) – an event to wait for. 
 
 - wait_stream(stream)
- Synchronizes with another stream. - All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call complete. - Parameters
- stream (Stream) – a stream to synchronize. 
 - Note - This function returns without waiting for currently enqueued kernels in - stream: only future operations are affected.
 
- class torch.xpu.Event(**kwargs)
- elapsed_time(end_event)
- Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded. 
 - query()
- Checks if all work currently captured by event has completed. - Returns
- A boolean indicating if all work currently captured by event has completed. 
 
 - record(stream=None)
- Records the event in a given stream. - Uses - torch.xpu.current_stream()if no stream is specified.
 - synchronize()
- Waits for the event to complete. - Waits until the completion of all work currently captured in this event. This prevents the CPU thread from proceeding until the event completes. 
 - wait(stream=None)
- Makes all future work submitted to the given stream wait for this event. - Use - torch.xpu.current_stream()if no stream is specified.
 
Memory management
- torch.xpu.empty_cache() None
- Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in sysman toolkit. - Note - empty_cache()doesn’t increase the amount of GPU memory available for PyTorch. However, it may help reduce fragmentation of GPU memory in certain cases.
- torch.xpu.memory_stats(device: Optional[Union[device, str, int]] = None) Dict[str, Any]
- Returns a dictionary of XPU memory allocator statistics for a given device. - The return value of this function is a dictionary of statistics, each of which is a non-negative integer. - Core statistics: - "allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}": number of allocation requests received by the memory allocator.
- "allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}": amount of allocated memory.
- "segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}": number of reserved segments from- xpuMalloc().
- "reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}": amount of reserved memory.
- "active.{all,large_pool,small_pool}.{current,peak,allocated,freed}": number of active memory blocks.
- "active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}": amount of active memory.
- "inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}": number of inactive, non-releasable memory blocks.
- "inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}": amount of inactive, non-releasable memory.
 - For these core statistics, values are broken down as follows. - Pool type: - all: combined statistics across all memory pools.
- large_pool: statistics for the large allocation pool (as of October 2019, for size >= 1MB allocations).
- small_pool: statistics for the small allocation pool (as of October 2019, for size < 1MB allocations).
 - Metric type: - current: current value of this metric.
- peak: maximum value of this metric.
- allocated: historical total increase in this metric.
- freed: historical total decrease in this metric.
 - In addition to the core statistics, we also provide some simple event counters: - "num_alloc_retries": number of failed- xpuMalloccalls that result in a cache flush and retry.
- "num_ooms": number of out-of-memory errors thrown.
 - Parameters
- device (torch.device or int, optional) – selected device. Returns statistics for the current device, given by - current_device(), if- deviceis- None(default).
 
- torch.xpu.memory_summary(device: Optional[Union[device, str, int]] = None, abbreviated: bool = False) str
- Returns a human-readable printout of the current memory allocator statistics for a given device. - This can be useful to display periodically during training, or when handling out-of-memory exceptions. - Parameters
- device (torch.device or int, optional) – selected device. Returns printout for the current device, given by - current_device(), if- deviceis- None(default).
- abbreviated (bool, optional) – whether to return an abbreviated summary (default: False). 
 
 
- torch.xpu.memory_snapshot()
- Returns a snapshot of the XPU memory allocator state across all devices. - Interpreting the output of this function requires familiarity with the memory allocator internals. 
- torch.xpu.memory_allocated(device: Optional[Union[device, str, int]] = None) int
- Returns the current GPU memory occupied by tensors in bytes for a given device. - Parameters
- device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by - current_device(), if- deviceis- None(default).
 - Note - This is likely less than the amount shown in sysman toolkit since some unused memory can be held by the caching allocator and some context needs to be created on GPU. 
- torch.xpu.max_memory_allocated(device: Optional[Union[device, str, int]] = None) int
- Returns the maximum GPU memory occupied by tensors in bytes for a given device. - By default, this returns the peak allocated memory since the beginning of this program. - reset_peak_stats()can be used to reset the starting point in tracking this metric. For example, these two functions can measure the peak allocated memory usage of each iteration in a training loop.- Parameters
- device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by - current_device(), if- deviceis- None(default).
 
- torch.xpu.memory_reserved(device: Optional[Union[device, str, int]] = None) int
- Returns the current GPU memory managed by the caching allocator in bytes for a given device. - Parameters
- device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by - current_device(), if- deviceis- None(default).
 
- torch.xpu.max_memory_reserved(device: Optional[Union[device, str, int]] = None) int
- Returns the maximum GPU memory managed by the caching allocator in bytes for a given device. - By default, this returns the peak cached memory since the beginning of this program. - reset_peak_stats()can be used to reset the starting point in tracking this metric. For example, these two functions can measure the peak cached memory amount of each iteration in a training loop.- Parameters
- device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by - current_device(), if- deviceis- None(default).
 
- torch.xpu.reset_peak_memory_stats(device: Optional[Union[device, str, int]] = None) None
- Resets the “peak” stats tracked by the XPU memory allocator. - See - memory_stats()for details. Peak stats correspond to the “peak” key in each individual stat dict.- Parameters
- device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by - current_device(), if- deviceis- None(default).
 
- torch.xpu.memory_stats_as_nested_dict(device: Optional[Union[device, str, int]] = None) Dict[str, Any]
- Returns the result of - memory_stats()as a nested dictionary.
- torch.xpu.reset_accumulated_memory_stats(device: Optional[Union[device, str, int]] = None) None
- Resets the “accumulated” (historical) stats tracked by the XPU memory allocator. - See - memory_stats()for details. Accumulated stats correspond to the “allocated” and “freed” keys in each individual stat dict, as well as “num_alloc_retries” and “num_ooms”.- Parameters
- device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by - current_device(), if- deviceis- None(default).
 
C++ API
- 
enum xpu::FP32_MATH_MODE
- specifies the available DPCCP packet types - Values: - 
enumerator FP32
- set floating-point math mode to FP32. 
 - 
enumerator TF32
- set floating-point math mode to TF32. 
 - 
enumerator BF32
- set floating-point math mode to BF32. 
 
- 
enumerator FP32
- 
bool xpu::set_fp32_math_mode(FP32_MATH_MODE mode)
- Enable or disable implicit data type conversion. If mode is FP32MathMode.FP32 which means to disable the oneDNN fpmath mode. If mode is FP32MathMode.TF32 which means to enable the oneDNN fpmath mode by down converting to tf32 implicitly If mode is FP32MathMode.BF32 which means to enable the oneDNN fpmath mode by down converting to bfloat16 implicitly. - Parameters
- mode – (FP32MathMode): Only works for - FP32MathMode.FP32,- FP32MathMode.TF32and- FP32MathMode.BF32. oneDNN fpmath mode will be disabled by default if dtype is set to- FP32MathMode.FP32. The implicit FP32 to TF32 data type conversion will be enabled if dtype is set to `- FP32MathMode.TF32. The implicit FP32 to BF16 data type conversion will be enabled if dtype is set to `- FP32MathMode.BF32.
 
- 
sycl::queue &xpu::get_queue_from_stream(c10::Stream stream)
- Get a sycl queue from a c10 stream. Generate a dpcpp stream from c10 stream, and get dpcpp queue. - Parameters
- stream – c10 stream. 
- Returns
- : dpcpp queue. 
 
CPU-Specific
Miscellaneous
- ipex.enable_onednn_fusion(enabled)
- Enables or disables oneDNN fusion functionality. If enabled, oneDNN operators will be fused in runtime, when intel_extension_for_pytorch is imported. - Parameters
- enabled (bool) – Whether to enable oneDNN fusion functionality or not. Default value is - True.
 - Examples - >>> import intel_extension_for_pytorch as ipex >>> # to enable the oneDNN fusion >>> ipex.enable_onednn_fusion(True) >>> # to disable the oneDNN fusion >>> ipex.enable_onednn_fusion(False) 
Quantization
- ipex.quantization.prepare(model, configure, example_inputs, inplace=False)
- Prepare an FP32 torch.nn.Module model to do calibration or to convert to quantized model. - Parameters
- model (torch.nn.Module) – The FP32 model to be prepared. 
- configure (torch.quantization.qconfig.QConfig) – The observer settings about activation and weight. 
- example_inputs (tuple or torch.Tensor) – A tuple of example inputs that will be passed to the function while running to init quantization state. 
- inplace – (bool): It will change the given model in-place if True. The default value is - False.
 
- Returns
- torch.nn.Module 
 
- ipex.quantization.convert(model, inplace=False)
- Convert an FP32 prepared model to a model which will automatically insert fake quant before a quantizable module or operator. - Parameters
- model (torch.nn.Module) – The FP32 model to be convert. 
- inplace – (bool): It will change the given model in-place if True. The default value is - False.
 
- Returns
- torch.nn.Module 
 
Experimental API, introduction is avaiable at feature page.
- ipex.quantization.autotune(prepared_model, calib_dataloader, eval_func, sampling_sizes=[100], accuracy_criterion={'relative': 0.01}, tuning_time=0)
- Automatic accuracy-driven tuning helps users quickly find out the advanced recipe for INT8 inference. - Parameters
- prepared_model (torch.nn.Module) – the FP32 prepared model returned from ipex.quantization.prepare. 
- calib_dataloader (generator) – set a dataloader for calibration. 
- eval_func (function) – set a evaluation function. This function takes “model” as input parameter executes entire evaluation process with self contained metrics, and returns an accuracy value which is a scalar number. The higher the better. 
- sampling_sizes (list) – a list of sample sizes used in calibration, where the tuning algorithm would explore from. The default value is - [100].
- accuracy_criterion ({accuracy_criterion_type(str, 'relative' or 'absolute') – accuracy_criterion_value(float)}): set the maximum allowed accuracy loss, either relative or absolute. The default value is - {'relative': 0.01}.
- tuning_time (seconds) – tuning timeout. The default value is - 0which means early stop.
 
- Returns
- FP32 tuned model (torch.nn.Module) 
 
CPU Runtime
- ipex.cpu.runtime.is_runtime_ext_enabled()
- Helper function to check whether runtime extension is enabled or not. - Parameters
- None (None) – None 
- Returns
- Whether the runtime exetension is enabled or not. If the
- Intel OpenMP Library is preloaded, this API will return True. Otherwise, it will return False. 
 
- Return type
- bool 
 
- class ipex.cpu.runtime.CPUPool(core_ids: Optional[list] = None, node_id: Optional[int] = None)
- An abstraction of a pool of CPU cores used for intra-op parallelism. - Parameters
- core_ids (list) – A list of CPU cores’ ids used for intra-op parallelism. 
- node_id (int) – A numa node id with all CPU cores on the numa node. - node_iddoesn’t work if- core_idsis set.
 
- Returns
- Generated ipex.cpu.runtime.CPUPool object. 
- Return type
 
- class ipex.cpu.runtime.pin(cpu_pool: CPUPool)
- Apply the given CPU pool to the master thread that runs the scoped code region or the function/method def. - Parameters
- cpu_pool (ipex.cpu.runtime.CPUPool) – ipex.cpu.runtime.CPUPool object, contains all CPU cores used by the designated operations. 
- Returns
- Generated ipex.cpu.runtime.pin object which can be used as a with context or a function decorator. 
- Return type
 
- class ipex.cpu.runtime.MultiStreamModuleHint(*args, **kwargs)
- MultiStreamModuleHint is a hint to MultiStreamModule about how to split the inputs or concat the output. Each argument should be None, with type of int or a container which containes int or None such as: (0, None, …) or [0, None, …]. If the argument is None, it means this argument will not be split or concat. If the argument is with type int, its value means along which dim this argument will be split or concat. - Parameters
- *args – Variable length argument list. 
- **kwargs – Arbitrary keyword arguments. 
 
- Returns
- Generated ipex.cpu.runtime.MultiStreamModuleHint object. 
- Return type
 
- class ipex.cpu.runtime.MultiStreamModule(model, num_streams: ~typing.Union[int, str] = 'AUTO', cpu_pool: ~ipex.cpu.runtime.cpupool.CPUPool = <ipex.cpu.runtime.cpupool.CPUPool object>, concat_output: bool = True, input_split_hint: ~ipex.cpu.runtime.multi_stream.MultiStreamModuleHint = <ipex.cpu.runtime.multi_stream.MultiStreamModuleHint object>, output_concat_hint: ~ipex.cpu.runtime.multi_stream.MultiStreamModuleHint = <ipex.cpu.runtime.multi_stream.MultiStreamModuleHint object>)
- MultiStreamModule supports inference with multi-stream throughput mode. - If the number of cores inside - cpu_poolis divisible by- num_streams, the cores will be allocated equally to each stream. If the number of cores inside- cpu_poolis not divisible by- num_streamswith remainder N, one extra core will be allocated to the first N streams. We suggest to set the- num_streamsas divisor of core number inside- cpu_pool.- If the inputs’ batchsize is larger than and divisible by - num_streams, the batchsize will be allocated equally to each stream. If batchsize is not divisible by- num_streamswith remainder N, one extra piece will be allocated to the first N streams. If the inputs’ batchsize is less than- num_streams, only the first batchsize’s streams are used with mini batch as one. We suggest to set inputs’ batchsize larger than and divisible by- num_streams. If you don’t want to tune the num of streams and leave it as “AUTO”, we suggest to set inputs’ batchsize larger than and divisible by number of cores.- Parameters
- model (torch.jit.ScriptModule or torch.nn.Module) – The input model. 
- num_streams (Union[int, str]) – Number of instances (int) or “AUTO” (str). “AUTO” means the stream number will be selected automatically. Although “AUTO” usually provides a reasonable performance, it may still not be optimal for some cases which means manual tuning for number of streams is needed for this case. 
- cpu_pool (ipex.cpu.runtime.CPUPool) – An ipex.cpu.runtime.CPUPool object, contains all CPU cores used to run multi-stream inference. 
- concat_output (bool) – A flag indicates whether the output of each stream will be concatenated or not. The default value is True. Note: if the output of each stream can’t be concatenated, set this flag to false to get the raw output (a list of each stream’s output). 
- input_split_hint (MultiStreamModuleHint) – Hint to MultiStreamModule about how to split the inputs. 
- output_concat_hint (MultiStreamModuleHint) – Hint to MultiStreamModule about how to concat the outputs. 
 
- Returns
- Generated ipex.cpu.runtime.MultiStreamModule object. 
- Return type
 
- class ipex.cpu.runtime.Task(module, cpu_pool: CPUPool)
- An abstraction of computation based on PyTorch module and is scheduled asynchronously. - Parameters
- model (torch.jit.ScriptModule or torch.nn.Module) – The input module. 
- cpu_pool (ipex.cpu.runtime.CPUPool) – An ipex.cpu.runtime.CPUPool object, contains all CPU cores used to run Task asynchronously. 
 
- Returns
- Generated ipex.cpu.runtime.Task object. 
- Return type
 
- ipex.cpu.runtime.get_core_list_of_node_id(node_id)
- Helper function to get the CPU cores’ ids of the input numa node. - Parameters
- node_id (int) – Input numa node id. 
- Returns
- List of CPU cores’ ids on this numa node. 
- Return type
- list