Examples

These examples will help you get started using Intel® Extension for PyTorch* with Intel GPUs.

Note: For examples on Intel CPUs, check these CPU examples.

Note: You need to install torchvision and transformers to run with the examples.

Python

Training

Single-Instance Training

Code Changes Highlight

You’ll only need to change a few lines of codes use Intel® Extension for PyTorch* on training, as shown:

  1. Use the ipex.optimize function, which applies optimizations against the model object, as well as an optimizer object.

  2. Use Auto Mixed Precision (AMP) with BFloat16 data type.

  3. Convert input tensors, loss criterion and model to XPU.

The complete examples for Float32 and BFloat16 training on single-instance are illustrated in the sections.

...
import torch
import intel_extension_for_pytorch as ipex
...
model = Model()
criterion = ...
optimizer = ...
model.train()
# For Float32
model, optimizer = ipex.optimize(model, optimizer=optimizer)
# For BFloat16
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
...
# For Float32
output = model(data)
...
# For BFloat16
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
  output = model(input)
...
Complete - Float32 Example
import torch
import torchvision

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

LR = 0.001
DOWNLOAD = True
DATA = "datasets/cifar10/"

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
train_dataset = torchvision.datasets.CIFAR10(
    root=DATA,
    train=True,
    transform=transform,
    download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)

model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
model.train()
######################## code changes #######################
model = model.to("xpu")
criterion = criterion.to("xpu")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
######################## code changes #######################

for batch_idx, (data, target) in enumerate(train_loader):
    ########## code changes ##########
    data = data.to("xpu")
    target = target.to("xpu")
    ########## code changes ##########
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print(batch_idx)
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    "checkpoint.pth",
)

print("Execution finished")
Complete - BFloat16 Example
import torch
import torchvision

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

LR = 0.001
DOWNLOAD = True
DATA = "datasets/cifar10/"

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
train_dataset = torchvision.datasets.CIFAR10(
    root=DATA,
    train=True,
    transform=transform,
    download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128)

model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
model.train()
##################################### code changes ################################
model = model.to("xpu")
criterion = criterion.to("xpu")
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
##################################### code changes ################################

for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    ######################### code changes #########################
    data = data.to("xpu")
    target = target.to("xpu")
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    ######################### code changes #########################
        output = model(data)
        loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print(batch_idx)
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    "checkpoint.pth",
)

print("Execution finished")

Inference

Get additional performance boosts for your computer vision and NLP workloads by applying the Intel® Extension for PyTorch* optimize function against your model object.

Float32

Imperative Mode
Resnet50
import torch
import torchvision.models as models

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
data = torch.rand(1, 3, 224, 224)

######## code changes #######
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model)
######## code changes #######

with torch.no_grad():
    model(data)

print("Execution finished")
BERT
import torch
from transformers import BertModel

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])

######## code changes #######
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model)
######## code changes #######

with torch.no_grad():
    model(data)

print("Execution finished")
TorchScript Mode

We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.

Resnet50
import torch
import torchvision.models as models

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
data = torch.rand(1, 3, 224, 224)

######## code changes #######
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model)
######## code changes #######

with torch.no_grad():
    d = torch.rand(1, 3, 224, 224)
    ##### code changes #####
    d = d.to("xpu")
    ##### code changes #####
    model = torch.jit.trace(model, d)
    model = torch.jit.freeze(model)

    model(data)

print("Execution finished")
BERT
import torch
from transformers import BertModel

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])

######## code changes #######
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model)
######## code changes #######

with torch.no_grad():
    d = torch.randint(vocab_size, size=[batch_size, seq_length])
    ##### code changes #####
    d = d.to("xpu")
    ##### code changes #####
    model = torch.jit.trace(model, (d,), strict=False)
    model = torch.jit.freeze(model)

    model(data)

print("Execution finished")

BFloat16

The optimize function works for both Float32 and BFloat16 data type. For BFloat16 data type, set the dtype parameter to torch.bfloat16. We recommend using Auto Mixed Precision (AMP) with BFloat16 data type.

Imperative Mode
Resnet50
import torch
import torchvision.models as models

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
data = torch.rand(1, 3, 224, 224)

#################### code changes #################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.bfloat16)
#################### code changes #################

with torch.no_grad():
    ############################# code changes #####################
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    ############################ code changes ######################
        model(data)

print("Execution finished")
BERT
import torch
from transformers import BertModel

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])

