torch.compile for GPU (Beta)

Introduction

Intel® Extension for PyTorch* now empowers users to seamlessly harness graph compilation capabilities for optimal PyTorch model performance on Intel GPU via the flagship torch.compile API through the default “inductor” backend (TorchInductor). The Triton compiler has been the core of the Inductor codegen supporting various accelerator devices. Intel has extended TorchInductor by adding Intel GPU support to Triton. Additionally, post-op fusions for convolution and matrix multiplication, facilitated by oneDNN fusion kernels, contribute to enhanced efficiency for computational intensive operations. Leveraging these features is as simple as using the default “inductor” backend, making it easier than ever to unlock the full potential of your PyTorch models on Intel GPU platforms.

Required Dependencies

Verified version:

  • torch : v2.5

  • intel_extension_for_pytorch : v2.5

  • triton : v3.1.0+91b14bf559

Install Intel® oneAPI DPC++/C++ Compiler 2025.0.4.

Follow Intel® Extension for PyTorch* Installation to install torch and intel_extension_for_pytorch firstly.

Triton could be directly installed using the following command:

pip install --pre pytorch-triton-xpu==3.1.0+91b14bf559 --index-url https://download.pytorch.org/whl/nightly/xpu

Remember to activate the oneAPI DPC++/C++ Compiler by following commands.

# {dpcpproot} is the location for dpcpp ROOT path and it is where you installed oneAPI DPCPP, usually it is /opt/intel/oneapi/compiler/latest or ~/intel/oneapi/compiler/latest
source {dpcpproot}/env/vars.sh

Example Usage

Inferenece with torch.compile

import torch
import torch.nn as nn
import intel_extension_for_pytorch

# Define the SimpleNet model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create model
model = SimpleNet().to("xpu")

# Compile model
compiled_model = torch.compile(model, options={"freezing": True})

# Inference main
input = torch.rand(64, 3, 224, 224, device=torch.device("xpu"))
with torch.no_grad():
    with torch.xpu.amp.autocast(dtype=torch.float16):
        output = compiled_model(input)

# Print the output shape
print(output.shape)
print("Done for inference with torch.compile")

Training with torch.compile

import torch
import intel_extension_for_pytorch

# create model and optimizer
model = SimpleNet().to("xpu")
optimizer = torch.optim.SGD(model.parameters(), lr=..., momentum=..., weight_decay=...)

# compile model
compiled_model = torch.compile(model)

# training main
input = torch.rand(64, 3, 224, 224, device=torch.device("xpu"))
with torch.xpu.amp.autocast(dtype=torch.bfloat16):
    output = compiled_model(input)
    loss = loss_function(output)
optimizer.zero_grad()
loss.backward()
optimizer.step()

Troubleshooting

If you encounter any issue related to torch.compile or triton, please refer to Library Dependencies section in known_issues.