torch.compile for GPU (Beta) ============================ # Introduction IntelĀ® Extension for PyTorch\* now empowers users to seamlessly harness graph compilation capabilities for optimal PyTorch model performance on Intel GPU via the flagship [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile) API through the default "inductor" backend ([TorchInductor](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747/1)). The Triton compiler has been the core of the Inductor codegen supporting various accelerator devices. Intel has extended TorchInductor by adding Intel GPU support to Triton. Additionally, post-op fusions for convolution and matrix multiplication, facilitated by oneDNN fusion kernels, contribute to enhanced efficiency for computational intensive operations. Leveraging these features is as simple as using the default "inductor" backend, making it easier than ever to unlock the full potential of your PyTorch models on Intel GPU platforms. # Required Dependencies **Verified version**: - `torch` : v2.6 - `intel_extension_for_pytorch` : v2.6 Follow [ https://pytorch-extension.intel.com/installation?platform=gpu&version=v2.6.10%2Bxpu) to install `torch` and `intel_extension_for_pytorch`. Triton is installed along with torch. The cached files would be generated if you had run `torch.compile` with a previous version of triton, but they are generally conflicting with the new version. So, if the folder `~/.triton` exists before your first running of the `torch.compile` script in the current environment, please delete it. ```bash # delete the cache files generated by previous version of triton, if exists rm -rf ~/.triton ``` # Example Usage ## Inferenece with torch.compile ```python import torch import torch.nn as nn import intel_extension_for_pytorch # Define the SimpleNet model class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.fc1 = nn.Linear(32 * 56 * 56, 128) self.fc2 = nn.Linear(128, 10) self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) self.relu = nn.ReLU() def forward(self, x): x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = x.view(-1, 32 * 56 * 56) x = self.relu(self.fc1(x)) x = self.fc2(x) return x # Create model model = SimpleNet().to("xpu") # Compile model compiled_model = torch.compile(model, options={"freezing": True}) # Inference main input = torch.rand(64, 3, 224, 224, device=torch.device("xpu")) with torch.no_grad(): with torch.xpu.amp.autocast(dtype=torch.float16): output = compiled_model(input) # Print the output shape print(output.shape) print("Done for inference with torch.compile") ``` ## Training with torch.compile ```python import torch import intel_extension_for_pytorch # create model and optimizer model = SimpleNet().to("xpu") optimizer = torch.optim.SGD(model.parameters(), lr=..., momentum=..., weight_decay=...) # compile model compiled_model = torch.compile(model) # training main input = torch.rand(64, 3, 224, 224, device=torch.device("xpu")) with torch.xpu.amp.autocast(dtype=torch.bfloat16): output = compiled_model(input) loss = loss_function(output) optimizer.zero_grad() loss.backward() optimizer.step() ``` ## Troubleshooting If you encounter any issue related to `torch.compile` or `triton`, please refer to Library Dependencies section in [known_issues](../known_issues.md).