Fast BERT (Experimental)

Feature Description

Intel proposed a technique to speed up BERT workloads. Implementation leverages the idea from Tensor Processing Primitives: A Programming Abstraction for Efficiency and Portability in Deep Learning & HPC Workloads.

The Implementation is integrated into Intel® Extension for PyTorch*. BERT could benefit from this new technique, for both training and inference.

Prerequisite

  • Transformers 4.6.0 ~ 4.20.0

Usage Example

An API ipex.fast_bert is provided for a simple usage. Usage of this API follows the pattern of ipex.optimize function. More detailed description of API is available at Fast BERT API doc

import torch
from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
torch.manual_seed(43)

#################### code changes ####################
import intel_extension_for_pytorch as ipex
model = ipex.fast_bert(model, dtype=torch.bfloat16)
######################################################

with torch.no_grad():
  model(data)