分布式训练与通信原语

随着模型规模指数级增长,单卡早已无法容纳整个模型。本文系统梳理数据/张量/流水线/序列/专家并行、NCCL 集合通信原语、Ring AllReduce、FSDP/ZeRO 以及通信-计算重叠等分布式训练核心机制。正文为 markdown,关键机制配有可交互动画(点按钮逐步演示)。

3.0 为什么需要分布式训练?

随着模型规模的指数增长(GPT-3: 175B, PaLM: 540B, Llama 3: 405B),单张 GPU 已经无法容纳整个模型的参数、梯度和优化器状态。分布式训练通过将计算和存储分摊到多个设备上,使得训练超大模型成为可能。

核心挑战: 在保持训练等价性(mathematical equivalence)的前提下,最小化通信开销、最大化 GPU 利用率。

三种基本并行策略可以组合使用(称为 3D Parallelism):

策略 切分对象 通信模式 适用场景
Data Parallelism (DP) 数据 (batch) AllReduce 梯度 模型能放入单卡
Tensor Parallelism (TP) 单层参数 AllReduce / AllGather 单层参数过大
Pipeline Parallelism (PP) 层间切分 点对点 Send/Recv 模型层数多

3.1 并行策略对比

数据并行 (Data Parallelism)

核心思想: 每个 GPU 持有模型的完整副本,将一个 mini-batch 均分为若干 micro-batch,各 GPU 独立做前向和反向,然后通过 AllReduce 同步梯度。

将 Mini-Batch(B=32)分发到 4 个 GPU,各持有完整模型副本,独立计算后通过 AllReduce 同步梯度,同步后所有 GPU 持有相同梯度,更新后参数一致。

优势与局限

  • 实现简单,PyTorch 的 DistributedDataParallel 开箱即用
  • 通信量:每次迭代需要 AllReduce 全部梯度(2M bytes,M 为参数量)
  • 局限:模型必须能放入单张 GPU 的显存
  • 局限:GPU 数量增多时,通信开销线性增长(但 Ring AllReduce 可缓解)

下面的动画逐步演示 Data Parallelism 的训练流程(数据切分 → Forward/Backward → AllReduce → 参数更新):

张量并行 (Tensor Parallelism)

核心思想: 将单个层的权重矩阵切分到多个 GPU 上,每个 GPU 只计算部分结果,通过通信合并。Megatron-LM 首次系统性地提出了这种方法。

  • Column Parallel(列切分): ,各 GPU 得到 Y = [Y₁, Y₂](concat),forward 无需通信。
  • Row Parallel(行切分): ,需要 AllReduce 合并部分和

Megatron-LM 的 MLP 切分策略

在 Transformer 的 MLP 中,将第一个线性层做列切分,第二个线性层做行切分,这样 forward 只需一次 AllReduce,backward 也只需一次 AllReduce:

下面的动画逐步演示 Tensor Parallel MLP 的计算流(列并行 → GeLU → 行并行 → AllReduce):

流水线并行 (Pipeline Parallelism)

核心思想: 将模型按层切分成若干 stage,每个 stage 放在不同的 GPU 上。数据以 micro-batch 的形式在各 stage 之间流动。例如 Stage 0(Layer 0-5)在 GPU 0,Stage 1(Layer 6-11)在 GPU 1,依此类推。

Bubble Ratio(气泡比)

Pipeline parallelism 的核心问题是 pipeline bubble —— GPU 空闲等待的时间。

其中 = pipeline stages 数, = micro-batches 数。增加 micro-batch 数可以降低 bubble ratio。

常见调度策略

  • GPipe: 先做完所有 forward micro-batches,再做所有 backward
  • 1F1B: 交替执行 forward 和 backward,减少峰值内存
  • Interleaved 1F1B: 每个 GPU 持有多个非连续 stage,进一步减少 bubble

下面的动画可切换 GPipe / 1F1B 两种调度,逐时间步观察 Forward/Backward 数据流与 bubble:

3D Parallelism(混合并行)

实际训练超大模型时,通常同时使用 DP + TP + PP 三种策略。例如 Megatron-Turing NLG (530B) 使用(2048 GPUs):

维度 并行度 说明
Tensor Parallelism 8 同一节点内 8 张 GPU(NVLink 高带宽)
Pipeline Parallelism 8 跨节点,通信量小(仅传 activation)
Data Parallelism 32 2048 / (8×8) = 32 路数据并行
总 GPU 数 8 × 8 × 32 = 2048

设计原则

  • TP 放在节点内(intra-node):利用 NVLink/NVSwitch 高带宽(900 GB/s on H100)
  • PP 跨节点:点对点通信量小,对带宽要求低
  • DP 在剩余维度上扩展:配合梯度累积进一步隐藏通信