#################### code changes #################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.bfloat16)
#################### code changes #################

with torch.no_grad():
    ########################### code changes ########################
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    ########################### code changes ########################
        model(data)

print("Execution finished")
TorchScript Mode

We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.

Resnet50
import torch
import torchvision.models as models

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
data = torch.rand(1, 3, 224, 224)

#################### code changes #################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.bfloat16)
#################### code changes #################

with torch.no_grad():
    d = torch.rand(1, 3, 224, 224)
    ############################# code changes #####################
    d = d.to("xpu")
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    ############################# code changes #####################
        model = torch.jit.trace(model, d)
        model = torch.jit.freeze(model)
        model(data)

print("Execution finished")
BERT
import torch
from transformers import BertModel

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])

#################### code changes #################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.bfloat16)
#################### code changes #################

with torch.no_grad():
    d = torch.randint(vocab_size, size=[batch_size, seq_length])
    ############################# code changes #####################
    d = d.to("xpu")
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    ############################# code changes #####################
        model = torch.jit.trace(model, (d,), strict=False)
        model = torch.jit.freeze(model)

        model(data)

print("Execution finished")

Float16

The optimize function works for both Float32 and Float16 data type. For Float16 data type, set the dtype parameter to torch.float16. We recommend using Auto Mixed Precision (AMP) with Float16 data type.

Imperative Mode
Resnet50
import torch
import torchvision.models as models

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
data = torch.rand(1, 3, 224, 224)

#################### code changes ################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.float16)
#################### code changes ################

with torch.no_grad():
    ############################# code changes #####################
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.float16):
    ############################# code changes #####################
        model(data)

print("Execution finished")
BERT
import torch
from transformers import BertModel

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])

#################### code changes ################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.float16)
#################### code changes ################

with torch.no_grad():
    ############################# code changes #####################
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.float16):
    ############################# code changes #####################
        model(data)

print("Execution finished")
TorchScript Mode

We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.

Resnet50
import torch
import torchvision.models as models

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
data = torch.rand(1, 3, 224, 224)

#################### code changes ################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.float16)
#################### code changes ################

with torch.no_grad():
    d = torch.rand(1, 3, 224, 224)
    ############################# code changes #####################
    d = d.to("xpu")
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.float16):
    ############################# code changes #####################
        model = torch.jit.trace(model, d)
        model = torch.jit.freeze(model)
        model(data)

print("Execution finished")
BERT
import torch
from transformers import BertModel

############# code changes ###############
import intel_extension_for_pytorch as ipex

############# code changes ###############

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])

#################### code changes ################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.float16)
#################### code changes ################

with torch.no_grad():
    d = torch.randint(vocab_size, size=[batch_size, seq_length])
    ############################# code changes #####################
    d = d.to("xpu")
    with torch.xpu.amp.autocast(enabled=True, dtype=torch.float16):
    ############################# code changes #####################
        model = torch.jit.trace(model, (d,), strict=False)
        model = torch.jit.freeze(model)

        model(data)

print("Execution finished")

INT8

We recommend you use TorchScript for INT8 model because it has wider support for models. TorchScript mode also auto-enables our optimizations. For TorchScript INT8 model, inserting observer and model quantization is achieved through prepare_jit and convert_jit separately. Calibration process is required for collecting statistics from real data. After conversion, optimizations such as operator fusion would be auto-enabled.

import os
import torch
from torch.jit._recursive import wrap_cpp_module
from torch.quantization.quantize_jit import (
    convert_jit,
    prepare_jit,
)

#################### code changes ####################
import intel_extension_for_pytorch as ipex

######################################################

##### Example Model #####
import torchvision.models as models

model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
model = model.to("xpu")

with torch.no_grad():
    data = torch.rand(1, 3, 224, 224)
    data = data.to("xpu")
    modelJit = torch.jit.trace(model, data)
#########################

qconfig = torch.quantization.QConfig(
    activation=torch.quantization.observer.MinMaxObserver.with_args(
        qscheme=torch.per_tensor_symmetric, reduce_range=False, dtype=torch.quint8
    ),
    weight=torch.quantization.default_weight_observer,
)
modelJit = prepare_jit(modelJit, {"": qconfig}, True)

##### Example Dataloader #####
import torchvision

