Troubleshooting
General Usage
Problem: Issues with the
+cpuPyTorch package.Cause: Certain Python packages may have PyTorch as a hard dependency. If you installed the
+cpuversion of PyTorch, installation of these packages might replace the+cpuversion with the default version released on Pypi.org.Solution: Reinstall the
+cpuversion back.
Problem: The workload running with Intel® Extension for PyTorch* occupies a remarkably large amount of memory.
Solution: Try to reduce the occupied memory size by setting the
weights_prepackparameter of theipex.optimize()function toFalse.
Problem: The
conv+bnfolding feature of theipex.optimize()function does not work if inference is done with a custom function:import torch import intel_pytorch_extension as ipex class Module(torch.nn.Module): def __init__(self): super(Module, self).__init__() self.conv = torch.nn.Conv2d(1, 10, 5, 1) self.bn = torch.nn.BatchNorm2d(10) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x def inference(self, x): return self.forward(x) if __name__ == '__main__': m = Module() m.eval() m = ipex.optimize(m, dtype=torch.float32, level="O0") d = torch.rand(1, 1, 112, 112) with torch.no_grad(): m.inference(d)
Cause: PyTorch FX limitation.
Solution: You can avoid this error by calling
m = ipex.optimize(m, level="O0"), which doesn’t apply ipex optimization, or disableconv+bnfolding by callingm = ipex.optimize(m, level="O1", conv_bn_folding=False).
Performance Regression
Some models may experience performance regression comparing to 2.0.x due to deprecation of the NNC feature in PyTorch*.
TorchDynamo
Problem: A workload that uses
torch.compile()fails to run or demonstrates poor performance.Cause: The support of
torch.compile()withipexas the backend is still an beta feature. Currently, the following HuggingFace models fail to run usingtorch.compile()withipexbackend due to memory issues:masked-language-modeling+xlm-roberta-base
casual-language-modeling+gpt2
casual-language-modeling+xlm-roberta-base
summarization+t5-base
text-classification+allenai-longformer-base-409
Solution: Use the
torch.jitAPIs and graph optimization APIs of the Intel® Extension for PyTorch*.
Dynamic Shape
Problem: When working with an NLP model inference with dynamic input data length using TorchScript (either
torch.jit.traceortorch.jit.script), performance with Intel® Extension for PyTorch* may be less than that without Intel® Extension for PyTorch*.Solution: Use the workaround below:
Python interface
torch._C._jit_set_texpr_fuser_enabled(False)
C++ interface
#include <torch/csrc/jit/passes/tensorexpr_fuser.h> torch::jit::setTensorExprFuserEnabled(false);
INT8
Problem: Limitations of dynamic shapes support of static quantization:
When an input shape is provided in runtime for the first time, execution could take longer time to compile a new kernel for this shape. Specifically, the new kernel compilation time could be long for complicated kernels.
Channels Last format won’t take effect with dynamic input shapes for CNN models at this time. Optimizations are undergoing.
Problem:
RuntimeError: Overflow when unpacking longwhen a tensor’s min max value exceeds int range while performing int8 calibration.Solution: Customize
QConfigto use min-max calibration method.
Problem: Models get large accuracy loss with the default quantization recipe.
Solution: Try using the the INT8 Recipe Tuning API to tune a recipe with satisfied accuracy loss.
Problem: Incorrect results with large tensors when calibrating with
quantize_per_tensor, when benchmarking with 1 OpenMP* thread (find more detailed info here.Solution: Editing your code following the pseudocode below can workaround this issue, if you do need to explicitly set
OMP_NUM_THREAEDS=1for benchmarking. However, there could be a performance regression if oneDNN graph compiler prototype feature is used.Workaround pseudocode:
# perform convert/trace/freeze with omp_num_threads > 1(N) torch.set_num_threads(N) prepared_model = prepare(model, input) converted_model = convert(prepared_model) traced_model = torch.jit.trace(converted_model, input) freezed_model = torch.jit.freeze(traced_model) # run freezed model to apply optimization pass freezed_model(input) # benchmarking with omp_num_threads = 1 torch.set_num_threads(1) run_benchmark(freezed_model, input)
For models with dynamic control flow, please try dynamic quantization. Users are likely to get performance gain for GEMM models.
Support for
EmbeddingBagwith INT8 when bag size > 1 is work in progress.
BFloat16
Problem: BF16 AMP(auto-mixed-precision) runs abnormally with the extension on the AVX2-only machine if the topology contains
Conv,Matmul,Linear, andBatchNormalization.Solution: TBD
Problem: A PyTorch* model containing
torch.nn.TransformerEncoderLayercomponent may encounter a RuntimeError in BF16 training or inference process if the model is optimized byipex.optimize()with arguments set to default values.Solution:
TransformerEncoderLayeroptimized byipex.optimize()with weight prepacking functionality enabled may encounter a weight dimension issue. The error can be avoided by disabling weight prepacking,model = ipex.optimize(model, weights_prepack=False).
Runtime Extension
The following limitations currently exist:
Runtime extension of
MultiStreamModuledoes not support DLRM inference, since the input of DLRM (EmbeddingBag specifically) cannot be simply batch split.Runtime extension of
MultiStreamModulehas poor performance of RNNT Inference comparing with native throughput mode. Only part of the RNNT models (joint_netspecifically) can be jit traced into graph. However, in one batch inference,joint_netis invoked multiple times. It increases the overhead ofMultiStreamModuleas input batch split, thread synchronization and output concat.
Result Correctness
Problem: Incorrect Conv and Linear result if the number of OMP threads is changed at runtime.
Cause: The oneDNN memory layout depends on the number of OMP threads, which requires the caller to detect the changes for the # of OMP threads while this release has not implemented it yet.