序列并行 (Sequence Parallelism)

当序列长度很长时(如 128K+ tokens),即使单层的 activation 也可能撑爆显存。序列并行将序列维度切分到多个 GPU 上。

下面的动画演示 TP 区域 ↔ SP 区域的切换(4 GPUs):

核心思想:在非 TP 区域沿序列维度切分。 Tensor Parallelism 已经将 Attention 和 FFN 切分到多卡。但 LayerNorm、Dropout 等操作仍然在每张卡上保留完整序列的 activation —— 这是冗余的。

  • TP 区域(已切分): QKV Proj → Attention → O Proj → FFN
  • SP 区域(序列切分): LayerNorm, Dropout, Residual Add
  • TP 到 SP 的转换:AllReduce → ReduceScatter(前向)/ AllGather(反向)
收益 说明
显存节省 activation 内存降为 1/TP_size(如 TP=8 → 节省 87.5% activation)
通信开销 无额外通信!将原有 AllReduce 拆为 ReduceScatter + AllGather

Megatron-SP 代码示意

1
2
3
4
5
6
7
8
9
10
11
12
# 前向: TP 区域结束 → SP 区域开始
# 原本: AllReduce(output) [通信量 2M]
# 改为: ReduceScatter(output) [通信量 M]
output = reduce_scatter(tp_output, dim=1) # seq 维度切分

# LayerNorm / Dropout 只在 local seq chunk 上计算
output = layer_norm(output) # shape: [batch, seq_len/tp, hidden]
output = dropout(output)

# SP 区域结束 → 进入下一个 TP 区域
output = all_gather(output, dim=1) # 恢复完整序列 [通信量 M]
# 总通信量 = M + M = 2M, 与 AllReduce 相同!

同一层的 activation: Normal vs TP vs TP+SP 形状对照

把一层 transformer block 的每个中间 tensor 走一遍, 三栏对比:

  • 左 NORMAL: 单卡保留所有 activation, 是基线
  • 中 TP only: TP=8 沿 hidden / heads 切, 但 LN io、残差那几行 TP 切不动 (仍是完整 (b,s,h))
  • 右 TP+SP: 在 TP 之上叠 SP, 沿 seq 切 TP 切不动那部分 — 把残差/LN 区也压到 1/T

配置: 70B (h=8192, n_h=64, d_k=128), b=1, s=2048, FP16, T=8。点 [下一步] 一行行加, 特别注意 Step 6/7/11: TP 那两列 x_attn/LN2/x_out 是 32 MB 不动 — SP 救场之后只剩 4 MB。最后一帧给出 Korthikanti 的正式公式: NORMAL 1850 MB → TP 373 MB (5×) → TP+SP 224 MB (8.3×), SP 在 TP 之上额外多省的 ~1.7× 全来自那几行残差/LN io。

★ 标记的是体积大头 — O(s²) 在 s=8192 时单项就 17 GB / 层, 是 FlashAttention 重点要砍掉的。

三类区域、三种 sharding 维, 一行看懂

区域 sharding 维 每卡形状 谁切了它
SP 区 (LN io, residual) 沿 seq 切 (b, s/T, h) SP
TP 区, attention (Q, K, V, scores, softmax, ctx) 沿 head 切 (b, n_h/T, s, d_k) TP
TP 区, MLP (up, GeLU) 沿 hidden 切 (b, s, 4h/T) TP
边界 transient (TP 区入口的 x1, x2; TP 区出口的 attn_out, down) 完整 (b, s, h), 但 partial sum, 立即被通信归约 (无, 是中间态)

边界 transient 看似不省 — 但它们是计算流过去的临时态, 算完立即被 reduce-scatter 归约或 @ 矩阵乘 消费, 不会同时活在显存里。Peak 显存账里它们只占一份, 不是所有 transient 同时存在。

为什么省了 ~8×

  • SP 区 (LN io, residual): 长得最多 (每层 6+ 份 (b,s,h)), 沿 seq 切 → 1/T
  • TP 区 attention: O(s²) 大户被 heads 切 → 1/T
  • TP 区 MLP: 4× 大胖 (up + GeLU) 沿 hidden 切 → 1/T
  • 边界 transient: 不切, 但只活一瞬, peak 算一份

总公式 (Korthikanti et al. 2022): activation per layer 34·s·b·h + 5·a·s²·b(34·s·b·h + 5·a·s²·b) / T, 完整 1/T 节省。配上 selective recompute 再砍 ~80% O(s²) → 70B 训练里每卡每层 activation 从 1.85 GB 压到 ~50 MB 量级。