DOWNLOAD = True
DATA = "datasets/cifar10/"

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
train_dataset = torchvision.datasets.CIFAR10(
    root=DATA,
    train=True,
    transform=transform,
    download=DOWNLOAD,
)
calibration_data_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=128
)

for batch_idx, (d, target) in enumerate(calibration_data_loader):
    print(f"calibrated on batch {batch_idx} out of {len(calibration_data_loader)}")
    d = d.to("xpu")
    modelJit(d)
##############################

modelJit = convert_jit(modelJit, True)

data = torch.rand(1, 3, 224, 224)
data = data.to("xpu")
modelJit(data)

print("Execution finished")

torch.xpu.optimize

The torch.xpu.optimize function is an alternative to ipex.optimize in Intel® Extension for PyTorch*, and provides identical usage for XPU devices only. The motivation for adding this alias is to unify the coding style in user scripts base on torch.xpu modular. Refer to the example below for usage.

import torch
import torchvision.models as models

############# code changes #########
import intel_extension_for_pytorch

############# code changes #########

model = models.resnet50(weights="ResNet50_Weights.DEFAULT")
model.eval()
data = torch.rand(1, 3, 224, 224)

model = model.to(memory_format=torch.channels_last)
data = data.to(memory_format=torch.channels_last)

########## code changes #########
model = model.to("xpu")
data = data.to("xpu")
model = torch.xpu.optimize(model)
########## code changes #########

with torch.no_grad():
    model(data)

print("Execution finished")

C++

To work with libtorch, the PyTorch C++ library, Intel® Extension for PyTorch* provides its own C++ dynamic library. The C++ library only handles inference workloads, such as service deployment. For regular development, use the Python interface. Unlike using libtorch, no specific code changes are required. Compilation follows the recommended methodology with CMake. Detailed instructions can be found in the PyTorch tutorial.

During compilation, Intel optimizations will be activated automatically after the C++ dynamic library of Intel® Extension for PyTorch* is linked.

The example code below works for all data types.

Basic Usage

example-app.cpp

#include <torch/script.h>
#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  torch::jit::script::Module module;
  try {
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }
  module.to(at::kXPU);

  std::vector<torch::jit::IValue> inputs;
  torch::Tensor input = torch::rand({1, 3, 224, 224}).to(at::kXPU);
  inputs.push_back(input);

  at::Tensor output = module.forward(inputs).toTensor();
  std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << std::endl;
  std::cout << "Execution finished" << std::endl;

  return 0;
}

CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)

find_package(IPEX REQUIRED)

set(target example-app)
add_executable(${target} example-app.cpp)
target_link_libraries(${target} ${TORCH_IPEX_LIBRARIES})

set_property(TARGET ${target} PROPERTY CXX_STANDARD 17)

Command for compilation

$ cd examples/gpu/inference/cpp/example-app
$ mkdir build
$ cd build
$ CC=icx CXX=icpx cmake -DCMAKE_PREFIX_PATH=<LIBPYTORCH_PATH> ..
$ make

If Found IPEX is shown as dynamic library paths, the extension was linked into the binary. This can be verified with the Linux command ldd.

$ CC=icx CXX=icpx cmake -DCMAKE_PREFIX_PATH=/workspace/libtorch ..
-- The C compiler identification is IntelLLVM 2023.2.0
-- The CXX compiler identification is IntelLLVM 2023.2.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /workspace/intel/oneapi/compiler/2023.2.0/linux/bin/icx - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /workspace/intel/oneapi/compiler/2023.2.0/linux/bin/icpx - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success
-- Found Threads: TRUE
-- Found Torch: /workspace/libtorch/lib/libtorch.so
-- Found IPEX: /workspace/libtorch/lib/libintel-ext-pt-cpu.so;/workspace/libtorch/lib/libintel-ext-pt-gpu.so
-- Configuring done
-- Generating done
-- Build files have been written to: examples/gpu/inference/cpp/example-app/build

