DistributedDataParallel (DDP)

Introduction

DistributedDataParallel (DDP) is a PyTorch* module that implements multi-process data parallelism across multiple GPUs and machines. With DDP, the model is replicated on every process, and each model replica is fed a different set of input data samples. DDP enables overlapping between gradient communication and gradient computations to speed up training. Please refer to DDP Tutorial for an introduction to DDP.

The PyTorch Collective Communication (c10d) library supports communication across processes. To run DDP on GPU, we use Intel® oneCCL Bindings for Pytorch* (formerly known as torch-ccl) to implement the PyTorch c10d ProcessGroup API (https://github.com/intel/torch-ccl). It holds PyTorch bindings maintained by Intel for the Intel® oneAPI Collective Communications Library* (oneCCL), a library for efficient distributed deep learning training implementing such collectives as allreduce, allgather, and alltoall. Refer to oneCCL Github page for more information about oneCCL.

Installation of Intel® oneCCL Bindings for Pytorch*

To use PyTorch DDP on GPU, install Intel® oneCCL Bindings for Pytorch* as described below.

Install PyTorch and Intel® Extension for PyTorch*

Make sure you have installed PyTorch and Intel® Extension for PyTorch* successfully. For more detailed information, check installation guide.

Install Intel® oneCCL Bindings for Pytorch*

Install from source:

Installation for CPU:

git clone https://github.com/intel/torch-ccl.git -b v2.1.0+cpu
cd torch-ccl
git submodule sync
git submodule update --init --recursive
python setup.py install

Installation for GPU:

  • Clone the oneccl_bindings_for_pytorch

git clone https://github.com/intel/torch-ccl.git -b v2.1.300+xpu
cd torch-ccl
git submodule sync 
git submodule update --init --recursive
  • Install oneccl_bindings_for_pytorch

Option 1: build with oneCCL from third party

COMPUTE_BACKEND=dpcpp python setup.py install

Option 2: build without oneCCL and use oneCCL in system (Recommend)

We recommend to use apt/yum/dnf to install the oneCCL package. Refer to Base Toolkit Installation for adding the APT/YUM/DNF key and sources for first-time users.

Reference commands:

sudo apt install intel-oneapi-ccl-devel=2021.11.1-6
sudo yum install intel-oneapi-ccl-devel=2021.11.1-6
sudo dnf install intel-oneapi-ccl-devel=2021.11.1-6

Compile with commands below.

export INTELONEAPIROOT=/opt/intel/oneapi
USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py install

Install from prebuilt wheel:

Prebuilt wheel files for CPU, GPU with generic Python* and GPU with Intel® Distribution for Python* are released in separate repositories.

# Generic Python* for CPU
REPO_URL: https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
# Generic Python* for GPU
REPO_URL: https://pytorch-extension.intel.com/release-whl/stable/xpu/us/

Installation from either repository shares the command below. Replace the place holder <REPO_URL> with a real URL mentioned above.

python -m pip install oneccl_bind_pt --extra-index-url <REPO_URL>

Runtime Dynamic Linking

  • If torch-ccl is built with oneCCL from third party or installed from prebuilt wheel: Dynamic link oneCCL and Intel MPI libraries:

source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl;print(torch_ccl.cwd)")/env/setvars.sh

Dynamic link oneCCL only (not including Intel MPI):

source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl;print(torch_ccl.cwd)")/env/vars.sh 
  • If torch-ccl is built without oneCCL and use oneCCL in system, dynamic link oneCCl from oneAPI basekit:

source <ONEAPI_ROOT>/ccl/latest/env/vars.sh

Note: Make sure you have installed basekit when using Intel® oneCCL Bindings for Pytorch* on Intel® GPUs. If the basekit is installed with a package manager, <ONEAPI_ROOT> is /opt/intel/oneapi.

DDP Usage

DDP follows its usage in PyTorch. To use DDP with Intel® Extension for PyTorch*, make the following modifications to your model script:

  1. Import the necessary packages.

import torch
import intel_extension_for_pytorch 
import oneccl_bindings_for_pytorch
  1. Initialize the process group with ccl backend.

dist.init_process_group(backend='ccl')
  1. For DDP with each process exclusively works on a single GPU, set the device ID as local rank. This step is not required for usage on CPU.

device = "xpu:{}".format(args.local_rank)
torch.xpu.set_device(device)
  1. Wrap model by DDP.

model = model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])