关于 activation 本身在 Transformer 各位置的形状解剖 (x_in/Q/K/V/scores/MLP up 等三档量级、FlashAttention 砍掉哪一档、output logits 为什么 524 MB), 见 LLM Inference §4.1 “activation 到底长在 Transformer 哪里”

专家并行 (Expert Parallelism)

Mixture of Experts (MoE) 模型中,不同 expert 分布在不同 GPU 上。每个 token 经由 router 选择 top-k expert 处理。

下面的动画演示 Token Routing 与 All-to-All(4 GPUs,8 Experts,top-1):

MoE + Expert Parallelism 执行流程: Input Tokens [B, S, H] → Router(softmax → top-k)→ All-to-All(dispatch tokens)→ Expert FFN(local compute)→ All-to-All(combine results)→ Output [B, S, H]

GPU 分布示例(8 Experts, EP=4):GPU 0 持有 E0, E1;GPU 1 持有 E2, E3;GPU 2 持有 E4, E5;GPU 3 持有 E6, E7。

优势 挑战
参数量可以 10× 扩展,计算量仅 ~2×(每 token 只激活 top-k expert) All-to-All 通信:每个 token 需发送到目标 expert 所在的 GPU
每张 GPU 只存部分 expert 参数 负载不均衡:热门 expert 可能收到过多 token(需 load balancing loss)
与 DP/TP/PP 正交组合 Token dropping:expert buffer 满时丢弃 token

DeepSeek-V3 的 EP 实践

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# DeepSeek-V3: 256 Experts, top-8 routing, EP=64
# 每张 GPU 存放 256/64 = 4 个 expert

# Step 1: Router 决定每个 token 去哪些 expert
router_logits = router(hidden_states) # [B*S, num_experts]
topk_ids, topk_weights = top_k(router_logits, k=8)

# Step 2: All-to-All dispatch — 把 token 发送到对应 expert 所在 GPU
# 每张 GPU: 发出本地 token 中属于远端 expert 的部分
# 接收其他 GPU 发来的属于本地 expert 的 token
dispatched = all_to_all(tokens_by_expert, group=ep_group)

# Step 3: 本地 expert 计算
expert_output = local_experts(dispatched) # 只算本 GPU 上的 expert

# Step 4: All-to-All combine — 把结果发回原 GPU
combined = all_to_all(expert_output, group=ep_group)

# Step 5: 加权求和
output = weighted_sum(combined, topk_weights) # [B*S, hidden]

3.2 NCCL 集合通信原语

NVIDIA Collective Communications Library (NCCL) 提供了分布式训练所需的底层通信操作。理解这些原语是理解分布式训练的基础。

操作 描述 输入 → 输出 通信量 (N GPUs, M bytes)
Broadcast 一个 GPU 的数据发送到所有 GPU 1 份 → N 份相同
Reduce 所有 GPU 的数据规约到一个 GPU N 份 → 1 份聚合
AllReduce 所有 GPU 的数据规约,结果发到所有 GPU N 份 → N 份相同聚合
AllGather 收集所有 GPU 的数据到所有 GPU
ReduceScatter 规约后将结果分片发到各 GPU

关键洞察: AllReduce = ReduceScatter + AllGather。这是 Ring AllReduce 算法的核心分解。

选择不同原语,观察 4 个 GPU 之间数据如何流动。每个 GPU 初始持有不同颜色的数据块:

通信量对比

下图展示了 4 个 GPU、参数量 1GB 时各原语的通信量(AllReduce 通信量 = ;其他 = 。N=4, M=1GB):

3.3 Ring AllReduce 算法详解

Ring AllReduce 是最经典的 AllReduce 实现之一,由百度在 2017 年引入深度学习训练中。其核心优势是通信量与 GPU 数量无关(per-GPU 带宽利用恒定)。

算法步骤 —— Ring AllReduce 分为两个阶段,每个阶段有 N-1 步(N = GPU 数):

  • Phase 1: Reduce-Scatter —— 经过 N-1 步后,每个 GPU 持有最终结果的 1/N
  • Phase 2: AllGather —— 再经过 N-1 步后,每个 GPU 持有完整的最终结果

下面的动画用具体数值逐步展示每个 GPU 每个 chunk 的内容变化(环形拓扑,每个 GPU 只向右邻居发送):

Ring AllReduce 伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Ring AllReduce 实现 (ReduceScatter + AllGather)
def ring_allreduce(data, rank, world_size):
# 将数据分为 world_size 个 chunk
chunks = split(data, world_size)

# Phase 1: Reduce-Scatter
for step in range(world_size - 1):
send_idx = (rank - step) % world_size
recv_idx = (rank - step - 1) % world_size

# 向下一个 GPU 发送, 从上一个 GPU 接收
send(chunks[send_idx], dst=(rank + 1) % world_size)
recv_buf = recv(src=(rank - 1) % world_size)