$ ldd example-app
        ...
        libtorch.so => /workspace/libtorch/lib/libtorch.so (0x00007fd5bb927000)
        libc10.so => /workspace/libtorch/lib/libc10.so (0x00007fd5bb895000)
        libtorch_cpu.so => /workspace/libtorch/lib/libtorch_cpu.so (0x00007fd5a44d8000)
        libintel-ext-pt-cpu.so => /workspace/libtorch/lib/libintel-ext-pt-cpu.so (0x00007fd5a1a1b000)
        libintel-ext-pt-gpu.so => /workspace/libtorch/lib/libintel-ext-pt-gpu.so (0x00007fd5862b0000)
        ...
        libmkl_intel_lp64.so.2 => /workspace/intel/oneapi/mkl/2023.2.0/lib/intel64/libmkl_intel_lp64.so.2 (0x00007fd584ab0000)
        libmkl_core.so.2 => /workspace/intel/oneapi/mkl/2023.2.0/lib/intel64/libmkl_core.so.2 (0x00007fd5806cc000)
        libmkl_gnu_thread.so.2 => /workspace/intel/oneapi/mkl/2023.2.0/lib/intel64/libmkl_gnu_thread.so.2 (0x00007fd57eb1d000)
        libmkl_sycl.so.3 => /workspace/intel/oneapi/mkl/2023.2.0/lib/intel64/libmkl_sycl.so.3 (0x00007fd55512c000)
        libOpenCL.so.1 => /workspace/intel/oneapi/compiler/2023.2.0/linux/lib/libOpenCL.so.1 (0x00007fd55511d000)
        libsvml.so => /workspace/intel/oneapi/compiler/2023.2.0/linux/compiler/lib/intel64_lin/libsvml.so (0x00007fd553b11000)
        libirng.so => /workspace/intel/oneapi/compiler/2023.2.0/linux/compiler/lib/intel64_lin/libirng.so (0x00007fd553600000)
        libimf.so => /workspace/intel/oneapi/compiler/2023.2.0/linux/compiler/lib/intel64_lin/libimf.so (0x00007fd55321b000)
        libintlc.so.5 => /workspace/intel/oneapi/compiler/2023.2.0/linux/compiler/lib/intel64_lin/libintlc.so.5 (0x00007fd553a9c000)
        libsycl.so.6 => /workspace/intel/oneapi/compiler/2023.2.0/linux/lib/libsycl.so.6 (0x00007fd552f36000)
        ...

Use SYCL codes

Using SYCL codes in an C++ application is also possible. The example below shows how to invoke SYCL codes. You need to explicitly pass -fsycl into CMAKE_CXX_FLAGS.

example-usm.cpp

#include <iostream>
#include <memory>
#include <torch/script.h>
#include <ipex.h>
#include <CL/sycl.hpp>

using namespace sycl;
using namespace xpu::dpcpp;

int main(int argc, const char* argv[]) {
  torch::jit::script::Module module;
  try {
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }
  std::cout << "load model done " << std::endl;
  module.to(at::kXPU);

  std::vector<torch::jit::IValue> inputs;
  // fetch sycl queue from c10::Stream on XPU device.
  auto device_type = c10::DeviceType::XPU;
  c10::impl::VirtualGuardImpl impl(device_type);
  c10::Stream xpu_stream = impl.getStream(c10::Device(device_type));
  auto& sycl_queue = xpu::get_queue_from_stream(xpu_stream);
  float *input_ptr = malloc_device<float>(224 * 224 * 3, sycl_queue);
  auto input = fromUSM(input_ptr, at::ScalarType::Float, {1, 3, 224, 224}, c10::nullopt, -1).to(at::kXPU);
  std::cout << "input tensor created from usm " << std::endl;
  inputs.push_back(input);

  at::IValue output = module.forward(inputs);
  torch::Tensor output_tensor;
  output_tensor = output.toTensor();
  std::cout << output_tensor.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << std::endl;
  std::cout << "Execution finished" << std::endl;

  return 0;
}

CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-usm)

find_package(IPEX REQUIRED)

set(target example-usm)
add_executable(${target} example-usm.cpp)
target_link_libraries(${target} ${TORCH_IPEX_LIBRARIES})
list(APPEND CMAKE_CXX_FLAGS "-fsycl")

set_property(TARGET ${target} PROPERTY CXX_STANDARD 17)

Customize DPC++ kernels

Intel® Extension for PyTorch* provides its C++ dynamic library to allow users to implement custom DPC++ kernels to run on the XPU device. Refer to the DPC++ extension for details.

Model Zoo

Use cases that have already been optimized by Intel engineers are available at Model Zoo for Intel® Architecture. A number of PyTorch use cases for benchmarking are also available on the GitHub page. Models verified on Intel GPUs are marked in the Model Documentation Column. You can get performance benefits out-of-box by simply running scripts in the Model Zoo.