Distributed Training with PyTorch
a summary of some experiences about distributed training with PyTorch.
1-Background
This part consists of some basic concepts about the distributed training.
In the following part of this article, we consider the case where we have 2 nodes, and there are 4 GPUs on each node (2 nodes * 4 GPUs/node = 8 GPUs
in total).
Some concepts for DDP with PyTorch
-
group/world
: process group,a group of all processe. In default, there is usually one process group for one job. -
world_size
: number of processes in a group\world; -
rank
: The global index for each process. Within a process group, each process has an independent index for the inter-process communication. The node withrank=0
is the master node. -
local_rank
: The local index of a certain process on the node that it exists. For example,rank=1,local_rank=3
indicates the 4-th GPU on the 2-th node.
Remark: The relationship between rank/process and GPU: one rank or process can contain either one GPU or multiple GPUs. The recommended manner is “one process for one GPU”.
Core arguments of Slurm
#SBATCH --nodes=2 : number of the nodes (or machines)
#SBATCH --ntasks=8 : number of the tasks/processes. Usually set as the number of GPUs, one GPU per task/process.
#SBATCH --ntasks-per-node=4 : (Recommended) number of the tasks/processes on each node. Usually set as the number of GPUs on each node, one GPU per task/process. (Specify by one of `--ntasks` and `--ntasks-per-node`)
#SBATCH --cpus-per-task=2 : number of CPU cores for each task/process.
#### How to request GPUs
#SBATCH --gpus-per-task=v100:1 : (Recommended) request n GPUs for each task. Usually set to 1 (one GPU per task).
#SBATCH --gpus-per-node=v100:n : request n GPUs for each node. Usually set to the number of tasks on each node (one GPU per task).
#### How to request RAM
#SBATCH --mem=32G : allocate 32GB on each node
#SBATCH --mem-per-gpu=8G : (Recommended) allocate memory w.r.t. each GPU
#SBATCH --mem-per-cpu=4G : allocate memory w.r.t. each CPU core
srun main.py
Some useful environment variables after submitting a Slurm job
If we execute srun python main.py
with Slurm
, srun
command will execute the program in ntask
individual processes, and each process will got the following environmental variables (some are same among processes, some are specific for each process):
-
SLURM_JOBID
: job ID -
SLURM_NTASKS/SLURM_NPROCS
: number of tasks or processes -
SLURM_NNODES/SLURM_JOB_NUM_NODES
: number of nodes -
SLURM_NTASKS_PER_NODE
: number of tasks/processes on each node -
SLURM_GPUS_ON_NODE
:number of available GPU on each node -
SLURM_PROCID
:the (global) rank -
SLURM_LOCALID
:the locak rank -
SLURM_NODEID
:the rank of the node (likenode_rank
)
How to get the IP address of the master node
-
SLURMD_NODELIST
: all allocated list in a job -
SLURMD_NODENAME
: to get the node name on which the current process is running on -
SLURM_LAUNCH_NODE_IPADDR
: the IP address of the launch node (on which the task launch was initiated or thesrun
command was executed). -
SLURM_SRUN_COMM_HOST
: the node for slurm communication -
SLURM_SRUN_COMM_PORT
: the port for slurm communication. (Attention: do not pass this port todist.init_process_group()
)
A tricky way to get the IP address of the master node
subprocess.getoutput(f"scontrol show hostname ${SLURMD_NODELIST} | head -n 1")
The architecture of distributed training
RingALLReduce
[To be updated.]
2-torch.nn.DataParallel (Deprecated)
Features:
- a single process for all GPUs
- only support Single-Node Nulti-GPU mode
This method is too slow and is not recommended.
3-torch.nn.DistributedDataParallel
Features:
- Support Multi-Node Multi-GPU
- Each GPU has its own process, each process has its own model and optimizer;
- Only limited data (i.e., gradient) need to be communicated among processes;
- After the gradient computation within each process,
RingALLReduce
was used to average the gradient information, then the gradient was broadcasted from the node withrank=0
to other nodes
Some functions/classes in PyTorch
-
torch.nn.parallel.DistributedDataParallel()
: wrap the model- Used after
model.cuda()
- Used after
-
torch.nn.SyncBatchNorm.convert_sync_batchnorm()
: convert the BN layers in the model -
torch.utils.data.distributed.DistributedSampler()
: split the data into each processe.- via
dataloader=DataLoader(dataset, ..., sampler=DistributedSampler(dataset))
- ([Attention] do not forget to call
sampler.set_epoch(current_epoch)
orloader.sampler.set_epoch(current_epoch)
within each epoch to guarantee the data splitting with in each epoch is different)
- via
-
torch.distributed.init_process_group()
: initialize the process group. Tell each process “how large the process group is”, “which rank you are in”- This function can read environmental variables
3.0-[Take-away] unified code
Flexible conversion between Single-Node Multi-GP and Multi-Node Multi-GPU
def init_distributed(backend="nccl", port=None):
### check the launch method
# if launched with 'torch.distributed.launch/torchrun' on single node,
# os.environ will has the environmental variable "RANK"
if "RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# otherwise, using `slurm+srun` as default,
# `os.environ` will has the environmental variable `SLURM_JOB_ID`
elif "SLURM_JOB_ID" in os.environ:
rank = int(os.environ["SLURM_PROCID"])
local_rank = int(os.environ["SLURM_LOCALID"])
world_size = int(os.environ["SLURM_NTASKS/SLURM_NPROCS"])
node_list = os.environ["SLURM_NODELIST"]
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
# specify master port
if port is not None:
os.environ["MASTER_PORT"] = str(port)
elif "MASTER_PORT" in os.environ:
pass # use MASTER_PORT in the environment variable
else:
os.environ["MASTER_PORT"] = "29005"
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = addr
os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]
os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
os.environ["RANK"] = os.environ["SLURM_PROCID"]
else:
raise EnvironmentError("Neither slurm or torchrun/torch.distributed.launch")
torch.cuda.set_device(local_rank)
# a simple call since `WORLD_SIZE/RANK/LOCAL_RANK/...` has been saved into `os.environ`
dist.init_process_group(backend=backend, init_method="env://")
return local_rank, rank, world_size
def main(args):
local_rank, rank, world_size = init_distributed()
main_worker(local_rank, rank, world_size, args)
if dist.is_initialized():
dist.destroy_process_group()
How to launch
### Single-Node Multi-GPU:'torchrun'
torchrun \
--nnodes=1 \
--nproc_per_node=4 \
--node_rank=0 \
main.py
### Multi-Node Multi-GPU:'slurm'
#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-taske=1
#SBATCH --cpus-per-task=2
export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}
srun python main.py
3.1-Launched with torch.distrubuted.launch
(deprecated, only for torch<1.9.0
)
Need to launch on each node
Access the environment variables:
local_rank = int(os.environ["LOCAL_RANK"]) ## adopted "--use_env"
rank = dist.get_rank()
# or
rank = int(os.environ["RANK"])
world_size = dist.get_world_size()
# or
world_size = int(os.environ["WORLD_SIZE"])
An example for main.py
# main.py
import torch
import argparse
import torch.distributed as dist
def main_worker(lcoal_rank, rank, world_size, cfg):
### training
...
def main():
################### if without `--use_env` #####################################
### If you don't use `--use_env`, you must add `--local_rank` in argparse
# During spawning, it automatically appends --local_rank=X as a CLI argument to each subprocess,
# but it does not inject LOCAL_RANK into os.environ.
# You must use argparse to explicitly parse that `--local_rank` argument in your script.
# But note that this argument don't need to be specify by user!
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int,
help='node rank for distributed training')
# Changed in version 2.0.0, `--local-rank` instead of `--local_rank`
args = parser.parse_args()
local_rank = args.local_rank
################### if with `--use_env` #####################################
local_rank = int(os.environ["local_rank"])
#
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# set device
torch.cuda.set_device(local_rank)
# set backend for GPU communication
dist.init_process_group(backend='nccl')
main_worker(local_rank, rank, world_size, cfg)
if dist.is_initialized():
dist.destroy_process_group()
How to launch
### execute on Node 0 (with 4 gpus):
python -m torch.distributed.launch \
--nnodes=2 --node_rank=0 --nproc_per_node=4 \
--master_addr='172.18.39.122' --master_port='29500' \
--use_env \
main.py
### execute on Node 1 (with 4 gpus):
python -m torch.distributed.launch \
--nnodes=2 --node_rank=1 --nproc_per_node=4 \
--master_addr='172.18.39.122' --master_port='29500' \
--use_env \
main.py
Explain the command:
python -m torch.distributed.launch --help
usage: launch.py [-h] [--nnodes NNODES] [--node_rank NODE_RANK] [--nproc_per_node NPROC_PER_NODE] [--master_addr MASTER_ADDR]
[--master_port MASTER_PORT] [--use_env] [-m] [--no_python]
[--logdir LOGDIR]
training_script ...
core arguments:
-
--init_method
: how to initialize the process group. If “init_method” and “store” are not specified, the default value “env://” that indicated “initialize by loading environmental variables”-
init_method="env://"
: initialize by loading environmental variables -
init_method="tcp://192.168.1.1:1234"
: initialize by the master IP and port of the node withrank=0
-
init_method="file://xxxx"
: initialize through a shared filesystem
-
-
--nnodes
: number of the nodes -
--node_rank
: the ranks of the nodes, start with “0” -
--nproc_per_node
: number of the processes (or GPUs) on each node -
--use_env
: if loading the arguments from the emvironmental variables or not, If used, will load the following environmental variables:-
MASTER_ADDR
: the IP address of the master node -
MASTER_PORT
: an idle port on the master node -
WORLD_SIZE
: size of the process group -
RANK
: the rank of the current node -
LOCAL_RANK
: the local rank of the current process on current node
-
3.2-Launched with torchrun
(Recommended, since torch>=1.9.0
)
New features compared to torch.distributed.launch
:
- Failover:can auto-relaunch all the workers when a worker-failuer happened;
- Elastic: can dynamically add or delete a node.
### execute on Node 0:
torchrun --nnodes=2 --node_rank=0 --nproc_per_node=4 \
--master_addr='172.18.39.122' --master_port='29500' \
main.py
### execute on Node 2:
torchrun --nnodes=2 --node_rank=1 --nproc_per_node=4 \
--master_addr='172.18.39.122' --master_port='29500' \
main.py
$ torchrun -h
usage: torchrun [-h] [--nnodes NNODES] [--nproc_per_node NPROC_PER_NODE] [--rdzv_backend RDZV_BACKEND] [--rdzv_endpoint RDZV_ENDPOINT]
[--rdzv_id RDZV_ID] [--rdzv_conf RDZV_CONF] [--standalone] [--max_restarts MAX_RESTARTS] [--monitor_interval MONITOR_INTERVAL]
[--start_method {spawn,fork,forkserver}] [--role ROLE] [-m] [--no_python] [--run_path] [--log_dir LOG_DIR] [-r REDIRECTS]
[-t TEE] [--node_rank NODE_RANK] [--master_addr MASTER_ADDR] [--master_port MASTER_PORT]
training_script ...
Torch Distributed Elastic Training Launcher
positional arguments: training_script
: Full path to the (single GPU) training program/script to be launched in parallel, followed by all the arguments for the training script. training_script_args
:
Some core arguments (only the new features compared to torch.distributed.launch
)
-
--nnodes=1:3
: the range of the nodes used for training (e.g., minimum 1 node, maximum 3 nodes). (Can dynamically add the nodes if necessary) -
--max_restarts=3
: The maximum number for rebooting the process group. Please note that node fail, node scale down and node scale up will cause the restart. -
--rdzv_id=1
: all nodes will use the same job id -
--rdzv_backend
: the backend for rendezvous. Support “c10d” and “etcd” in default. “rendezvous” was used for the communication and scheduling among multiple nodes. -
--rdzv_endpoint
:the address of rendezvous. Should be “host ip” and “port” of a node.
3.3-Launched with torch.multiprocessing.spawn
- Single-node, multi-GPU setups.
- Good for debugging or simple setups where you don’t want to rely on external launchers like
torchrun
ortorch.distributed.launch
. - Manual handling of
init_process_group()
andlocal_rank
is required inside yourmain_worker()
.
# main.py
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def main_worker(local_rank, world_size, args):
rank = args.node_rank * args.nproc_per_node + local_rank
# initialization
dist.init_process_group(backend='nccl',
init_method='tcp://{}:{}'.format(args.master_addr, args.master_port),
world_size=world_size,
rank=rank)
# set evice
torch.cuda.set_device(local_rank)
### training
...
def main():
parser.add_argument(
"--nnodes", default=1, type=int, help="number of nodes for distributed training"
)
parser.add_argument(
"--nproc_per_node",
default=1,
type=int,
help="number of processes(GPUs_ per node for distributed training",
)
parser.add_argument(
"--master_addr", default="127.0.0.1", type=str, help="master node IP address"
)
parser.add_argument("--master_port", default="12345", type=str, help="master node free port")
parser.add_argument(
"--node_rank", default=0, type=int, help="node rank for distributed training"
)
args = parser.parse_args()
args.world_size = args.nnodes * args.nproc_per_node
### `mp.spawn()` will allocate "local_rank" automatically as first argument)
# that's why we only have `args=(args.world_size, args)` here
mp.spawn(main_worker, nprocs=args.nproc_per_node, args=(args.world_size, args))
if dist.is_initialized():
dist.destroy_process_group()
How to launch
# execute on Node 0:
python main.py \
--nnodes=2 --node_rank=0 --nproc_per_node=4 \
--master_addr='172.18.39.122' --master_port='29500'
# execute on Node 1:
python main.py \
--nnodes=2 --node_rank=1 --nproc_per_node=4 \
--master_addr='172.18.39.122' --master_port='29500'
3.4-Launched with scheduler Slurm
Many clusters are managed by Slurm scheduler. With the command like srun
, we do not need to execute/launch the script on each node.
How to organize main.py
It is noteworthy that with srun
, some useful variables will be saved as environmental variables that we can access via os.environ["xxx"]
.
def main_worker(local_rank, rank, world_size, args):
print(f"launching: local_rank={local_rank}, rank={rank}, world_size={world_size}")
dist.init_process_group(backend="nccl")
# set device
torch.cuda.set_device(local_rank)
### training
...
def main():
# ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"])
# world_size = int(os.environ["SLURM_NTASKS"])
# node_rank = int(os.environ["SLURM_NODEID"])
# node_list = os.environ["SLURM_NODELIST"]
master_addr = subprocess.getoutput(f"scontrol show hostname {args.node_list} | head -n1")
# master_addr = os.environ["SLURM_LAUNCH_NODE_IPADDR"]
# save the following variables as environmental variables for dist.init_process_group()
os.environ["RANK"] = os.environ["SLURM_PROCID"]
os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = "25900"
local_rank = int(os.environ["SLURM_LOCALID"])
rank = int(os.environ["SLURM_PROCID"])
world_size = int(os.environ["SLURM_NTASKS"])
main_worker(local_rank, rank, world_size, args)
slurm script for job submission
#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=2
export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}
srun python main.py
Then run with
sbatch submitter.sh
4-Third-party tools
PyTorch-Lightning
Features:
- support multiple strategies
Example:
# train on 32 GPUs (4 nodes)
trainer = Trainer(accelerator="gpu", devices=8, strategy="ddp", num_nodes=4)
Apex:
Horovod:
5-Distributed evaluation or test
References:
Core methods:
-
torch.distributed.all_reduce()
: Documentation
output = model(images)
loss = criterion(output, target)
# synchronize the sliced data within each GPU for distributed evaluation
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= args.world_size
return rt
output = model(images)
loss = criterion(output, target)
#
torch.distributed.barrier()
reduced_loss = loss.data.clone()
dist.all_reduce(reduced_loss , op=torch.distributed.reduce_op.SUM)
reduced_loss /= world_size
Enjoy Reading This Article?
Here are some more articles you might like to read next: