Fast BERT (Prototype)

Feature Description

Intel proposed a technique to speed up BERT workloads. Implementation leverages the idea from Tensor Processing Primitives: A Programming Abstraction for Efficiency and Portability in Deep Learning & HPC Workloads.

Currently ipex.fast_bert API is only well optimized for training. For inference, it ensures functionality, while to get peak perf, please use ipex.optimize API + torchscript.

Prerequisite

  • Transformers 4.6.0 ~ 4.31.0

Usage Example

An API ipex.fast_bert is provided for a simple usage. Usage of this API follows the pattern of ipex.optimize function. More detailed description of API is available at Fast BERT API doc

import torch
from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
torch.manual_seed(43)

#################### code changes ####################  # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.fast_bert(model, dtype=torch.bfloat16)
######################################################  # noqa F401

with torch.no_grad():
    model(data)

print("Execution finished")