# 将接收到的数据累加到对应 chunk
chunks[recv_idx] += recv_buf

# 此时 chunks[(rank+1) % world_size] 包含该 chunk 的全局归约结果

# Phase 2: AllGather
for step in range(world_size - 1):
send_idx = (rank - step + 1) % world_size
recv_idx = (rank - step) % world_size

send(chunks[send_idx], dst=(rank + 1) % world_size)
chunks[recv_idx] = recv(src=(rank - 1) % world_size)

return concat(chunks)

通信量推导

设定: = GPU 数量, = 每个 GPU 的梯度大小 (bytes),每个 GPU 把梯度切成 个 chunk,每个 chunk 大小 =

Phase 1: Reduce-Scatter —— 每一步每个 GPU 发送 1 个 chunk 给右邻居;共 N-1 步(才能让一个 chunk 经过所有 N 个 GPU 累加)。

Phase 2: AllGather —— 每一步每个 GPU 发送 1 个已完成归约的 chunk 给右邻居;同样 N-1 步。

总通信量:

关键性质:通信量与 GPU 数几乎无关! 很大时 ,所以每 GPU 通信量趋近于 。无论是 4 卡还是 4096 卡,每张卡发送的总数据量基本一样。

对比:Naive AllReduce(树形)

方法 每 GPU 发送量 N=4 N=64 N=1024
Naive (one-to-all) 3M 63M 1023M
Ring AllReduce 1.5M 1.97M 1.998M

带宽利用率: Ring AllReduce 中,每一步每个 GPU 的 NVLink/网卡都在同时收发数据(全双工),带宽利用率 = 。而 Naive 方法中,大部分时间只有根节点在通信,其他 GPU 空闲。

3.4 FSDP (Fully Sharded Data Parallel) / ZeRO

微软的 ZeRO (Zero Redundancy Optimizer) 和 PyTorch 的 FSDP 通过分片 (sharding) 优化器状态、梯度和参数来突破数据并行的内存限制。

内存组成分析 (Mixed Precision Training)

对于参数量为 的模型,使用 Adam + Mixed Precision 训练时,每个 GPU 的内存占用:

组件 精度 内存 (bytes)
参数 (FP16) FP16
梯度 (FP16) FP16
参数 master copy (FP32) FP32
Adam momentum (FP32) FP32
Adam variance (FP32) FP32
总计

ZeRO 三个阶段

阶段 分片对象 内存 通信量
Stage 1 Optimizer States 不变(= DP)
Stage 2 Optimizer + Gradients 不变(= DP)
Stage 3 All(包括 Parameters) 增加 1.5x

FSDP 内存计算器

拖动滑块调整模型参数量、GPU 数量与单卡显存,对比 DDP 与 ZeRO 三个阶段的单卡显存占用:

FSDP 运行流程 (Gather-Compute-Scatter)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# FSDP Forward Pass 伪代码
for layer in model.layers:
# 1. AllGather: 从所有 GPU 收集完整参数
full_params = all_gather(layer.sharded_params)

# 2. Compute: 用完整参数做前向计算
output = layer.forward(input, full_params)

# 3. Free: 释放非本 shard 的参数, 节省内存
free(full_params except my_shard)

# FSDP Backward Pass
for layer in reversed(model.layers):
# 1. AllGather: 重新收集参数 (用于计算梯度)
full_params = all_gather(layer.sharded_params)

# 2. Compute: 计算梯度
grad = layer.backward(output_grad, full_params)

# 3. ReduceScatter: 梯度规约并分片
sharded_grad = reduce_scatter(grad)

# 4. Free: 释放完整参数
free(full_params except my_shard)

3.5 torch.distributed API 实战

初始化进程组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.distributed as dist
import os

def setup(rank, world_size):
"""初始化分布式训练环境"""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

# 初始化进程组, 选择通信后端
# 'nccl' 用于 GPU, 'gloo' 用于 CPU
dist.init_process_group(
backend='nccl',
rank=rank,
world_size=world_size
)

# 设置当前进程使用的 GPU
torch.cuda.set_device(rank)

def cleanup():
dist.destroy_process_group()

AllReduce 使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def allreduce_example(rank, world_size):
setup(rank, world_size)

# 每个 GPU 创建自己的 tensor
tensor = torch.tensor([rank + 1.0], device=f'cuda:{rank}')
print(f"Rank {rank} before AllReduce: {tensor}")

# AllReduce: 对所有 GPU 的 tensor 求和
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f"Rank {rank} after AllReduce: {tensor}")
# 所有 GPU 输出: tensor([10.0]) (1+2+3+4=10, 假设 4 GPUs)

cleanup()

