torch.compile for GPU (Beta)

Introduction

Intel® Extension for PyTorch* now empowers users to seamlessly harness graph compilation capabilities for optimal PyTorch model performance on Intel GPU via the flagship torch.compile API through the default “inductor” backend (TorchInductor). The Triton compiler has been the core of the Inductor codegen supporting various accelerator devices. Intel has extended TorchInductor by adding Intel GPU support to Triton. Additionally, post-op fusions for convolution and matrix multiplication, facilitated by oneDNN fusion kernels, contribute to enhanced efficiency for computational intensive operations. Leveraging these features is as simple as using the default “inductor” backend, making it easier than ever to unlock the full potential of your PyTorch models on Intel GPU platforms.

Note: torch.compile for GPU is a beta feature and available from 2.1.10. So far, the feature is functional on Intel® Data Center GPU Max Series.

Required Dependencies

Verified version:

  • torch : v2.1.0

  • intel_extension_for_pytorch : > v2.1.10

  • triton : v2.1.0 with Intel® XPU Backend for Triton* backend enabled.

Follow Intel® Extension for PyTorch* Installation to install torch and intel_extension_for_pytorch firstly.

Then install Intel® XPU Backend for Triton* backend for triton package. You may install it via prebuilt wheel package or build it from the source. We recommend installing via prebuilt package:

  • Download the wheel package from release page. Note that you don’t need to install the LLVM release manually.

  • Install the wheel package by pip install. Note that this wheel package is a triton package with Intel GPU support, so you don’t need to pip install triton again.

python -m pip install --force-reinstall  triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

Please follow the Intel® XPU Backend for Triton* Installation for more detailed installation steps.

Note that if you install triton using make triton command inside PyTorch* repo, the installed triton does not compile with Intel GPU support by default, you will need to manually set TRITON_CODEGEN_INTEL_XPU_BACKEND=1 for enabling Intel GPU support. In addition, for building from the source via the triton repo, the commit needs to be pinned at a tested triton commit. Please follow the Intel® XPU Backend for Triton* Installation #build from the source section for more information about build triton package from the source.

Inferenece with torch.compile

import torch
import intel_extension_for_pytorch

# create model
model = SimpleNet().to("xpu")

# compile model
compiled_model = torch.compile(model, options={"freezing": True})

# inference main
input = torch.rand(64, 3, 224, 224, device=torch.device("xpu"))
with torch.no_grad():
    with torch.xpu.amp.autocast(dtype=torch.float16):
        output = compiled_model(input)

Training with torch.compile

import torch
import intel_extension_for_pytorch

# create model and optimizer
model = SimpleNet().to("xpu")
optimizer = torch.optim.SGD(model.parameters(), lr=..., momentum=..., weight_decay=...)

# compile model
compiled_model = torch.compile(model)

# training main
input = torch.rand(64, 3, 224, 224, device=torch.device("xpu"))
with torch.xpu.amp.autocast(dtype=torch.bfloat16):
    output = compiled_model(input)
    loss = loss_function(output)
optimizer.zero_grad()
loss.backward()
optimizer.step()