交互动画:完整 MLP 层的 TP 计算流程(Megatron-LM 风格)

Column Parallel → GeLU(无通信)→ Row Parallel → AllReduce(唯一通信点)

X (输入)
[batch, seq, d]
↓ 广播到每张卡(f 算子: identity forward, AllReduce backward)
X × W₁col
[d, d/2] → Y₁ [batch, seq, d/2]
GPU 0 独立计算
X × W₂col
[d, d/2] → Y₂ [batch, seq, d/2]
GPU 1 独立计算
↓ 无需通信!GeLU 是逐元素操作
GeLU(Y₁)
shape 不变 [batch, seq, d/2]
无通信!
GeLU(Y₂)
shape 不变 [batch, seq, d/2]
无通信!
↓ 关键:Row Parallel 按行切分第二个权重矩阵
GeLU(Y₁) × W₁row
[d/2, d] → Z₁ [batch, seq, d]
GPU 0 部分结果
GeLU(Y₂) × W₂row
[d/2, d] → Z₂ [batch, seq, d]
GPU 1 部分结果
↓ AllReduce(g 算子: AllReduce forward, identity backward)
Z = Z₁ + Z₂ (AllReduce)
[batch, seq, d] — 完整输出
唯一通信点!
输出 → 下一层
[batch, seq, d]
点击「播放」逐步展示 MLP 在两张 GPU 上的完整计算流程