# 启动
import torch.multiprocessing as mp
mp.spawn(allreduce_example, args=(world_size,), nprocs=world_size)

AllGather 使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def allgather_example(rank, world_size):
setup(rank, world_size)

# 每个 GPU 的本地数据
local_tensor = torch.tensor([rank * 10, rank * 10 + 1],
device=f'cuda:{rank}')

# 准备接收 buffer (world_size 个 tensor)
gathered = [torch.zeros(2, device=f'cuda:{rank}')
for _ in range(world_size)]

# AllGather: 收集所有 GPU 的数据
dist.all_gather(gathered, local_tensor)

print(f"Rank {rank}: gathered = {gathered}")
# 每个 GPU 都有: [[0,1], [10,11], [20,21], [30,31]]

cleanup()

ReduceScatter 使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def reduce_scatter_example(rank, world_size):
setup(rank, world_size)

# 每个 GPU 有 world_size 个 chunk
input_tensor = torch.ones(world_size, device=f'cuda:{rank}') * (rank + 1)
output_tensor = torch.zeros(1, device=f'cuda:{rank}')

# ReduceScatter: 规约后分片
dist.reduce_scatter_tensor(output_tensor, input_tensor,
op=dist.ReduceOp.SUM)

print(f"Rank {rank}: result = {output_tensor}")
# Rank 0: 10.0, Rank 1: 10.0, ... (sum of all ranks' chunk_i)

cleanup()

使用 Send/Recv 实现点对点通信

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def point_to_point(rank, world_size):
setup(rank, world_size)

tensor = torch.zeros(1, device=f'cuda:{rank}')

if rank == 0:
# Rank 0 发送数据到 Rank 1
tensor += 42
dist.send(tensor, dst=1)
print(f"Rank 0 sent: {tensor}")
elif rank == 1:
# Rank 1 从 Rank 0 接收数据
dist.recv(tensor, src=0)
print(f"Rank 1 received: {tensor}")

cleanup()

3.6 通信策略开销对比

不同并行策略的通信模式和开销差异巨大。下表对比了在不同配置下每次迭代的通信量:

策略 通信原语 每次迭代通信量 (per GPU) 通信频率
DDP AllReduce 每个 backward step
FSDP (ZeRO-3) AllGather + ReduceScatter 每层 forward + backward
Tensor Parallel AllReduce '_' allowed only in math mode2 \times \text{activation_size} \times (N-1)/N per layer 每层 forward + backward
Pipeline Parallel P2P Send/Recv '_' allowed only in math mode\text{activation_size} \times \text{micro_batches} 每个 micro-batch

Trade-off: FSDP Stage 3 的通信量比 DDP 多 50%,但内存节省可以让你训练大 N 倍的模型或使用更大的 batch size。

CUDA Stream 是什么?

CUDA Stream 是 GPU 上的一个有序命令队列。你可以把它想象成一条流水线传送带

  • 同一 stream 内的操作严格按顺序执行(FIFO)
  • 不同 stream 之间的操作可以并行 —— GPU 硬件会自动调度
  • 默认所有操作都在 stream 0(default stream)上,因此默认是串行的

单 Stream(串行):总时间 = A + copy + B。多 Stream(并行):总时间 = max(A+B, copy)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# PyTorch 中使用多 stream
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()

# Stream 1: 计算
with torch.cuda.stream(s1):
output = model(input_batch)

# Stream 2: 同时做通信 (与 s1 并行!)
with torch.cuda.stream(s2):
dist.all_reduce(gradients)

# 需要同步时: 插入 event 依赖
event = s2.record_event()
s1.wait_event(event) # s1 等 s2 的 all_reduce 完成

关键硬件基础: GPU 有独立的硬件引擎:计算用 SM (Streaming Multiprocessor),数据搬运用 Copy Engine (CE),跨卡通信用 NVLink/PCIe DMA。多 stream 之所以能并行,是因为这些引擎物理上就是独立的 —— 不是软件模拟的并发,而是真正的硬件并行。

通信与计算重叠 (Communication-Computation Overlap)

分布式训练的性能瓶颈往往不是通信本身,而是 通信阻塞了计算。核心思想:让 GPU 在等待数据传输时继续做有用的计算,从而”隐藏”通信延迟。

为什么可以重叠? 有了多 CUDA Stream 的概念,答案就很自然了:把通信和计算放到不同的 stream 上,硬件就能并行执行。

  • 无重叠(Naive): 通信和计算串行,Total = T_compute + T_comm
  • 有重叠(Overlap): 通信与计算并行执行,Total = max(T_compute, T_comm)

技术 1: Gradient Bucketing (DDP)

