Fast BERT (Prototype)
Feature Description
Intel proposed a technique to speed up BERT workloads. Implementation leverages the idea from Tensor Processing Primitives: A Programming Abstraction for Efficiency and Portability in Deep Learning & HPC Workloads.
Currently ipex.fast_bert
API is only well optimized for training. For inference, it ensures functionality, while to get peak perf, please use ipex.optimize
API + torchscript.
Prerequisite
Transformers 4.6.0 ~ 4.45.0
Usage Example
An API ipex.fast_bert
is provided for a simple usage. Usage of this API follows the pattern of ipex.optimize
function. More detailed description of API is available at Fast BERT API doc
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased", attn_implementation="eager")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
torch.manual_seed(43)
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.fast_bert(model, dtype=torch.bfloat16)
###################################################### # noqa F401
with torch.no_grad():
model(data)
print("Execution finished")