Fully Sharded Data Parallel (FSDP)

Introduction

Fully Sharded Data Parallel (FSDP) is a PyTorch* module that provides industry-grade solution for large model training. FSDP is a type of data parallel training, unlike DDP, where each process/worker maintains a replica of the model, FSDP shards model parameters, optimizer states and gradients across DDP ranks to reduce the GPU memory footprint used in training. This makes the training of some large-scale models feasible. Please refer to FSDP Tutorial for an introduction to FSDP.

To run FSDP on GPU, similar to DDP, we use Intel® oneCCL Bindings for Pytorch* (formerly known as torch-ccl) to implement the PyTorch c10d ProcessGroup API (https://github.com/intel/torch-ccl). It holds PyTorch bindings maintained by Intel for the Intel® oneAPI Collective Communications Library* (oneCCL), a library for efficient distributed deep learning training implementing collectives such as AllGather, ReduceScatter, and other needed by FSDP. Refer to oneCCL Github page for more information about oneCCL. To install Intel® oneCCL Bindings for Pytorch*, follow the same installation steps as for DDP.

FSDP Usage (GPU only)

FSDP is designed to align with PyTorch conventions. To use FSDP with Intel® Extension for PyTorch*, make the following modifications to your model script:

  1. Import the necessary packages.

import torch
import intel_extension_for_pytorch 
import oneccl_bindings_for_pytorch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1. Initialize the process group with ccl backend.

dist.init_process_group(backend='ccl')
  1. For FSDP with each process exclusively working on a single GPU, set the device ID as local rank.

torch.xpu.set_device("xpu:{}".format(rank))
# or
device = "xpu:{}".format(args.local_rank)
torch.xpu.set_device(device)
  1. Wrap model by FSDP.

model = model.to(device)
model = FSDP(model, device_id=device)

Note: for FSDP with XPU, you need to specify device_ids with XPU device; otherwise, it will trigger the CUDA path and throw an error.

Example

Here’s an example based on PyTorch FSDP Tutorial to illustrate the usage of FSDP on XPU and the necessary changes to switch from CUDA to an XPU case.

  1. Import necessary packages:

"""
Import Intel® extension for Pytorch\* and Intel® oneCCL Bindings for Pytorch\*
"""
import os
import argparse
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Import Intel® extension for Pytorch\* and Intel® oneCCL Bindings for Pytorch\*
import intel_extension_for_pytorch
import oneccl_bindings_for_pytorch

from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)
  1. Set up distributed training:

"""
Set the initialize the process group backend as Intel® oneCCL Bindings for Pytorch\*
"""
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group by Intel® oneCCL Bindings for Pytorch\*
    dist.init_process_group("ccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()
  1. Define the toy model for handwritten digit classification:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
  1. Define a training function:

"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'
"""
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    # XPU device should be formatted as string, replace the rank with '"xpu:{}".format(rank)'
    ddp_loss = torch.zeros(2).to("xpu:{}".format(rank))
    if sampler:
        sampler.set_epoch(epoch)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to("xpu:{}".format(rank)), target.to("xpu:{}".format(rank))
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target, reduction='sum')
        loss.backward()
        optimizer.step()
        ddp_loss[0] += loss.item()
        ddp_loss[1] += len(data)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
    if rank == 0:
        print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
  1. Define a validation function:

"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'
"""
def test(model, rank, world_size, test_loader):
    model.eval()
    correct = 0
    # XPU device should be formatted as string, replace the rank with '"xpu:{}".format(rank)'
    ddp_loss = torch.zeros(3).to("xpu:{}".format(rank))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to("xpu:{}".format(rank)), target.to("xpu:{}".format(rank))
            output = model(data)
            ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
            ddp_loss[2] += len(data)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)

    if rank == 0:
        test_loss = ddp_loss[0] / ddp_loss[2]
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
            100. * ddp_loss[1] / ddp_loss[2]))
  1. Define a distributed training function that wraps the model in FSDP:

"""
Change the device related logic from 'rank' to '"xpu:{}".format(rank)'.
Specify the argument `device_ids` as XPU device ("xpu:{}".format(rank)) in FSDP API.
"""
def fsdp_main(rank, world_size, args):
    setup(rank, world_size)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset1 = datasets.MNIST('../data', train=True, download=True,
                        transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                        transform=transform)

    sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)

    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    xpu_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(xpu_kwargs)
    test_kwargs.update(xpu_kwargs)

    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=100
    )
    torch.xpu.set_device("xpu:{}".format(rank))


    init_start_event = torch.xpu.Event(enable_timing=True)
    init_end_event = torch.xpu.Event(enable_timing=True)

    model = Net().to("xpu:{}".format(rank))
    # Specify the argument `device_ids` as XPU device ("xpu:{}".format(rank)) in FSDP API.
    model = FSDP(model, device_id="xpu:{}".format(rank))

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    init_start_event.record()
    for epoch in range(1, args.epochs + 1):
        train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        test(model, rank, world_size, test_loader)
        scheduler.step()

    init_end_event.record()

    if rank == 0:
        print(f"XPU event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
        print(f"{model}")

    if args.save_model:
        # use a barrier to make sure training is done on all ranks
        dist.barrier()
        states = model.state_dict()
        if rank == 0:
            torch.save(states, "mnist_cnn.pt")

    cleanup()
  1. Finally, parse the arguments and set the main function:

"""
Replace CUDA runtime API with XPU runtime API.
"""
if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    WORLD_SIZE = torch.xpu.device_count()
    mp.spawn(fsdp_main,
        args=(WORLD_SIZE, args),
        nprocs=WORLD_SIZE,
        join=True)
  1. Put the above code snippets to a python script FSDP_mnist_xpu.py, and run:

python FSDP_mnist_xpu.py