API Documentation
General
- ipex.optimize(model, dtype=None, optimizer=None, level='O1', inplace=False, conv_bn_folding=None, linear_bn_folding=None, weights_prepack=None, replace_dropout_with_identity=None, optimize_lstm=None, split_master_weight_for_bf16=None, fuse_update_step=None, auto_kernel_selection=None, sample_input=None, graph_mode=None, concat_linear=None)
Apply optimizations at Python frontend to the given model (nn.Module), as well as the given optimizer (optional). If the optimizer is given, optimizations will be applied for training. Otherwise, optimization will be applied for inference. Optimizations include
conv+bn
folding (for inference only), weight prepacking and so on.Weight prepacking is a technique to accelerate performance of oneDNN operators. In order to achieve better vectorization and cache reuse, onednn uses a specific memory layout called
blocked layout
. Although the calculation itself withblocked layout
is fast enough, from memory usage perspective it has drawbacks. Running with theblocked layout
, oneDNN splits one or several dimensions of data into blocks with fixed size each time the operator is executed. More details information about oneDNN data mermory format is available at oneDNN manual. To reduce this overhead, data will be converted to predefined block shapes prior to the execution of oneDNN operator execution. In runtime, if the data shape matches oneDNN operator execution requirements, oneDNN won’t perform memory layout conversion but directly go to calculation. Through this methodology, calledweight prepacking
, it is possible to avoid runtime weight data format convertion and thus increase performance.- Parameters:
model (torch.nn.Module) – User model to apply optimizations on.
dtype (torch.dtype) – Only works for
torch.bfloat16
andtorch.half
a.k.atorch.float16
. Model parameters will be casted totorch.bfloat16
ortorch.half
according to dtype of settings. The default value is None, meaning do nothing. Note: Data type conversion is only applied tonn.Conv2d
,nn.Linear
andnn.ConvTranspose2d
for both training and inference cases. For inference mode, additional data type conversion is applied to the weights ofnn.Embedding
andnn.LSTM
.optimizer (torch.optim.Optimizer) – User optimizer to apply optimizations on, such as SGD. The default value is
None
, meaning inference case.level (string) –
"O0"
or"O1"
. No optimizations are applied with"O0"
. The optimizer function just returns the original model and optimizer. With"O1"
, the following optimizations are applied: conv+bn folding, weights prepack, dropout removal (inferenc model), master weight split and fused optimizer update step (training model). The optimization options can be further overridden by setting the following options explicitly. The default value is"O1"
.inplace (bool) – Whether to perform inplace optimization. Default value is
False
.conv_bn_folding (bool) – Whether to perform
conv_bn
folding. It only works for inference model. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.linear_bn_folding (bool) – Whether to perform
linear_bn
folding. It only works for inference model. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.weights_prepack (bool) – Whether to perform weight prepack for convolution and linear to avoid oneDNN weights reorder. The default value is
None
. Explicitly setting this knob overwrites the configuration set bylevel
knob. Weight prepack works for CPU only.replace_dropout_with_identity (bool) – Whether to replace
nn.Dropout
withnn.Identity
. If replaced, theaten::dropout
won’t be included in the JIT graph. This may provide more fusion opportunites on the graph. This only works for inference model. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.optimize_lstm (bool) – Whether to replace
nn.LSTM
withIPEX LSTM
which takes advantage of oneDNN kernels to get better performance. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.split_master_weight_for_bf16 (bool) – Whether to split master weights update for BF16 training. This saves memory comparing to master weight update solution. Split master weights update methodology doesn’t support all optimizers. The default value is None. The default value is
None
. Explicitly setting this knob overwrites the configuration set bylevel
knob.fuse_update_step (bool) – Whether to use fused params update for training which have better performance. It doesn’t support all optimizers. The default value is
None
. Explicitly setting this knob overwrites the configuration set bylevel
knob.sample_input (tuple or torch.Tensor) – Whether to feed sample input data to ipex.optimize. The shape of input data will impact the block format of packed weight. If not feed a sample input, Intel® Extension for PyTorch* will pack the weight per some predefined heuristics. If feed a sample input with real input shape, Intel® Extension for PyTorch* can get best block format. Sample input works for CPU only.
auto_kernel_selection (bool) – Different backends may have different performances with different dtypes/shapes. Default value is False. Intel® Extension for PyTorch* will try to optimize the kernel selection for better performance if this knob is set to
True
. You might get better performance at the cost of extra memory usage. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob. Auto kernel selection works for CPU only.graph_mode – (bool) [prototype]: It will automatically apply a combination of methods to generate graph or multiple subgraphs if True. The default value is
False
.concat_linear (bool) – Whether to perform
concat_linear
. It only works for inference model. The default value isNone
. Explicitly setting this knob overwrites the configuration set bylevel
knob.
- Returns:
Model and optimizer (if given) modified according to the
level
knob or other user settings.conv+bn
folding may take place anddropout
may be replaced byidentity
. In inference scenarios, convolutuon, linear and lstm will be replaced with the optimized counterparts in Intel® Extension for PyTorch* (weight prepack for convolution and linear) for good performance. In bfloat16 or float16 scenarios, parameters of convolution and linear will be casted to bfloat16 or float16 dtype.
Warning
Please invoke
optimize
function BEFORE invoking DDP in distributed training scenario.The
optimize
function deepcopys the original model. If DDP is invoked beforeoptimize
function, DDP is applied on the origin model, rather than the one returned fromoptimize
function. In this case, some operators in DDP, like allreduce, will not be invoked and thus may cause unpredictable accuracy loss.Examples
>>> # bfloat16 inference case. >>> model = ... >>> model.load_state_dict(torch.load(PATH)) >>> model.eval() >>> optimized_model = ipex.optimize(model, dtype=torch.bfloat16) >>> # running evaluation step. >>> # bfloat16 training case. >>> optimizer = ... >>> model.train() >>> optimized_model, optimized_optimizer = ipex.optimize(model, dtype=torch.bfloat16, optimizer=optimizer) >>> # running training step.
torch.xpu.optimize() is an alternative of optimize API in Intel® Extension for PyTorch*, to provide identical usage for XPU device only. The motivation of adding this alias is to unify the coding style in user scripts base on torch.xpu modular.
Examples
>>> # bfloat16 inference case. >>> model = ... >>> model.load_state_dict(torch.load(PATH)) >>> model.eval() >>> optimized_model = torch.xpu.optimize(model, dtype=torch.bfloat16) >>> # running evaluation step. >>> # bfloat16 training case. >>> optimizer = ... >>> model.train() >>> optimized_model, optimized_optimizer = torch.xpu.optimize(model, dtype=torch.bfloat16, optimizer=optimizer) >>> # running training step.
- ipex.optimize_transformers(model, optimizer=None, dtype=torch.float32, inplace=False, device='cpu', quantization_config=None, qconfig_summary_file=None, low_precision_checkpoint=None, sample_inputs=None, deployment_mode=True)
Apply optimizations at Python frontend to the given transformers model (nn.Module). This API focus on transformers models, especially for generation tasks inference. Well supported model family: Llama, GPT-J, GPT-Neox, OPT, Falcon.
- Parameters:
model (torch.nn.Module) – User model to apply optimizations.
optimizer (torch.optim.Optimizer) – User optimizer to apply optimizations on, such as SGD. The default value is
None
, meaning inference case.dtype (torch.dtype) – Now it works for
torch.bfloat16
andtorch.float
. The default value istorch.float
. When working with quantization, it means the mixed dtype with quantization.inplace (bool) – Whether to perform inplace optimization. Default value is
False
.device (str) – Specifying the device on which the optimization will be performed-either ‘CPU’ or ‘XPU.
quantization_config (object) – Defining the IPEX quantization recipe (Weight only quant or static quant). Default value is
None
. Once used, meaning using IPEX quantizatization model for model.generate().(only works on CPU)qconfig_summary_file (str) – Path to the IPEX static quantization config json file. (only works on CPU) Default value is
None
. Work with quantization_config under static quantization use case. Need to do IPEX static quantization calibration and generate this file. (only works on CPU)low_precision_checkpoint (dict or tuple of dict) – For weight only quantization with INT4 weights. If it’s a dict, it should be the state_dict of checkpoint (.pt) generated by GPTQ, etc. If a tuple is provided, it should be (checkpoint, checkpoint config), where checkpoint is the state_dict and checkpoint config is dict specifying keys of groups in the state_dict. The default config is { groups: ‘-1’ }. Change the values of the dict to make a custom config. Weights shape should be N by K and they are quantized to UINT4 and compressed along K, then stored as torch.int32. Zero points are also UINT4 and stored as INT32. Scales and bias are floating point values. Bias is optional. If bias is not in state dict, bias of the original model is used. Only per-channel quantization of weight is supported (group size = -1). Default value is
None
.sample_inputs (Tuple tensors) – sample inputs used for model quantization or torchscript. Default value is
None
, and for well supported model, we provide this sample inputs automaticlly. (only works on CPU)deployment_mode (bool) – Whether to apply the optimized model for deployment of model generation. It means there is no need to further apply optimization like torchscirpt. Default value is
True
. (only works on CPU)
- Returns:
optimized model object for model.generate(), also workable with model.forward
Warning
Please invoke
optimize_transformers
function AFTER invoking DeepSpeed in Tensor Parallel inference scenario.Examples
>>> # bfloat16 generation inference case. >>> model = ... >>> model.load_state_dict(torch.load(PATH)) >>> model.eval() >>> optimized_model = ipex.optimize_transformers(model, dtype=torch.bfloat16) >>> optimized_model.generate()
- ipex.get_fp32_math_mode(device='cpu')
Get the current fpmath_mode setting.
- Parameters:
device (string) –
cpu
,xpu
- Returns:
Fpmath mode The value will be
FP32MathMode.FP32
,FP32MathMode.BF32
orFP32MathMode.TF32
(GPU ONLY). oneDNN fpmath mode will be disabled by default if dtype is set toFP32MathMode.FP32
. The implicitFP32
toTF32
data type conversion will be enabled if dtype is set toFP32MathMode.TF32
. The implicitFP32
toBF16
data type conversion will be enabled if dtype is set toFP32MathMode.BF32
.
Examples
>>> import intel_extension_for_pytorch as ipex >>> # to get the current fpmath mode >>> ipex.get_fp32_math_mode(device="xpu")
torch.xpu.get_fp32_math_mode()
is an alternative function in Intel® Extension for PyTorch*, to provide identical usage for XPU device only. The motivation of adding this alias is to unify the coding style in user scripts base ontorch.xpu
modular.Examples
>>> import intel_extension_for_pytorch as ipex >>> # to get the current fpmath mode >>> torch.xpu.get_fp32_math_mode(device="xpu")
- ipex.set_fp32_math_mode(mode=FP32MathMode.FP32, device='cpu')
Enable or disable implicit data type conversion.
- Parameters:
mode (FP32MathMode) –
FP32MathMode.FP32
,FP32MathMode.BF32
orFP32MathMode.TF32
(GPU ONLY). oneDNN fpmath mode will be disabled by default if dtype is set toFP32MathMode.FP32
. The implicitFP32
toTF32
data type conversion will be enabled if dtype is set toFP32MathMode.TF32
. The implicitFP32
toBF16
data type conversion will be enabled if dtype is set toFP32MathMode.BF32
.device (string) –
cpu
,xpu
Examples
>>> import intel_extension_for_pytorch as ipex >>> # to enable the implicit data type conversion >>> ipex.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.BF32) >>> # to disable the implicit data type conversion >>> ipex.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.FP32)
torch.xpu.set_fp32_math_mode()
is an alternative function in Intel® Extension for PyTorch*, to provide identical usage for XPU device only. The motivation of adding this alias is to unify the coding style in user scripts base ontorch.xpu
modular.Examples
>>> import intel_extension_for_pytorch as ipex >>> # to enable the implicit data type conversion >>> torch.xpu.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.BF32) >>> # to disable the implicit data type conversion >>> torch.xpu.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.FP32)
Miscellaneous
- torch.xpu.current_device() int
Returns the index of a currently selected device.
- torch.xpu.current_stream(device: device | str | int | None = None) Stream
Returns the currently selected
Stream
for a given device.
- class torch.xpu.device(device: Any)
Context-manager that changes the selected device and a wrapper encapsules the sycl device from runtime.
- Parameters:
device (torch.device or int) – device index to select. It’s a no-op if this argument is a negative integer or
None
.
- torch.xpu.device_count() int
Returns the number of XPUs device available.
- class torch.xpu.device_of(obj)
Context-manager that changes the current device to that of given object.
You can use both tensors and storages as arguments. If a given object is not allocated on a GPU, this is a no-op.
- Parameters:
obj (Tensor or Storage) – object allocated on the selected device.
- torch.xpu.get_device_name(device: device | str | int | None = None) str
Gets the name of a device.
- Parameters:
device (torch.device or int, optional) – device for which to return the name. This function is a no-op if this argument is a negative integer. It uses the current device, given by
current_device()
, ifdevice
isNone
(default).
- torch.xpu.get_device_properties(device: device | str | int)
Gets the xpu properties of a device.
- Parameters:
device (torch.device or int, optional) – device for which to return the device properties. It uses the current device, given by
current_device()
, ifdevice
isNone
(default).- Returns:
the properties of the device
- Return type:
_DeviceProperties
- torch.xpu.init()
Initialize the XPU’s state. This is a Python API about lazy initialization that avoids initializing XPU until the first time it is accessed. You may need to call this function explicitly in very rare cases, since IPEX could call this initialization automatically when XPU functionality is on-demand.
Does nothing if call this function repeatedly.
- torch.xpu.is_available() bool
Returns a bool indicating if XPU is currently available.
- torch.xpu.is_initialized() bool
Returns whether XPU state has been initialized.
- torch.xpu.set_device(device: device | str | int) None
Sets the current device.
Usage of this function is discouraged in favor of
device
. In most cases it’s better to useZE_AFFINITY_MASK
environmental variable to restrict which devices are visible.- Parameters:
device (torch.device or int) – selected device. This function is a no-op if this argument is negative.
- torch.xpu.stream(stream: Stream | None) StreamContext
Wrapper around the Context-manager StreamContext that selects a given stream.
- Parameters:
stream (Stream) – selected stream. This manager is a no-op if it’s
None
.
Note
Streams are per-device. If the selected stream is not on the current device, this function will also change the current device to match the stream.
- torch.xpu.synchronize(device: device | str | int | None = None) None
Waits for all kernels in all streams on a XPU device to complete.
- Parameters:
device (torch.device or int, optional) – device for which to synchronize. It uses the current device, given by
current_device()
, ifdevice
isNone
(default).
- torch.xpu.fp8.fp8.fp8_autocast(enabled: bool = False, fp8_recipe: DelayedScaling | None = None) None
Random Number Generator
- torch.xpu.get_rng_state(device: int | str | device = 'xpu') Tensor
Returns the random number generator state of the specified GPU as a ByteTensor.
- Parameters:
device (torch.device or int, optional) – The device to return the RNG state of. Default:
'xpu'
(i.e.,torch.device('xpu')
, the current XPU device).
Warning
This function eagerly initializes XPU.
- torch.xpu.get_rng_state_all() List[Tensor]
Returns a list of ByteTensor representing the random number states of all devices.
- torch.xpu.set_rng_state(new_state: Tensor, device: int | str | device = 'xpu') None
Sets the random number generator state of the specified GPU.
- Parameters:
new_state (torch.ByteTensor) – The desired state
device (torch.device or int, optional) – The device to set the RNG state. Default:
'xpu'
(i.e.,torch.device('xpu')
, the current XPU device).
- torch.xpu.set_rng_state_all(new_states: Iterable[Tensor]) None
Sets the random number generator state of all devices.
- Parameters:
new_states (Iterable of torch.ByteTensor) – The desired state for each device
- torch.xpu.manual_seed(seed: int) None
Sets the seed for generating random numbers for the current GPU. It’s safe to call this function if XPU is not available; in that case, it is silently ignored.
- Parameters:
seed (int) – The desired seed.
Warning
If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use
manual_seed_all()
.
- torch.xpu.manual_seed_all(seed: int) None
Sets the seed for generating random numbers on all GPUs. It’s safe to call this function if XPU is not available; in that case, it is silently ignored.
- Parameters:
seed (int) – The desired seed.
- torch.xpu.seed() None
Sets the seed for generating random numbers to a random number for the current GPU. It’s safe to call this function if XPU is not available; in that case, it is silently ignored.
Warning
If you are working with a multi-GPU model, this function will only initialize the seed on one GPU. To initialize all GPUs, use
seed_all()
.
- torch.xpu.seed_all() None
Sets the seed for generating random numbers to a random number on all GPUs. It’s safe to call this function if XPU is not available; in that case, it is silently ignored.
- torch.xpu.initial_seed() int
Returns the current random seed of the current GPU.
Warning
This function eagerly initializes XPU.
Streams and events
- class torch.xpu.Stream(device=None, priority=0, **kwargs)
- record_event(event=None)
Records an event.
- Parameters:
event (Event, optional) – event to record. If not given, a new one will be allocated.
- Returns:
Recorded event.
- property sycl_queue
-> PyCapsule
Returns the sycl queue of the corresponding Stream in a
PyCapsule
, which encapsules a void pointer address. Its capsule name istorch.xpu.Stream.sycl_queue
.- Type:
sycl_queue(self)
- synchronize()
Wait for all the kernels in this stream to complete.
- wait_event(event)
Makes all future work submitted to the stream wait for an event.
- Parameters:
event (Event) – an event to wait for.
- wait_stream(stream)
Synchronizes with another stream.
All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call complete.
- Parameters:
stream (Stream) – a stream to synchronize.
Note
This function returns without waiting for currently enqueued kernels in
stream
: only future operations are affected.
- class torch.xpu.Event(**kwargs)
- elapsed_time(end_event)
Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded.
- query()
Checks if all work currently captured by event has completed.
- Returns:
A boolean indicating if all work currently captured by event has completed.
- record(stream=None)
Records the event in a given stream.
Uses
torch.xpu.current_stream()
if no stream is specified.
- synchronize()
Waits for the event to complete.
Waits until the completion of all work currently captured in this event. This prevents the CPU thread from proceeding until the event completes.
- wait(stream=None)
Makes all future work submitted to the given stream wait for this event.
Use
torch.xpu.current_stream()
if no stream is specified.
Memory management
- torch.xpu.empty_cache() None
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in sysman toolkit.
Note
empty_cache()
doesn’t increase the amount of GPU memory available for PyTorch. However, it may help reduce fragmentation of GPU memory in certain cases. See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.memory_stats(device: device | str | int | None = None) Dict[str, Any]
Returns a dictionary of XPU memory allocator statistics for a given device.
The return value of this function is a dictionary of statistics, each of which is a non-negative integer.
Core statistics:
"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"
: number of allocation requests received by the memory allocator."allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"
: amount of allocated memory."segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"
: number of reserved segments fromxpuMalloc()
."reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"
: amount of reserved memory."active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"
: number of active memory blocks."active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"
: amount of active memory."inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"
: number of inactive, non-releasable memory blocks."inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"
: amount of inactive, non-releasable memory.
For these core statistics, values are broken down as follows.
Pool type:
all
: combined statistics across all memory pools.large_pool
: statistics for the large allocation pool (as of October 2019, for size >= 1MB allocations).small_pool
: statistics for the small allocation pool (as of October 2019, for size < 1MB allocations).
Metric type:
current
: current value of this metric.peak
: maximum value of this metric.allocated
: historical total increase in this metric.freed
: historical total decrease in this metric.
In addition to the core statistics, we also provide some simple event counters:
"num_alloc_retries"
: number of failedxpuMalloc
calls that result in a cache flush and retry."num_ooms"
: number of out-of-memory errors thrown.
- Parameters:
device (torch.device or int, optional) – selected device. Returns statistics for the current device, given by
current_device()
, ifdevice
isNone
(default).
Note
See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.memory_summary(device: device | str | int | None = None, abbreviated: bool = False) str
Returns a human-readable printout of the current memory allocator statistics for a given device.
This can be useful to display periodically during training, or when handling out-of-memory exceptions.
- Parameters:
device (torch.device or int, optional) – selected device. Returns printout for the current device, given by
current_device()
, ifdevice
isNone
(default).abbreviated (bool, optional) – whether to return an abbreviated summary (default: False).
Note
See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.memory_snapshot()
Returns a snapshot of the XPU memory allocator state across all devices.
Interpreting the output of this function requires familiarity with the memory allocator internals.
Note
See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.memory_allocated(device: device | str | int | None = None) int
Returns the current GPU memory occupied by tensors in bytes for a given device.
- Parameters:
device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by
current_device()
, ifdevice
isNone
(default).
Note
This is likely less than the amount shown in sysman toolkit since some unused memory can be held by the caching allocator and some context needs to be created on GPU. See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.max_memory_allocated(device: device | str | int | None = None) int
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
By default, this returns the peak allocated memory since the beginning of this program.
reset_peak_stats()
can be used to reset the starting point in tracking this metric. For example, these two functions can measure the peak allocated memory usage of each iteration in a training loop.- Parameters:
device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by
current_device()
, ifdevice
isNone
(default).
Note
See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.memory_reserved(device: device | str | int | None = None) int
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
- Parameters:
device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by
current_device()
, ifdevice
isNone
(default).
Note
See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.max_memory_reserved(device: device | str | int | None = None) int
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
By default, this returns the peak cached memory since the beginning of this program.
reset_peak_stats()
can be used to reset the starting point in tracking this metric. For example, these two functions can measure the peak cached memory amount of each iteration in a training loop.- Parameters:
device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by
current_device()
, ifdevice
isNone
(default).
Note
See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.reset_peak_memory_stats(device: device | str | int | None = None) None
Resets the “peak” stats tracked by the XPU memory allocator.
See
memory_stats()
for details. Peak stats correspond to the “peak” key in each individual stat dict.- Parameters:
device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by
current_device()
, ifdevice
isNone
(default).
Note
See Memory Management [GPU] for more details about GPU memory management.
- torch.xpu.memory_stats_as_nested_dict(device: device | str | int | None = None) Dict[str, Any]
Returns the result of
memory_stats()
as a nested dictionary.
- torch.xpu.reset_accumulated_memory_stats(device: device | str | int | None = None) None
Resets the “accumulated” (historical) stats tracked by the XPU memory allocator.
See
memory_stats()
for details. Accumulated stats correspond to the “allocated” and “freed” keys in each individual stat dict, as well as “num_alloc_retries” and “num_ooms”.- Parameters:
device (torch.device or int, optional) – selected device. Returns statistic for the current device, given by
current_device()
, ifdevice
isNone
(default).
Note
See Memory Management [GPU] for more details about GPU memory management.
C++ API
-
enum xpu::FP32_MATH_MODE
specifies the available DPCCP packet types
Values:
-
enumerator FP32
set floating-point math mode to FP32.
-
enumerator TF32
set floating-point math mode to TF32.
-
enumerator BF32
set floating-point math mode to BF32.
-
enumerator FP32_MATH_MODE_MIN
-
enumerator FP32
-
bool xpu::set_fp32_math_mode(FP32_MATH_MODE mode)
Enable or disable implicit floating-point type conversion during computation for oneDNN kernels. Set
FP32MathMode.FP32
will disable floating-point type conversion. SetFP32MathMode.TF32
will enable implicit down-conversion fromfp32
totf32
. SetFP32MathMode.BF32
will enable implicit down-conversion fromfp32
tobf16
.refer to Primitive Attributes: floating -point math mode for detail description about the definition and numerical behavior of floating-point math modes.
- Parameters:
mode – (FP32MathMode): Only works for
FP32MathMode.FP32
,FP32MathMode.TF32
andFP32MathMode.BF32
. oneDNN fpmath mode will be disabled by default if dtype is set toFP32MathMode.FP32
. The implicit FP32 to TF32 data type conversion will be enabled if dtype is set to `FP32MathMode.TF32
. The implicit FP32 to BF16 data type conversion will be enabled if dtype is set to `FP32MathMode.BF32
.
-
sycl::queue &xpu::get_queue_from_stream(c10::Stream stream)
Get a sycl queue from a c10 stream. Generate a dpcpp stream from c10 stream, and get dpcpp queue.
- Parameters:
stream – c10 stream.
- Returns:
: dpcpp queue.