BERT Training for Classifying Text on Intel CPU and GPU


Intel® Extension for TensorFlow* is compatible with stock TensorFlow*. This example uses the tutorial from Classify text with BERT to show training on BERT model without changing the original code.

Installing the Intel® Extension for TensorFlow* in legacy running environment, TensorFlow will execute the training on Intel CPU and GPU.

No need any code change

Hardware Requirements

Verified Hardware Platforms:

  • Intel® Data Center GPU Max Series

  • Intel CPU


Prepare for GPU (Skip this step for CPU)

Refer to Prepare

Setup Running Environment

  • Setup for GPU

  • Setup for CPU


If your system is Ubuntu22.04, we suggest you to install below g++ version in conda environment.

conda install -c conda-forge gxx_linux-64==12.1.0

Enable Running Environment

  1. Enable oneAPI running environment (only for GPU) and virtual running environment.

    • For GPU, refer to Running

    • For CPU,

source env_itex_cpu/bin/activate
  1. Install the Python packages used in Classify text with BERT.

The Jupyter notebook script will install other required pacakges while it is running. So there’s no need to pre-install them.

Download Jupyter Code:


Startup Jupyter Notebook

jupyter notebook --notebook-dir=./ --ip= --no-browser  --allow-root &

Open the url: above in your web browser.


  1. Open classify_text_with_bert.ipynb by Jupyter notebook.

  2. Run the tutorial according to the description in the Jupyter notebook.

  3. The TensorFlow will train and infer the BERT model on Intel CPU or GPU.


  1. The following error log is a known issue and it is not caused by Intel® Extension for TensorFlow*. This crash happens when the code is finished and tries to release resources. It doesn’t impact the result of the Bert training and inference/test.

Traceback (most recent call last):
  File "/home/xxx/xxx/env_itex/lib/python3.9/site-packages/tensorflow/python/training/tracking/", line 174, in __del__
TypeError: 'NoneType' object is not callable
  1. Jupyter ipython kernel crash after import tf2.14 in Ubuntu22.04 is a known issue. You can install below g++ version in conda environment to solve this problem.

conda install -c conda-forge gxx_linux-64==12.1.0