INT8 Recipe Tuning API (Experimental) [CPU] =========================================== This [new API](../api_doc.html#ipex.quantization.autotune) `ipex.quantization.autotune` supports INT8 recipe tuning by using Intel® Neural Compressor as the backend in Intel® Extension for PyTorch\*. In general, we provid default recipe in Intel® Extension for PyTorch\*, and we still recommend users to try out the default recipe first without bothering tuning. If the default recipe doesn't bring about desired accuracy, users can use this API to tune for a more advanced receipe. Users need to provide a prepared model and some parameters required for tuning. The API will return a tuned model with advanced recipe. ### Usage Example [//]: # (marker_feature_int8_autotune) ```python import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor import intel_extension_for_pytorch as ipex ######################################################################## # Reference for training portion: # https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html # Download training data from open datasets. training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), ) # Download test data from open datasets. test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor(), ) batch_size = 64 # Create data loaders. train_dataloader = DataLoader(training_data, batch_size=batch_size) test_dataloader = DataLoader(test_data, batch_size=1) for X, y in test_dataloader: print(f"Shape of X [N, C, H, W]: {X.shape}") print(f"Shape of y: {y.shape} {y.dtype}") break # Define model class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork() loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): # Compute prediction error pred = model(X) loss = loss_fn(pred, y) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() if batch % 100 == 0: loss, current = loss.item(), batch * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") model, optimizer = ipex.optimize(model, optimizer=optimizer) epochs = 5 for t in range(epochs): print(f"Epoch {t+1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) print("Done!") ######################################################################## ################################ QUANTIZE ############################## model.eval() def evaluate(dataloader, model): size = len(dataloader.dataset) model.eval() accuracy = 0 with torch.no_grad(): for X, y in dataloader: # X, y = X.to('cpu'), y.to('cpu') pred = model(X) accuracy += (pred.argmax(1) == y).type(torch.float).sum().item() accuracy /= size return accuracy # prepare model, do conv+bn folding, and init model quant_state. qconfig = ipex.quantization.default_static_qconfig data = torch.randn(1, 1, 28, 28) prepared_model = ipex.quantization.prepare( model, qconfig, example_inputs=data, inplace=False ) ######################## recipe tuning with INC ######################## def eval(prepared_model): accu = evaluate(test_dataloader, prepared_model) return float(accu) # print(eval(prepared_model)) tuned_model = ipex.quantization.autotune( prepared_model, test_dataloader, eval, sampling_sizes=[100], accuracy_criterion={"relative": 0.01}, tuning_time=0, ) ######################################################################## # run tuned model convert_model = ipex.quantization.convert(tuned_model) with torch.no_grad(): traced_model = torch.jit.trace(convert_model, data) traced_model = torch.jit.freeze(traced_model) traced_model(data) # save tuned qconfig file tuned_model.save_qconf_summary(qconf_summary="tuned_conf.json") print("Execution finished") ``` [//]: # (marker_feature_int8_autotune)