Keras 3 Overview
Keras is a deep learning API written in Python and capable of running on top of either JAX, TensorFlow, or PyTorch. Both JAX and TensorFlow backend compiles the model by XLA and delivers the best training and prediction performance on GPU. But results vary from model to model, as non XLA TensorFlow is occasionaly faster on GPU. The following image show how ITEX works with XLA, Keras 3 TensorFlow backend and legacy Keras.
Use Case with different performance
There are serval use cases that can lead to diffent performance.
Default Users use Keras 3 and the model supports jit, the model will runs into XLA. If user script does not contains keras related code and does not enables XLA in tensorflow. There will be performance regression. Set environment variable
ITEX_DISABLE_XLA=1
to avoid regression. After ITEX XLA disabled, users can choose wether to use NPD (default) or stream excutor for better performance by environment variableITEX_ENABLE_NEXTPLUGGABLE_DEVICE
.Legacy Keras To continue using Keras 2.0, do the following.
Install
tf-keras
viapip install tf-keras
To switch
tf.keras
to use Keras 2 (tf-keras
), set the environment variableTF_USE_LEGACY_KERAS=1
directly or in your python program withimport os;os.environ["TF_USE_LEGACY_KERAS"]="1"
. Please note that this will set it for all packages in your Python runtime programChange the keras import: replace
import keras
withimport tf_keras as keras
. Update anyfrom keras import
tofrom tf_keras
.
Users can choose wether to use NPD (default) or stream excutor for better performance by environment variable ITEX_ENABLE_NEXTPLUGGABLE_DEVICE
.
Keras 3 with jit_compile disabled Users can disable jit_compile by
model.jit_compile=False
ormodel.compile(..., jit_compile=False)
. The use of itex ops override can also lead to disabling jit_compile. In this case,ITEX_DISABLE_XLA=1
must be set.Enable XLA through TensorFlow. Users can enable XLA through TensorFlow by add environment variable
TF_XLA_FLAGS="--tf_xla_auto_jit=1"
. Usetf_xla_auto_jit=1
for auto clustering TF ops into XLA,tf_xla_auto_jit=2
for compiling all into XLA. Users should setmodel.jit_compile=False
if keras model is used. If ITEX custom ops is used orITEX_OPS_OVERRIDE
is set, users should usetf_xla_auto_jit=1
to avoid error.
Situations leads to warning or Error
We list all invalid cases here. Keras version equals to 0 means model script does not use Keras.
Note that in any cases, import keras
first before import tensorflow
will cause an error due to circular import in ITEX.
OPS_OVERRIDE | TF_AUTO_JIT_FLAG | Keras version | NPD | Jit Compile | Warning | Error | Solution |
---|---|---|---|---|---|---|---|
Any | 0 | 0 | 0 | NA | PluggableDevice cannot work with latest Keras. | ITEX_DISABLE_XLA=1 |
|
Any | 0 | 0 | 1 | NA | Perf Regression | ITEX_DISABLE_XLA=1 |
|
Any | Any | 2 | Any | 1 | Unkown behavior, not supported. Use TF_AUTO_JIT_FLAG="--tf_xla_auto_jit=1" or 2 to enable XLA |
||
Any | 0 | 3 | 0 | Any | Cannot close NPD when keras 3 | ITEX_DISABLE_XLA=1 |
|
Any | 0 | 3 | 1 | 0 | perf regression | ITEX_DISABLE_XLA=1 |
|
Any | 1 | Any | 0 | Any | Cannot close NPD | ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1 |
|
Any | 2 | Any | 0 | Any | Cannot close NPD | ITEX_ENABLE_NEXTPLUGGABLE_DEVICE=1 |
|
1 | 2 | Any | 1 | Any | custom op not supported by XLA | ITEX_OPS_OVERRIDE=0 |