Note: For single-device modules, device_ids can contain exactly one device id, which represents the only GPU device where the input module corresponding to this process resides. Alternatively, device_ids can be None.

Note: When using torch.xpu.optimize for distributed training with low precision, the torch.xpu.manual_seed(seed_number) is needed to make sure the master weight is the same on all ranks.

Example Usage (MPI launch for single node):

Intel® oneCCL Bindings for Pytorch* recommends MPI as the launcher to start multiple processes. Here’s an example to illustrate such usage.

Dynamic link oneCCL and Intel MPI libraries:

source $(python -c "import oneccl_bindings_for_pytorch as torch_ccl;print(torch_ccl.cwd)")/env/setvars.sh
# Or
source <ONEAPI_ROOT>/ccl/latest/env/vars.sh

Example_DDP.py

"""
This example shows how to use MPI as the launcher to start DDP on single node with multiple devices.
"""
import os
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import intel_extension_for_pytorch
import oneccl_bindings_for_pytorch


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(4, 5)

    def forward(self, input):
        return self.linear(input)


if __name__ == "__main__":

    torch.xpu.manual_seed(123)  # set a seed number
    mpi_world_size = int(os.environ.get('PMI_SIZE', -1))
    mpi_rank = int(os.environ.get('PMI_RANK', -1))
    if mpi_world_size > 0:
        os.environ['RANK'] = str(mpi_rank)
        os.environ['WORLD_SIZE'] = str(mpi_world_size)
    else:
        # set the default rank and world size to 0 and 1
        os.environ['RANK'] = str(os.environ.get('RANK', 0))
        os.environ['WORLD_SIZE'] = str(os.environ.get('WORLD_SIZE', 1))
    os.environ['MASTER_ADDR'] = '127.0.0.1'  # your master address
    os.environ['MASTER_PORT'] = '29500'  # your master port

    # Initialize the process group with ccl backend
    dist.init_process_group(backend='ccl')

    # For single-node distributed training, local_rank is the same as global rank
    local_rank = dist.get_rank()
    # Only set device for distributed training on GPU
    device = "xpu:{}".format(local_rank)
    model = Model().to(device)
    if dist.get_world_size() > 1:
        model = DDP(model, device_ids=[device])

    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    loss_fn = nn.MSELoss().to(device)
    for i in range(3):
        print("Runing Iteration: {} on device {}".format(i, device))
        input = torch.randn(2, 4).to(device)
        labels = torch.randn(2, 5).to(device)
        # forward
        print("Runing forward: {} on device {}".format(i, device))
        res = model(input)
        # loss
        print("Runing loss: {} on device {}".format(i, device))
        L = loss_fn(res, labels)
        # backward
        print("Runing backward: {} on device {}".format(i, device))
        L.backward()
        # update
        print("Runing optim: {} on device {}".format(i, device))
        optimizer.step()

Running command:

mpirun -n 2 -l python Example_DDP.py

DDP scaling API (GPU Only)

For using one GPU card with multiple tiles, each tile could be regarded as a device for explicit scaling. We provide a DDP scaling API to enable DDP on one GPU card in GitHub repo.

Usage of DDP scaling API

Note: This API supports GPU devices on one card.

Args:
model: model to be parallelized
train_dataset: dataset for training

If you have a model running on a single tile, you only need to make minor changes to enable the DDP training by following these steps:

  1. Import the API:

try:
    from intel_extension_for_pytorch.xpu.single_card import single_card_dist
except ImportError:
    raise ImportError("single_card_dist not available!")
  1. Use multi_process_spawn launcher as a torch.multiprocessing wrapper.

single_card_dist.multi_process_spawn(main_worker, (args, )) # put arguments of main_worker into a tuple
  1. Usage of this API:

dist = single_card_dist(model, train_dataset)
local_rank, model, train_sampler = dist.rank, dist.model, dist.train_sampler
  1. Set in the model training:

for epoch in range ...
    train_sampler.set_epoch(epoch)
  1. Adjust the model to call local_rank, model, and train_sampler as shown here:

  • device: get the xpu information used in model training

xpu = "xpu:{}".format(local_rank)
print("DDP Use XPU: {} for training".format(xpu))
  • model: use the model warpped by DDP in the following training

  • train_sampler: use the train_sampler to get the train_loader

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    num_workers=args.workers, pin_memory=True, sampler=train_sampler)

Then you can start your model training on multiple GPU devices of one card.