Examples
These examples will guide you through using the Intel® Extension for PyTorch* on Intel CPUs.
You can also refer to the Features section to get the examples and usage instructions related to particular features.
The source code for these examples, as well as the feature examples, can be found in the GitHub source tree under the examples directory.
Python examples demonstrate usage of Python APIs:
C++ examples demonstrate usage of C++ APIs
Intel® AI Reference Models provide out-of-the-box use cases, demonstrating the performance benefits achievable with Intel Extension for PyTorch*
Prerequisites: Before running these examples, please note the following:
Examples using the BFloat16 data type require machines with the Intel® Advanced Vector Extensions 512 (Intel® AVX-512) BF16 and Intel® Advanced Matrix Extensions (Intel® AMX) BF16 instruction sets.
Python
Training
Single-instance Training
To use Intel® Extension for PyTorch* on training, you need to make the following changes in your code:
Import
intel_extension_for_pytorch
asipex
.Invoke the
ipex.optimize
function to apply optimizations against the model and optimizer objects, as shown below:
...
import torch
import intel_extension_for_pytorch as ipex
...
model = Model()
criterion = ...
optimizer = ...
model.train()
# For Float32
model, optimizer = ipex.optimize(model, optimizer=optimizer)
# For BFloat16
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
# Invoke the code below to enable beta feature torch.compile
model = torch.compile(model, backend="ipex")
...
optimizer.zero_grad()
output = model(data)
...
Below you can find complete code examples demonstrating how to use the extension on training for different data types:
Float32
Note: You need to install torchvision
Python package to run the following example.
import torch
import torchvision
import intel_extension_for_pytorch as ipex
LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128
)
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer)
# Uncomment the code below to enable beta feature `torch.compile`
# model = torch.compile(model, backend="ipex")
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(batch_idx)
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')
print("Execution finished")
BFloat16
Note: You need to install torchvision
Python package to run the following example.
import torch
import torchvision
import intel_extension_for_pytorch as ipex
LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128
)
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)
model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
# Uncomment the code below to enable beta feature `torch.compile`
# model = torch.compile(model, backend="ipex")
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
with torch.cpu.amp.autocast():
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(batch_idx)
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')
print("Execution finished")
Distributed Training
Distributed training with PyTorch DDP is accelerated by oneAPI Collective Communications Library Bindings for Pytorch* (oneCCL Bindings for Pytorch*). The extension supports FP32 and BF16 data types. More detailed information and examples are available at the Github repo.
Note: You need to install torchvision
Python package to run the following example.
import os
import torch
import torch.distributed as dist
import torchvision
import oneccl_bindings_for_pytorch as torch_ccl
import intel_extension_for_pytorch as ipex
LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
os.environ['RANK'] = os.environ.get('PMI_RANK', 0)
os.environ['WORLD_SIZE'] = os.environ.get('PMI_SIZE', 1)
dist.init_process_group(
backend='ccl',
init_method='env://'
)
rank = os.environ['RANK']
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128,
sampler=dist_sampler
)
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum=0.9)
model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer)
model = torch.nn.parallel.DistributedDataParallel(model)
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print('batch_id: {}'.format(batch_idx))
if rank == 0:
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')
dist.destroy_process_group()
print("Execution finished")
Inference
The optimize
function of Intel® Extension for PyTorch* applies optimizations to the model, bringing additional performance boosts. For both computer vision workloads and NLP workloads, we recommend applying the optimize
function against the model object.
Float32
Eager Mode
Resnet50
Note: You need to install torchvision
Python package to run the following example.
import torch
import torchvision.models as models
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
model.eval()
data = torch.rand(128, 3, 224, 224)
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model)
###################################################### # noqa F401
with torch.no_grad():
model(data)
print("Execution finished")
BERT
Note: You need to install transformers
Python package to run the following example.
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 128
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model)
###################################################### # noqa F401
with torch.no_grad():
model(data)
print("Execution finished")
TorchScript Mode
We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.
Resnet50
Note: You need to install torchvision
Python package to run the following example.
import torch
import torchvision.models as models
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
model.eval()
data = torch.rand(128, 3, 224, 224)
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model)
###################################################### # noqa F401
with torch.no_grad():
d = torch.rand(128, 3, 224, 224)
model = torch.jit.trace(model, d)
model = torch.jit.freeze(model)
model(data)
print("Execution finished")
BERT
Note: You need to install transformers
Python package to run the following example.
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 128
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model)
###################################################### # noqa F401
with torch.no_grad():
d = torch.randint(vocab_size, size=[batch_size, seq_length])
model = torch.jit.trace(model, (d,), check_trace=False, strict=False)
model = torch.jit.freeze(model)
model(data)
print("Execution finished")
TorchDynamo Mode (Beta, NEW feature from 2.0.0)
Resnet50
Note: You need to install torchvision
Python package to run the following example.
import torch
import torchvision.models as models
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model.eval()
data = torch.rand(128, 3, 224, 224)
# Beta Feature
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, weights_prepack=False)
model = torch.compile(model, backend="ipex")
###################################################### # noqa F401
with torch.no_grad():
model(data)
print("Execution finished")
BERT
Note: You need to install transformers
Python package to run the following example.
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 128
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
# Beta Feature
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, weights_prepack=False)
model = torch.compile(model, backend="ipex")
###################################################### # noqa F401
with torch.no_grad():
model(data)
print("Execution finished")
Note: In TorchDynamo mode, since the native PyTorch operators like aten::convolution
and aten::linear
are well supported and optimized in ipex
backend, we need to disable weights prepacking by setting weights_prepack=False
in ipex.optimize()
.
BFloat16
The optimize
function works for both Float32 and BFloat16 data type. For BFloat16 data type, set the dtype
parameter to torch.bfloat16
.
We recommend using Auto Mixed Precision (AMP) with BFloat16 data type.
Eager Mode
Resnet50
Note: You need to install torchvision
Python package to run the following example.
import torch
import torchvision.models as models
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
model.eval()
data = torch.rand(128, 3, 224, 224)
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16)
###################################################### # noqa F401
with torch.no_grad(), torch.cpu.amp.autocast():
model(data)
print("Execution finished")
BERT
Note: You need to install transformers
Python package to run the following example.
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 128
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16)
###################################################### # noqa F401
with torch.no_grad(), torch.cpu.amp.autocast():
model(data)
print("Execution finished")
TorchScript Mode
We recommend using Intel® Extension for PyTorch* with TorchScript for further optimizations.
Resnet50
Note: You need to install torchvision
Python package to run the following example.
import torch
import torchvision.models as models
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
model.eval()
data = torch.rand(128, 3, 224, 224)
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16)
###################################################### # noqa F401
with torch.no_grad(), torch.cpu.amp.autocast():
model = torch.jit.trace(model, torch.rand(128, 3, 224, 224))
model = torch.jit.freeze(model)
model(data)
print("Execution finished")
BERT
Note: You need to install transformers
Python package to run the following example.
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 128
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16)
###################################################### # noqa F401
with torch.no_grad(), torch.cpu.amp.autocast():
d = torch.randint(vocab_size, size=[batch_size, seq_length])
model = torch.jit.trace(model, (d,), check_trace=False, strict=False)
model = torch.jit.freeze(model)
model(data)
print("Execution finished")
TorchDynamo Mode (Beta, NEW feature from 2.0.0)
Resnet50
Note: You need to install torchvision
Python package to run the following example.
import torch
import torchvision.models as models
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
model.eval()
data = torch.rand(128, 3, 224, 224)
# Beta Feature
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16, weights_prepack=False)
model = torch.compile(model, backend="ipex")
###################################################### # noqa F401
with torch.no_grad(), torch.cpu.amp.autocast():
model(data)
print("Execution finished")
BERT
Note: You need to install transformers
Python package to run the following example.
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 128
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
# Beta Feature
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16, weights_prepack=False)
model = torch.compile(model, backend="ipex")
###################################################### # noqa F401
with torch.no_grad(), torch.cpu.amp.autocast():
model(data)
print("Execution finished")
Fast Bert (Prototype)
Note: You need to install transformers
Python package to run the following example.
import torch
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
torch.manual_seed(43)
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
model = ipex.fast_bert(model, dtype=torch.bfloat16)
###################################################### # noqa F401
with torch.no_grad():
model(data)
print("Execution finished")
INT8
Starting from Intel® Extension for PyTorch* 1.12.0, quantization feature supports both static and dynamic modes.
Static Quantization
Calibration
Please follow the steps below to perform calibration for static quantization:
Import
intel_extension_for_pytorch
asipex
.Import
prepare
andconvert
fromintel_extension_for_pytorch.quantization
.Instantiate a config object from
torch.ao.quantization.QConfig
to save configuration data during calibration.Prepare model for calibration.
Perform calibration against dataset.
Invoke
ipex.quantization.convert
function to apply the calibration configure object to the fp32 model object to get an INT8 model.Save the INT8 model into a
pt
file.
Note: You need to install torchvision
Python package to run the following example.
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.quantization import prepare, convert
###################################################### # noqa F401
##### Example Model ##### # noqa F401
import torchvision.models as models
model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
model.eval()
data = torch.rand(128, 3, 224, 224)
######################### # noqa F401
qconfig_mapping = ipex.quantization.default_static_qconfig_mapping
# Alternatively, define your own qconfig_mapping:
# from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig, QConfigMapping
# qconfig = QConfig(
# activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
# weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
# qconfig_mapping = QConfigMapping().set_global(qconfig)
prepared_model = prepare(model, qconfig_mapping, example_inputs=data, inplace=False)
##### Example Dataloader ##### # noqa F401
import torchvision
DOWNLOAD = True
DATA = 'datasets/cifar10/'
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
calibration_data_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128
)
with torch.no_grad():
for batch_idx, (d, target) in enumerate(calibration_data_loader):
print(f'calibrated on batch {batch_idx} out of {len(calibration_data_loader)}')
prepared_model(d)
############################## # noqa F401
converted_model = convert(prepared_model)
with torch.no_grad():
traced_model = torch.jit.trace(converted_model, data)
traced_model = torch.jit.freeze(traced_model)
traced_model.save("static_quantized_model.pt")
print("Saved model to: static_quantized_model.pt")
Deployment
For deployment, the INT8 model is loaded from the local file and can be used directly for sample inference.
Follow the steps below:
Import
intel_extension_for_pytorch
asipex
.Load the INT8 model from the saved file.
Run inference.
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex # noqa F401
###################################################### # noqa F401
model = torch.jit.load('static_quantized_model.pt')
model.eval()
model = torch.jit.freeze(model)
data = torch.rand(128, 3, 224, 224)
with torch.no_grad():
model(data)
print("Execution finished")
Dynamic Quantization
Please follow the steps below to perform dynamic quantization:
Import
intel_extension_for_pytorch
asipex
.Import
prepare
andconvert
fromintel_extension_for_pytorch.quantization
.Instantiate a config object from
torch.ao.quantization.QConfig
to save configuration data during calibration.Prepare model for quantization.
Convert the model.
Run inference to perform dynamic quantization.
Save the INT8 model into a
pt
file.
Note: You need to install transformers
Python package to run the following example.
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.quantization import prepare, convert
###################################################### # noqa F401
##### Example Model ##### # noqa F401
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
vocab_size = model.config.vocab_size
batch_size = 128
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
######################### # noqa F401
qconfig_mapping = ipex.quantization.default_dynamic_qconfig_mapping
# Alternatively, define your own qconfig:
# from torch.ao.quantization import PerChannelMinMaxObserver, PlaceholderObserver, QConfig, QConfigMapping
# qconfig = QConfig(
# activation = PlaceholderObserver.with_args(dtype=torch.float, is_dynamic=True),
# weight = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
# qconfig_mapping = QConfigMapping().set_global(qconfig)
prepared_model = prepare(model, qconfig_mapping, example_inputs=data)
converted_model = convert(prepared_model)
with torch.no_grad():
traced_model = torch.jit.trace(converted_model, (data,), check_trace=False, strict=False)
traced_model = torch.jit.freeze(traced_model)
traced_model.save("dynamic_quantized_model.pt")
print("Saved model to: dynamic_quantized_model.pt")
Large Language Model (LLM)
Intel® Extension for PyTorch* provides dedicated optimization for running Large Language Models (LLM) faster. A set of data types are supported for various scenarios, including FP32, BF16, Smooth Quantization INT8, Weight Only Quantization INT8/INT4 (prototype).
Note: You need to install transformers==4.38.1
Python package to run the following example.
In addition, you may need to log in your HuggingFace account to access the pretrained model files.
Please refer to HuggingFace login.
FP32/BF16
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
###################################################### # noqa F401
import argparse
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
# args
parser = argparse.ArgumentParser("Generation script (fp32/bf16 path)", add_help=False)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "bfloat16"],
default="float32",
help="choose the weight dtype and whether to enable auto mixed precision or not",
)
parser.add_argument(
"--max-new-tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument(
"--prompt", default="What are we having for dinner?", type=str, help="input prompt"
)
parser.add_argument("--greedy", action="store_true")
parser.add_argument("--batch-size", default=1, type=int, help="batch size")
args = parser.parse_args()
print(args)
# dtype
amp_enabled = True if args.dtype != "float32" else False
amp_dtype = getattr(torch, args.dtype)
# load model
model_id = "facebook/opt-125m"
config = AutoConfig.from_pretrained(
model_id, torchscript=True, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=amp_dtype,
config=config,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
model = model.eval()
model = model.to(memory_format=torch.channels_last)
# Intel(R) Extension for PyTorch*
#################### code changes #################### # noqa F401
model = ipex.llm.optimize(
model,
dtype=amp_dtype,
inplace=True,
deployment_mode=True,
)
###################################################### # noqa F401
# generate args
num_beams = 1 if args.greedy else 4
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=num_beams)
# input prompt
prompt = args.prompt
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)
prompt = [prompt] * args.batch_size
# inference
with torch.no_grad(), torch.inference_mode(), torch.cpu.amp.autocast(enabled=amp_enabled):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
gen_ids = model.generate(
input_ids,
max_new_tokens=args.max_new_tokens,
**generate_kwargs
)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
input_tokens_lengths = [x.shape[0] for x in input_ids]
output_tokens_lengths = [x.shape[0] for x in gen_ids]
total_new_tokens = [
o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)
]
print(gen_text, total_new_tokens, flush=True)
Smooth Quantization INT8
The typical steps shown in the example are:
Calibration process: Run the example script specifying
--calibration
, along with other related arguments. When the calibration process is completed, the quantization summary files would be generated.Model inference process: Run the example script without specifying
--calibration
. In this process the quantized model will be generated via the original model and the quantization config and summary files, and will generate results for the input prompt.
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
###################################################### # noqa F401
import argparse
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
# args
parser = argparse.ArgumentParser("Generation script (static quantization path)", add_help=False)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "bfloat16"],
default="float32",
help="choose the weight dtype and whether to enable auto mixed precision or not",
)
parser.add_argument(
"--max-new-tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument(
"--prompt", default="What are we having for dinner?", type=str, help="input prompt"
)
parser.add_argument("--greedy", action="store_true")
parser.add_argument("--batch-size", default=1, type=int, help="batch size")
parser.add_argument("--calibration", action="store_true")
parser.add_argument("--calibration-samples", default=512, type=int, help="total number of calibration samples")
parser.add_argument("--int8-qconfig", nargs="?", default="./qconfig.json", help="static quantization factors summary files generated by calibration")
parser.add_argument("--dataset", nargs="?", default="NeelNanda/pile-10k")
parser.add_argument("--alpha", default=0.5, type=float, help="alpha value for smoothquant")
args = parser.parse_args()
print(args)
# dtype
amp_enabled = True if args.dtype != "float32" and not calibration else False
amp_dtype = getattr(torch, args.dtype) if not calibration else torch.float32
# load model
model_id = "meta-llama/Llama-2-7b-hf"
config = AutoConfig.from_pretrained(
model_id, torchscript=True, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=amp_dtype,
config=config,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
model = model.eval()
model = model.to(memory_format=torch.channels_last)
num_beams = 1 if args.greedy else 4
beam_idx_tmp = torch.zeros(
(2048, int(args.batch_size * num_beams)), dtype=torch.long
).contiguous()
global_past_key_value = [
(
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
torch.zeros(
[
1,
model.config.num_attention_heads,
1,
int(
model.config.hidden_size
/ model.config.num_attention_heads
),
]
).contiguous(),
torch.zeros(
[
1,
user_model.config.num_attention_heads,
1,
int(
model.config.hidden_size
/ model.config.num_attention_heads
),
]
).contiguous(),
beam_idx_tmp,
)
for i in range(model.config.num_hidden_layers)
]
# Intel(R) Extension for PyTorch*
#################### code changes #################### # noqa F401
class Calibration:
def __init__(self, dataset, tokenizer, batch_size=1, pad_val=1, pad_max=512):
self.dataset = dataset
self.tokenizer = tokenizer
self.batch_size = batch_size
self.pad_val = pad_val
self.pad_max = pad_max
# tokenize the dataset
self.dataset = self.dataset.map(self.tokenize_function, batched=True)
self.dataset.set_format(type="torch", columns=["input_ids"])
@torch.no_grad()
def tokenize_function(self, examples):
if "prompt" in examples:
example = self.tokenizer(examples["prompt"])
elif "text" in examples:
example = self.tokenizer(examples["text"])
elif "code" in examples:
example = self.tokenizer(examples["code"])
return example
@torch.no_grad()
def collate_batch(self, batch):
position_ids_padded = []
input_ids_padded = []
last_ind = []
attention_mask_padded = []
for text in batch:
input_ids = text["input_ids"]
input_ids = (
input_ids[: int(self.pad_max)]
if len(input_ids) > int(self.pad_max)
else input_ids
)
last_ind.append(input_ids.shape[0] - 1)
attention_mask = torch.ones(len(input_ids))
position_ids = torch.arange(len(input_ids))
input_ids_padded.append(input_ids)
attention_mask_padded.append(attention_mask)
position_ids_padded.append(position_ids)
return (
(
torch.vstack(input_ids_padded),
torch.vstack(attention_mask_padded),
torch.vstack(position_ids_padded),
tuple(global_past_key_value),
),
torch.tensor(last_ind),
)
calib_dataset = load_dataset(args.dataset, split="train")
calib_evaluator = Calibration(calib_dataset, tokenizer, args.batch_size)
calib_dataloader = DataLoader(
calib_evaluator.dataset,
batch_size=1,
shuffle=False,
collate_fn=calib_evaluator.collate_batch,
)
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=args.alpha)
if args.calibration:
example_inputs = None
for i, (
(input_ids, attention_mask, position_ids, past_key_values),
last_ind,
) in enumerate(calib_dataloader):
example_inputs =
(input_ids, attention_mask, position_ids, past_key_values)
break
from intel_extension_for_pytorch.quantization import prepare, convert
model = ipex.llm.optimize(
model.eval(),
dtype=amp_dtype,
quantization_config=qconfig,
inplace=True,
deployment_mode=False,
)
prepared_model = prepare(
model.eval(), qconfig, example_inputs=example_inputs
)
with torch.no_grad():
for i, (
(input_ids, attention_mask, position_ids, past_key_values),
last_ind,
) in enumerate(calib_dataloader):
if i == args.calibration_samples:
break
prepared_model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
prepared_model.save_qconf_summary(qconf_summary=args.int8_qconfig)
print("calibration Done! Will exit and please launch model quantization and benchmark")
exit(0)
else:
model = ipex.llm.optimize(
model.eval(),
dtype=amp_dtype,
quantization_config=qconfig,
qconfig_summary_file=args.int8_qconfig,
inplace=True,
deployment_mode=True,
)
print("model quantization - Done!")
###################################################### # noqa F401
# generate args
num_beams = 1 if args.greedy else 4
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=num_beams)
# input prompt
prompt = args.prompt
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)
prompt = [prompt] * args.batch_size
# inference
with torch.no_grad(), torch.inference_mode(), torch.cpu.amp.autocast(enabled=amp_enabled):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
gen_ids = model.generate(
input_ids,
max_new_tokens=args.max_new_tokens,
**generate_kwargs
)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
input_tokens_lengths = [x.shape[0] for x in input_ids]
output_tokens_lengths = [x.shape[0] for x in gen_ids]
total_new_tokens = [
o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)
]
print(gen_text, total_new_tokens, flush=True)
Weight Only Quantization INT8/INT4
import torch
#################### code changes #################### # noqa F401
import intel_extension_for_pytorch as ipex
###################################################### # noqa F401
import argparse
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
# args
parser = argparse.ArgumentParser("Generation script (weight only quantization path)", add_help=False)
parser.add_argument(
"--dtype",
type=str,
choices=["float32", "bfloat16"],
default="float32",
help="choose the weight dtype and whether to enable auto mixed precision or not",
)
parser.add_argument(
"--max-new-tokens", default=32, type=int, help="output max new tokens"
)
parser.add_argument(
"--prompt", default="What are we having for dinner?", type=str, help="input prompt"
)
parser.add_argument("--greedy", action="store_true")
parser.add_argument("--batch-size", default=1, type=int, help="batch size")
# Intel(R) Extension for PyTorch*
#################### code changes #################### # noqa F401
parser.add_argument(
"--lowp-mode",
choices=["AUTO", "BF16", "FP32", "INT8", "FP16"],
default="AUTO",
type=str,
help="low precision mode for weight only quantization. "
"It indicates data type for computation for speedup at the cost "
"of accuracy. Unrelated to activation or weight data type."
"It is not supported yet to use lowp_mode=INT8 for INT8 weight, "
"falling back to lowp_mode=BF16 implicitly in this case."
"If set to AUTO, lowp_mode is determined by weight data type: "
"lowp_mode=BF16 is used for INT8 weight "
"and lowp_mode=INT8 used for INT4 weight",
)
parser.add_argument(
"--weight-dtype",
choices=["INT8", "INT4"],
default="INT8",
type=str,
help="weight data type for weight only quantization. Unrelated to activation"
" data type or lowp-mode. If `--low-precision-checkpoint` is given, weight"
" data type is always INT4 and this argument is not needed.",
)
parser.add_argument(
"--low-precision-checkpoint",
default="",
type=str,
help="Low precision checkpoint file generated by calibration, such as GPTQ. It contains"
" modified weights, scales, zero points, etc. For better accuracy of weight only"
" quantization with INT4 weight.",
)
###################################################### # noqa F401
args = parser.parse_args()
print(args)
# dtype
amp_enabled = True if args.dtype != "float32" else False
amp_dtype = getattr(torch, args.dtype)
# load model
model_id = "facebook/opt-125m"
config = AutoConfig.from_pretrained(
model_id, torchscript=True, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=amp_dtype,
config=config,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
model = model.eval()
model = model.to(memory_format=torch.channels_last)
# Intel(R) Extension for PyTorch*
#################### code changes #################### # noqa F401
from intel_extension_for_pytorch.quantization import WoqWeightDtype
weight_dtype = WoqWeightDtype.INT4 if args.weight_dtype == "INT4" else WoqWeightDtype.INT8
if args.lowp_mode == "INT8":
lowp_mode = ipex.quantization.WoqLowpMode.INT8
elif args.lowp_mode == "FP32":
lowp_mode = ipex.quantization.WoqLowpMode.NONE
elif args.lowp_mode == "FP16":
lowp_mode = ipex.quantization.WoqLowpMode.FP16
elif args.lowp_mode == "BF16":
lowp_mode = ipex.quantization.WoqLowpMode.BF16
else: # AUTO
if args.low_precision_checkpoint != "" or weight_dtype == WoqWeightDtype.INT4:
lowp_mode = ipex.quantization.WoqLowpMode.INT8
else:
lowp_mode = ipex.quantization.WoqLowpMode.BF16
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype, lowp_mode=lowp_mode
)
if args.low_precision_checkpoint != "":
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
else:
low_precision_checkpoint = None
model = ipex.llm.optimize(
model.eval(),
dtype=amp_dtype,
quantization_config=qconfig,
low_precision_checkpoint=low_precision_checkpoint,
deployment_mode=True,
inplace=True,
)
###################################################### # noqa F401
# generate args
num_beams = 1 if args.greedy else 4
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=num_beams)
# input prompt
prompt = args.prompt
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)
prompt = [prompt] * args.batch_size
# inference
with torch.no_grad(), torch.inference_mode(), torch.cpu.amp.autocast(enabled=amp_enabled):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
gen_ids = model.generate(
input_ids,
max_new_tokens=args.max_new_tokens,
**generate_kwargs
)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
input_tokens_lengths = [x.shape[0] for x in input_ids]
output_tokens_lengths = [x.shape[0] for x in gen_ids]
total_new_tokens = [
o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)
]
print(gen_text, total_new_tokens, flush=True)
Note: Please check LLM Best Known Practice Page for detailed environment setup and LLM workload running instructions.
C++
To work with libtorch, C++ library of PyTorch, Intel® Extension for PyTorch* provides its C++ dynamic library as well. The C++ library is supposed to handle inference workload only, such as service deployment. For regular development, use the Python interface. Unlike using libtorch, no specific code changes are required. Compilation follows the recommended methodology with CMake. Detailed instructions can be found in PyTorch tutorial.
During compilation, Intel optimizations will be activated automatically once C++ dynamic library of Intel® Extension for PyTorch* is linked.
The example code below works for all data types.
example-app.cpp
#include <torch/script.h>
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
torch::jit::script::Module module;
try {
module = torch::jit::load(argv[1]);
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::vector<torch::jit::IValue> inputs;
torch::Tensor input = torch::rand({1, 3, 224, 224});
inputs.push_back(input);
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << std::endl;
std::cout << "Execution finished" << std::endl;
return 0;
}
CMakeLists.txt
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)
find_package(IPEX REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_IPEX_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 17)
Command for compilation
$ cd examples/cpu/inference/cpp
$ mkdir build
$ cd build
$ cmake -DCMAKE_PREFIX_PATH=<LIBPYTORCH_PATH> ..
$ make
If Found IPEX is shown as with a dynamic library path, the extension had been linked into the binary. This can be verified with Linux command ldd.
$ cmake -DCMAKE_PREFIX_PATH=/workspace/libtorch ..
-- The C compiler identification is GNU XX.X.X
-- The CXX compiler identification is GNU XX.X.X
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
CMake Warning at /workspace/libtorch/share/cmake/Torch/TorchConfig.cmake:22 (message):
static library kineto_LIBRARY-NOTFOUND not found.
Call Stack (most recent call first):
/workspace/libtorch/share/cmake/Torch/TorchConfig.cmake:127 (append_torchlib_if_found)
/workspace/libtorch/share/cmake/IPEX/IPEXConfig.cmake:84 (FIND_PACKAGE)
CMakeLists.txt:4 (find_package)
-- Found Torch: /workspace/libtorch/lib/libtorch.so
-- Found IPEX: /workspace/libtorch/lib/libintel-ext-pt-cpu.so
-- Configuring done
-- Generating done
-- Build files have been written to: examples/cpu/inference/cpp/build
$ ldd example-app
...
libtorch.so => /workspace/libtorch/lib/libtorch.so (0x00007f3cf98e0000)
libc10.so => /workspace/libtorch/lib/libc10.so (0x00007f3cf985a000)
libintel-ext-pt-cpu.so => /workspace/libtorch/lib/libintel-ext-pt-cpu.so (0x00007f3cf70fc000)
libtorch_cpu.so => /workspace/libtorch/lib/libtorch_cpu.so (0x00007f3ce16ac000)
...
libdnnl_graph.so.0 => /workspace/libtorch/lib/libdnnl_graph.so.0 (0x00007f3cde954000)
...
Intel® AI Reference Models
Use cases that have already been optimized by Intel engineers are available at Intel® AI Reference Models (former Model Zoo). A number of PyTorch use cases for benchmarking are also available in the benchmarks. You can get performance benefits out-of-the-box by simply running scripts in the Intel® AI Reference Models.