OpenXLA Support on GPU¶
This guide introduces the overview of OpenXLA high level integration structure, and demonstrates how to build Intel® Extension for TensorFlow* and run JAX example with OpenXLA.
1. Overview¶
Intel® Extension for TensorFlow* includes PJRT plugin implementation, which seamlessly runs JAX models on Intel@ GPU. The PJRT API simplified the integration, which allowed the Intel GPU plugin to be developed separately and quickly integrated into JAX. This same PJRT implementation also enables initial Intel GPU support for TensorFlow and PyTorch models with XLA acceleration. Refer to OpenXLA PJRT Plugin RFC for more details.
JAX provides a familiar NumPy-style API, includes composable function transformations for compilation, batching, automatic differentiation, and parallelization, and the same code executes on multiple backends.
In JAX python package,
jax/_src/lib/xla_bridge.py
register_pjrt_plugin_factories(os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', ''))
register_pjrt_plugin_factories
registers backend for PJRT plugins. For intel XPUPJRT_NAMES_AND_LIBRARY_PATHS
is set to be'xpu:Your_itex_path/bazel-bin/itex/libitex_xla_extension.so'
,xpu
is the backend name andlibitex_xla_extension.so
is the PJRT plugin library.In jaxlib python package
jaxlib/xla_extension.so
,
Jaxlib gets the lastest tensorflow code which calls the PJRT C API interface. The backend needs to implement these API.libitex_xla_extension.so
implementsPJRT C API interface
which can be got in GetPjrtApi.
2. Hardware and Software Requirement¶
Hardware Requirements¶
Verified Hardware Platforms:
Software Requirements¶
Ubuntu 22.04, Red Hat 8.6 (64-bit)
Intel® Data Center GPU Flex Series
Ubuntu 22.04, Red Hat 8.6 (64-bit), SUSE Linux Enterprise Server(SLES) 15 SP3/SP4
Intel® Data Center GPU Max Series
Intel® oneAPI Base Toolkit 2023.1
TensorFlow 2.12.0
Python 3.8-3.10
pip 19.0 or later (requires manylinux2014 support)
Install GPU Drivers¶
Release | OS | Intel GPU | Install Intel GPU Driver |
---|---|---|---|
v1.2.0 | Ubuntu 22.04, Red Hat 8.6 | Intel® Data Center GPU Flex Series | Refer to the Installation Guides for latest driver installation. If install the verified Intel® Data Center GPU Max Series/Intel® Data Center GPU Flex Series 602, please append the specific version after components, such as sudo apt-get install intel-opencl-icd==23.05.25593.18-601~22.04 |
v1.2.0 | Ubuntu 22.04, Red Hat 8.6, SLES 15 SP3/SP4 | Intel® Data Center GPU Max Series | Refer to the Installation Guides for latest driver installation. If install the verified Intel® Data Center GPU Max Series/Intel® Data Center GPU Flex Series 602, please append the specific version after components, such as sudo apt-get install intel-opencl-icd==23.05.25593.18-601~22.04 |
3. Build Library for JAX¶
There are some differences from source build procedure
Make sure get Intel® Extension for TensorFlow* main branch code and python version >=3.8.
In TensorFlow installation steps, make sure to install jax and jaxlib at the same time.
$ pip install tensorflow==2.12.0 jax==0.4.4 jaxlib==0.4.4
In “Configure the build” step, run ./configure, select yes for JAX support,
=> “Do you wish to build for JAX support? [y/N]: Y”
Build command:
$ bazel build --config=jax -c opt //itex:libitex_xla_extension.so
Then we can get the library with xla extension ./bazel-bin/itex/libitex_xla_extension.so
4. Run JAX Example¶
Set library path.
$ export PJRT_NAMES_AND_LIBRARY_PATHS='xpu:Your_itex_path/bazel-bin/itex/libitex_xla_extension.so'
$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:Your_Python_site-packages/jaxlib # Some functions defined in xla_extension.so are needed by libitex_xla_extension.so
$ export ITEX_VERBOSE=1 # Optional variable setting. It shows detailed optimization/compilation/execution info.
Run the below jax python code.
import jax
import jax.numpy as jnp
@jax.jit
def lax_conv():
key = jax.random.PRNGKey(0)
lhs = jax.random.uniform(key, (2,1,9,9), jnp.float32)
rhs = jax.random.uniform(key, (1,1,4,4), jnp.float32)
side = jax.random.uniform(key, (1,1,1,1), jnp.float32)
out = jax.lax.conv_with_general_padding(lhs, rhs, (1,1), ((0,0),(0,0)), (1,1), (1,1))
out = jax.nn.relu(out)
out = jnp.multiply(out, side)
return out
print(lax_conv())
Reference result:
I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero
I itex/core/compiler/xla/service/service.cc:176] XLA service 0x56060b5ae740 initialized for platform sycl (this does not guarantee that XLA will be used). Devices:
I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (0): <undefined>, <undefined>
I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (1): <undefined>, <undefined>
[[[[2.0449753 2.093208 2.1844783 1.9769732 1.5857391 1.6942389]
[1.9218378 2.2862523 2.1549542 1.8367321 1.3978379 1.3860377]
[1.9456574 2.062028 2.0365305 1.901286 1.5255247 1.1421617]
[2.0621 2.2933435 2.1257985 2.1095486 1.5584903 1.1229166]
[1.7746235 2.2446113 1.7870374 1.8216239 1.557919 0.9832508]
[2.0887792 2.5433128 1.9749291 2.2580051 1.6096935 1.264905 ]]]
[[[2.175818 2.0094342 2.005763 1.6559253 1.3896458 1.4036925]
[2.1342552 1.8239582 1.6091168 1.434404 1.671778 1.7397764]
[1.930626 1.659667 1.6508744 1.3305787 1.4061482 2.0829628]
[2.130649 1.6637266 1.594426 1.2636002 1.7168686 1.8598001]
[1.9009514 1.7938274 1.4870623 1.6193901 1.5297288 2.0247464]
[2.0905268 1.7598859 1.9362347 1.9513799 1.9403584 2.1483061]]]]
If ITEX_VERBOSE=1
is set, the log looks like this:
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:301] Running HLO pass pipeline on module jit_lax_conv: optimization
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass fusion
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass fusion_merger
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass multi_output_fusion
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass gpu-conv-rewriter
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass onednn-fused-convolution-rewriter
I itex/core/compiler/xla/service/gpu/gpu_compiler.cc:1221] Build kernel via LLVM kernel compilation.
I itex/core/compiler/xla/service/gpu/spir_compiler.cc:255] CompileTargetBinary - CompileToSpir time: 11 us (cumulative: 99.2 ms, max: 74.9 ms, #called: 8)
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2201] Executing computation jit_lax_conv; num_replicas=1 num_partitions=1 num_addressable_devices=1
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2268] Replicated execution complete.
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1208] PjRtStreamExecutorBuffer::Delete
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1299] PjRtStreamExecutorBuffer::ToLiteral
More JAX examples
Get examples from https://github.com/google/jax to run.
$ git clone https://github.com/google/jax.git
$ cd jax && git checkout jax-v0.4.4
$ export PJRT_NAMES_AND_LIBRARY_PATHS='xpu:Your_itex_path/bazel-bin/itex/libitex_xla_extension.so'
$ python -m examples.mnist_classifier