/lib/libiomp5.so
```
### Benchmarking with Launcher Core Pinning
As described previously in [TorchServe with Launcher](#torchserve-with-launcher), launcher core pinning boosts performance of multi-worker inference. We'll demonstrate launcher core pinning with TorchServe benchmark, but keep in mind that launcher core pinning is a generic feature applicable to any TorchServe multi-worker inference use casese.
For example, assume running 4 workers
```
python benchmark-ab.py --workers 4
```
on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core. Launcher will bind worker 0 to cores 0-13, worker 1 to cores 14-27, worker 2 to cores 28-41, and worker 3 to cores 42-55.
All it needs to use TorchServe with launcher's core pinning is to enable launcher in `config.properties`.
Add the following lines to `config.properties` in the benchmark directory to use launcher's core pinning:
```
cpu_launcher_enable=true
```
CPU usage is shown as below:
![launcher_core_pinning](https://user-images.githubusercontent.com/93151422/159063975-e7e8d4b0-e083-4733-bdb6-4d92bdc10556.gif)
4 main worker threads were launched, then each launched a num_physical_cores/num_workers number (14) of threads affinitized to the assigned physical cores.
$ cat logs/model_log.log
2022-03-24 10:41:32,223 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:41:32,223 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:41:32,223 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:41:32,223 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:41:32,223 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:41:32,223 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so:/lib/libtcmalloc.so
2022-03-24 10:41:32,223 - __main__ - INFO - numactl -C 0-13 -m 0 /bin/python -u /lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9000
2022-03-24 10:49:03,760 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:03,761 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:03,762 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:03,762 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:03,762 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:03,762 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so:/lib/libtcmalloc.so
2022-03-24 10:49:03,763 - __main__ - INFO - numactl -C 14-27 -m 0 /bin/python -u /lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9001
2022-03-24 10:49:26,274 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:26,274 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:26,274 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:26,274 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:26,274 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:26,274 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so:/lib/libtcmalloc.so
2022-03-24 10:49:26,274 - __main__ - INFO - numactl -C 28-41 -m 1 /bin/python -u /lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9002
2022-03-24 10:49:42,975 - __main__ - INFO - Use TCMalloc memory allocator
2022-03-24 10:49:42,975 - __main__ - INFO - OMP_NUM_THREADS=14
2022-03-24 10:49:42,975 - __main__ - INFO - Using Intel OpenMP
2022-03-24 10:49:42,975 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2022-03-24 10:49:42,975 - __main__ - INFO - KMP_BLOCKTIME=1
2022-03-24 10:49:42,975 - __main__ - INFO - LD_PRELOAD=/lib/libiomp5.so:/lib/libtcmalloc.so
2022-03-24 10:49:42,975 - __main__ - INFO - numactl -C 42-55 -m 1 /bin/python -u /lib/python/site-packages/ts/model_service_worker.py --sock-type unix --sock-name /tmp/.ts.sock.9003
## Performance Boost with Intel® Extension for PyTorch\* and Launcher
![pdt_perf](https://user-images.githubusercontent.com/93151422/159067306-dfd604e3-8c66-4365-91ae-c99f68d972d5.png)
Above shows performance improvement of Torchserve with Intel® Extension for PyTorch\* and launcher on ResNet50 and BERT-base-uncased. Torchserve official [apache-bench benchmark](https://github.com/pytorch/serve/tree/master/benchmarks#benchmarking-with-apache-bench) on Amazon EC2 m6i.24xlarge was used to collect the results2. Add the following lines in ```config.properties``` to reproduce the results. Notice that launcher is configured such that a single instance uses all physical cores on a single socket to avoid cross socket communication and core overlap.
```
ipex_enable=true
cpu_launcher_enable=true
cpu_launcher_args=--node_id 0 --enable_jemalloc
```
Use the following command to reproduce the results.
```
python benchmark-ab.py --url {modelUrl} --input {inputPath} --concurrency 1
```
For example, run the following command to reproduce latency performance of ResNet50 with data type of Intel® Extension for PyTorch\* int8 and batch size of 1. Refer to [Creating and Exporting INT8 model for Intel® Extension for PyTorch\*](#creating-and-exporting-int8-model-for-intel-extension-for-pytorch) for steps to creating ```rn50_ipex_int8.mar``` file for ResNet50 with Intel® Extension for PyTorch\* int8 data type.
```
python benchmark-ab.py --url 'file:///model_store/rn50_ipex_int8.mar' --concurrency 1
```
For example, run the following command to reproduce latency performance of BERT with data type of Intel® Extension for PyTorch\* int8 and batch size of 1. Refer to [Creating and Exporting INT8 model for Intel® Extension for PyTorch\*](#creating-and-exporting-int8-model-for-intel-extension-for-pytorch) for steps to creating ```bert_ipex_int8.mar``` file for BERT with Intel® Extension for PyTorch\* int8 data type.
```
python benchmark-ab.py --url 'file:///model_store/bert_ipex_int8.mar' --input '../examples/Huggingface_Transformers/Seq_classification_artifacts/sample_text_captum_input.txt' --concurrency 1
```
3. Amazon EC2 m6i.24xlarge was used for benchmarking purpose only. For multi-core instances, Intel® Extension for PyTorch\* optimizations automatically scale and leverage full instance resources.