DDP 不会每算完一个梯度就立刻通信,而是将多个梯度打包成 “bucket”(默认 25MB),从最后一层开始:

  1. Backward 算完 Layer N 的梯度 → 放入 Bucket 1
  2. Bucket 1 满了 → 触发异步 AllReduce(通信 stream)
  3. 同时,计算 stream 继续算 Layer N-1, N-2… 的梯度
  4. 等所有 backward 做完,前面的 AllReduce 也早已完成

技术 2: FSDP Prefetching

FSDP 在计算当前层时,提前 AllGather 下一层的参数。Forward 和 Backward 都有 prefetch。Forward 时计算完一层后立即释放该层参数节省显存;Backward 时三条 stream 同时工作:计算当前层梯度 | 预取下一层参数 | 发送上一层梯度(异步 ReduceScatter)。

技术 3: Async Communication(手动控制)

底层 API 可以手动实现 overlap:

1
2
3
4
5
6
7
8
# 启动异步通信 (不阻塞)
handle = dist.all_reduce(gradient_bucket, async_op=True)

# GPU 继续做其他计算...
next_layer_grad = compute_backward(next_layer)

# 需要结果时才等待
handle.wait() # 此时通信可能早已完成

GPU 硬件视角:为什么能 Overlap

  • Compute Stream: kernel 在 SM 上执行,占用 CUDA Cores / Tensor Cores
  • Comm Stream: NCCL 通过 Copy Engine 搬数据,不占用 SM,完全独立硬件
  • 关键洞察: 两套硬件可以同时满负荷工作,软件只需放到不同 CUDA stream 即可

什么时候 Overlap 失效?

  • 通信远大于计算: 小模型 + 多卡 → AllReduce 时间 >> 单层 backward 时间,无法完全隐藏
  • 带宽不足: PCIe 连接的多节点场景,通信瓶颈太大
  • Bubble 阶段: Pipeline Parallelism 的 warmup/cooldown 阶段,没有计算可以重叠
  • 单 stream 实现: 如果通信和计算在同一个 CUDA stream 上,硬件无法并行

Overlap 程度对比三种场景:理想情况 (通信完全隐藏,effective time = );部分重叠 (compute 完成后还要等通信结束,产生 idle bubble);无重叠(同一 stream 或同步通信,Total = T_compute + T_comm)。

经验法则:通信时间 / 计算时间 ≤ 1 时,overlap 几乎能完全隐藏通信开销。实际调优时用 torch.profiler 的 trace view 观察两条 stream 是否真正并行。

3.7 练习题

练习 1:计算 Ring AllReduce 的通信量

假设有 N 个 GPU,每个 GPU 上的梯度大小为 M bytes。请推导 Ring AllReduce 算法中每个 GPU 的总发送量和总接收量。具体问题: 当 N=8, M=2GB(一个 500M 参数模型的 FP32 梯度)时,每个 GPU 发送和接收多少数据?

完整解答:

  1. 数据被分为 N 个 chunk,每个 chunk 大小 = M/N
  2. Reduce-Scatter 阶段:N-1 步,每步发送 M/N bytes
  3. AllGather 阶段:N-1 步,每步发送 M/N bytes
  4. 总发送量 = 2(N-1) × M/N = 2M(N-1)/N
  5. 总接收量同样 = 2M(N-1)/N

代入数值 N=8, M=2GB:每 GPU 通信量 = 2 × 2GB × 7/8 = 3.5 GB。注意:相比 Naive AllReduce(每个 GPU 发送到一个 root 再广播)的 2M(N-1) = 14 GB,Ring 方式大幅降低了单 GPU 的带宽需求。

练习 2:设计并行策略

你需要训练一个 70B 参数的模型,拥有 64 张 A100 80GB GPU(8 节点,每节点 8 卡,NVLink 互联)。已知:模型训练内存需求(含优化器)~16 bytes/param = 1120 GB;NVLink 带宽 600 GB/s(intra-node);InfiniBand 带宽 50 GB/s(inter-node)。问题: 设计 DP/TP/PP 的组合,说明理由。

推荐方案:TP=8, PP=2, DP=4(验证:8 × 2 × 4 = 64 GPU)

  1. TP=8(节点内): 70B / 8 TP = 每卡 ~8.75B 参数的存储,加上 activation 约 ~35GB,可放入 80GB;TP 通信频繁(每层两次 AllReduce),必须使用 NVLink 高带宽。
  2. PP=2(跨节点): 将模型分为 2 个 stage,每个 stage ~35B 参数;PP 仅传递 activation(相对较小),适合 InfiniBand 带宽;使用足够多的 micro-batch 来降低 bubble ratio。
  3. DP=4: 64 / (8 × 2) = 4 路数据并行;可配合 ZeRO Stage 1 进一步降低优化器状态的内存占用;梯度 AllReduce 可与 backward 计算 overlap。

替代方案: 如果使用 FSDP Stage 3 代替 TP+DP,则可以 PP=8, FSDP=8,但通信开销会更大。

单卡显存占用明细(A100 80GB) —— 70B 模型,TP=8 切 hidden dim,PP=2 切层,每卡有 35B/8 ≈ 4.375B 参数

组件 精度 计算 显存
模型参数 (FP16) FP16 4.375B × 2 bytes 8.75 GB
梯度 (FP16) FP16 4.375B × 2 bytes 8.75 GB
参数 Master Copy (FP32) FP32 4.375B × 4 bytes 17.5 GB
Adam Momentum (FP32) FP32 4.375B × 4 bytes 17.5 GB
Adam Variance (FP32) FP32 4.375B × 4 bytes 17.5 GB
Activations (+ checkpointing) Mixed b=1, s=2048, full ckpt ~1.3 GB
总计 ~71.3 GB

进一步优化 + ZeRO Stage 1:优化器状态(Master + Momentum + Variance = 52.5GB)在 DP=4 的 group 内分片,每卡优化器占用 = 52.5/4 = 13.1 GB,总显存 ≈ 32 GB,余量充足。

Activation 显存怎么估算? 每层总 activation(Megatron 论文公式):

'_' allowed only in math mode\text{Act_per_layer} = sbh \times \left(34 + \frac{5 \cdot s \cdot n_{head}}{h}\right) \approx 34,sbh

其中 s=seq_len, b=micro_batch_size, h=hidden_dim, n_head=注意力头数;34 为各项系数之和(FP16 每元素 2 bytes),是不使用 activation checkpointing 的情况。代入 70B(h=8192, n_head=64, 本卡 40 layers, s=2048, b=1):每层 Act = 2048×1×8192×34 = 544 MB/layer。TP=8 可切分约 65%(22/34)的 activation(QKV/attn_out/FFN),LayerNorm 输入等约 12bsh 不可切分。每层实际 ≈ 247 MB/GPU,40 层 ≈ 10 GB/GPU(不做 checkpointing)。1F1B schedule 中同时 in-flight 的 micro-batch 数 = PP stages,峰值 ≈ 2 × 10 = 20 GB。

实际工程必须用 Activation Checkpointing:Full checkpointing 只存每层输入(2bsh = 32 MB/layer),40 层 ≈ 1.3 GB(代价 ~33% 额外计算);Selective checkpointing 只重算 attention score,≈ 3-4 GB(额外计算 ~10%)。

影响 activation 大小的关键因素:seq_len(线性 + attention score 是 )、micro_batch(线性)、hidden_dim(线性)、层数(PP 切层可减)、TP degree(部分可切)、Checkpointing(~33% 额外计算换 ~90% activation 节省)。

编程练习 A:Ring AllReduce(完整实现)

torch.distributedisend/irecv 原语从零实现 Ring AllReduce。要求:实现 Reduce-Scatter + AllGather,4 进程验证与 dist.all_reduce 一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os

def ring_allreduce(tensor, rank, world_size):
assert tensor.numel() % world_size == 0
chunk_size = tensor.numel() // world_size
chunks = list(tensor.view(world_size, chunk_size).unbind(dim=0))
chunks = [c.contiguous().clone() for c in chunks]

left = (rank - 1) % world_size
right = (rank + 1) % world_size
send_buf = torch.zeros(chunk_size, device=tensor.device)
recv_buf = torch.zeros(chunk_size, device=tensor.device)

# ===== Phase 1: Reduce-Scatter (N-1 steps) =====
for step in range(world_size - 1):
send_idx = (rank - step) % world_size # 哪个 chunk 要发出去
recv_idx = (rank - step - 1) % world_size # 收到的数据累加到哪个 chunk

send_buf.copy_(chunks[send_idx])
req_s = dist.isend(send_buf, dst=right)
req_r = dist.irecv(recv_buf, src=left)
req_s.wait()
req_r.wait()

chunks[recv_idx] += recv_buf # 关键: 累加!

# ===== Phase 2: AllGather (N-1 steps) =====
for step in range(world_size - 1):
send_idx = (rank - step + 1) % world_size # 发送已归约的 chunk
recv_idx = (rank - step) % world_size # 覆盖 (不是累加!)

send_buf.copy_(chunks[send_idx])
req_s = dist.isend(send_buf, dst=right)
req_r = dist.irecv(recv_buf, src=left)
req_s.wait()
req_r.wait()

chunks[recv_idx].copy_(recv_buf) # 关键: 覆盖!

return torch.cat(chunks, dim=0)

def worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)

torch.manual_seed(rank)
tensor = torch.randn(16, device=f'cuda:{rank}')

result_ring = ring_allreduce(tensor.clone(), rank, world_size)

result_ref = tensor.clone()
dist.all_reduce(result_ref, op=dist.ReduceOp.SUM)

assert torch.allclose(result_ring, result_ref, atol=1e-5)
print(f"Rank {rank}: PASSED!")
dist.destroy_process_group()

if __name__ == '__main__':
mp.spawn(worker, args=(4,), nprocs=4)

index 计算规律: Reduce-Scatter 中 send_idx = (rank - step) % Nrecv_idx = (rank - step - 1) % N;AllGather 中 send_idx = (rank - step + 1) % Nrecv_idx = (rank - step) % N。直觉:收到的数据要累加到”下一个”要发出去的 chunk。

编程练习 B:Pipeline Schedule 模拟器(GPipe vs 1F1B)

模拟两种调度,对比 bubble ratio 和峰值 activation 内存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np

def simulate_gpipe(p, m):
"""p: num pipeline stages, m: num micro-batches
Returns: (schedule, bubble_ratio, peak_memory)"""
total_time = 2 * (m + p - 1)
schedule = [['idle'] * total_time for _ in range(p)]

# Forward: stage s 处理 micro-batch mb 的时间 = s + mb
for mb in range(m):
for s in range(p):
schedule[s][s + mb] = f'F{mb}'

# Backward: GPipe 所有 forward 做完再做 backward
b_start = m + p - 1
for mb in range(m):
for s in range(p - 1, -1, -1):
schedule[s][b_start + mb + (p - 1 - s)] = f'B{mb}'

busy = sum(1 for row in schedule for x in row if x != 'idle')
bubble_ratio = 1.0 - busy / (p * total_time)
peak_memory = m # 开始 backward 前每个 stage 存了 m 个 activation
return schedule, bubble_ratio, peak_memory

def simulate_1f1b(p, m):
"""1F1B schedule: warmup → steady (1B1F) → cooldown"""
assert m >= p
total_time = 2 * (m + p - 1)
schedule = [['idle'] * total_time for _ in range(p)]

for s in range(p):
t = s
fwd_done, bwd_done = 0, 0

# Warmup — 做 (p - s) 个 forward 填满 pipeline
warmup_count = min(p - s, m)
for _ in range(warmup_count):
schedule[s][t] = f'F{fwd_done}'
fwd_done += 1; t += 1

# Steady state — 交替 1B + 1F
while fwd_done < m:
schedule[s][t] = f'B{bwd_done}'
bwd_done += 1; t += 1
schedule[s][t] = f'F{fwd_done}'
fwd_done += 1; t += 1

# Cooldown — 做剩余 backward
while bwd_done < m:
schedule[s][t] = f'B{bwd_done}'
bwd_done += 1; t += 1

busy = sum(1 for row in schedule for x in row if x != 'idle')
bubble_ratio = 1.0 - busy / (p * total_time)
peak_memory = p # stage s 最多同时持有 (p - s) 个 activation
return schedule, bubble_ratio, peak_memory

# === 运行 ===
p, m = 4, 8
_, bg, mg = simulate_gpipe(p, m)
_, b1, m1 = simulate_1f1b(p, m)
print(f"GPipe: bubble={bg:.1%}, peak_mem={mg} activations")
print(f"1F1B: bubble={b1:.1%}, peak_mem={m1} activations")
print(f"Memory saving: {(1-m1/mg)*100:.0f}% (1F1B vs GPipe)")
# 预期: GPipe bubble=27.3% peak=8; 1F1B bubble=27.3% peak=4; saving 50%
# 注意: bubble ratio 相同! 1F1B 的优势在内存, 不在 bubble
# (实际中 1F1B 允许更大的 m, 从而间接降低 bubble)

3.8 本章小结

概念 核心要点
Data Parallelism 复制模型,切分数据,AllReduce 梯度;简单但有显存瓶颈
Tensor Parallelism 切分层内权重(column/row split);需要高带宽互联
Pipeline Parallelism 按层切分模型;bubble ratio = (p-1)/(p-1+m)
Ring AllReduce 通信量 O(M),与 GPU 数无关;带宽最优
FSDP / ZeRO 分片优化器/梯度/参数;Stage 3 可线性降低内存
3D Parallelism TP(节点内)+ PP(跨节点)+ DP(扩展);训练超大模型的标准范式

下一章预告: Chapter 4 将深入探讨 GPU 内存管理与高效 Attention 实现(FlashAttention, PagedAttention),进一步从硬件层面理解训练和推理效率。


本文是 ML Systems 系列 Chapter 3。正文 markdown 渲染,9 个交互动画通过自定义 {% anim %} 标签以隔离 iframe 嵌入,源自 Arkive 教程。