大规模 ML 系统的高级优化技术

本章覆盖大规模 ML 系统中的前沿优化技术:从 Zero-Overhead 调度、DP Attention、在线权重更新,到 FP8/INT4 训练、投机解码、长上下文优化与 Diffusion LLM。正文为 markdown,关键机制配有可交互动画。

6.1 Zero-Overhead Batch Scheduling

核心思想

在 LLM 推理中,batch scheduling 的 CPU 开销会导致 GPU 出现 idle bubble。Zero-Overhead Scheduling 通过 CPU-GPU overlap 实现调度零开销:在 GPU 执行当前 batch 的 forward pass 时,CPU 并行完成下一个 batch 的调度决策和数据准备。

关键技术组件

1. Placeholder Buffering: 预分配 placeholder slots,避免动态内存分配的开销。新请求到来时只需填充 placeholder,不需要 resize tensor。

2. Async Scheduling: 调度逻辑在独立线程运行,通过 double-buffering 机制与 GPU kernel 执行完全 overlap。

3. CPU-S 与 CPU-L 分离: 将 CPU 操作分为 short-latency (CPU-S) 和 long-latency (CPU-L) 两类。CPU-S 包含 token sampling、detokenize 等轻量操作;CPU-L 包含 scheduling policy、radix tree 操作等重量操作。两者分别 pipeline 化。

下面的动画展示 CPU-S、CPU-L 与 GPU 的流水线重叠效果,对比 “无重叠” 与 “有重叠 (pipelined)” 两种执行模式:

Overlap Scheduler 伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class OverlapScheduler:
def __init__(self, max_batch_size, num_placeholders):
self.front_buffer = BatchBuffer(max_batch_size)
self.back_buffer = BatchBuffer(max_batch_size)
self.placeholder_pool = PlaceholderPool(num_placeholders)
self.schedule_thread = Thread(target=self._schedule_loop)

def _schedule_loop(self):
"""CPU-L: 在后台线程运行调度策略"""
while not self.stopped:
# 等待 GPU forward 开始
self.gpu_start_event.wait()

# CPU-L operations (overlap with GPU)
new_reqs = self.request_queue.drain()
scheduled = self.scheduling_policy(new_reqs, self.back_buffer)

# 填充 placeholder slots
for req in scheduled:
slot = self.placeholder_pool.acquire()
slot.fill(req.input_ids, req.metadata)
self.back_buffer.add(slot)

# Radix tree 更新 (prefix caching)
self.radix_tree.insert_batch(scheduled)

# Signal: 下一个 batch 准备好了
self.batch_ready_event.set()

def step(self):
"""主循环: 交替执行 GPU forward 和 buffer swap"""
# GPU forward (当前 batch)
self.gpu_start_event.set()
output = self.model.forward(self.front_buffer)

# CPU-S: sampling + detokenize (很快)
tokens = self.sampler.sample(output.logits)
self.detokenizer.decode_batch(tokens)

# 等待 CPU-L 完成
self.batch_ready_event.wait()

# Double-buffer swap (零拷贝)
self.front_buffer, self.back_buffer = self.back_buffer, self.front_buffer
self.batch_ready_event.clear()

return tokens

性能分析

假设 GPU forward 耗时 ,CPU-S 耗时 ,CPU-L 耗时

  • Without overlap: 每 step 耗时
  • With overlap: 每 step 耗时

(通常成立)时,加速比

典型场景下 ,加速比约 1.29x

6.2 DP Attention for MoE

问题背景

在 Mixture-of-Experts (MoE) 模型中,传统 Tensor Parallelism (TP) 要求每个 GPU 都持有完整的 KV cache 副本。对于 DeepSeek-V3 这样的模型(236B 参数,128 experts),这导致巨大的内存浪费。

DP Attention 的核心思路

将 Attention 层使用 Data Parallelism(每个 GPU 处理不同的 request,各自维护自己的 KV cache),而 MLP/Expert 层使用 Tensor/Expert Parallelism。这样:

  • KV cache 不需要在 GPU 之间复制,内存效率大幅提升
  • Attention 计算无需 all-reduce,通信量减少
  • Expert 层仍然使用 EP/TP 来分摊参数

通信模式

Attention → MLP 之间需要 All-to-All 通信:每个 GPU 将自己的 token hidden states 发送给负责对应 expert 的 GPU。这类似于 MoE routing 的通信,但方向相反。

下面的交互图示可调整 Num Heads、TP Degree、Seq Length,对比 TP Attention 与 DP Attention 的 KV cache、通信与计算开销:

DP Attention 配置示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# DeepSeek-V3 style DP Attention + EP MLP configuration
from dataclasses import dataclass

@dataclass
class DPAttentionConfig:
# Model dimensions
num_attention_heads: int = 128
num_kv_heads: int = 8 # GQA: 8 KV heads shared
head_dim: int = 128
hidden_size: int = 7168

# Parallelism config
dp_size: int = 8 # Data parallel for attention
ep_size: int = 8 # Expert parallel for MLP
num_experts: int = 256
experts_per_gpu: int = 32 # 256 / 8

# Communication
use_all_to_all: bool = True # Between attention DP and MLP EP
overlap_comm_compute: bool = True

def kv_cache_per_gpu(self, seq_len, batch_size, dtype_bytes=2):
"""每个 GPU 的 KV cache 大小 (bytes)"""
# DP: 每个 GPU 只存本地 batch 的 KV cache
local_batch = batch_size // self.dp_size
kv_size = 2 * self.num_kv_heads * self.head_dim * seq_len * local_batch * dtype_bytes
return kv_size

def tp_kv_cache_per_gpu(self, seq_len, batch_size, dtype_bytes=2):
"""TP 模式下每个 GPU 的 KV cache 大小 (需要全部 batch)"""
# TP: 每个 GPU 需要全部 batch 的 KV cache (heads 分割)
kv_heads_per_gpu = self.num_kv_heads # GQA heads 不分割
kv_size = 2 * kv_heads_per_gpu * self.head_dim * seq_len * batch_size * dtype_bytes
return kv_size

config = DPAttentionConfig()
print(f"DP KV/GPU: {config.kv_cache_per_gpu(8192, 64) / 1e9:.2f} GB")
print(f"TP KV/GPU: {config.tp_kv_cache_per_gpu(8192, 64) / 1e9:.2f} GB")

6.3 Online Weight Updates

动机与挑战

在 RLHF/Online RL 训练中,policy model 需要不断更新。传统做法是训练完成后重新加载 checkpoint,这导致:

  • 高延迟: 保存 + 加载 checkpoint 耗时数十秒到数分钟
  • 资源浪费: 训练期间推理引擎 idle,推理期间训练集群 idle
  • Staleness: rollout 使用的模型版本落后于最新 policy

NCCL Broadcast 方案

利用 NCCL 的跨节点通信能力,直接将训练节点的权重 broadcast 到推理节点:

  1. 训练完成一个 step/epoch 后,trainer 触发 weight broadcast
  2. 通过 NCCL all-gather 或 broadcast 将权重同步到 inference workers
  3. Inference engine 原子性地更新模型权重(使用 double buffering)
  4. 正在进行的推理请求不受影响(用旧权重完成)

延迟优化技巧

Chunked Transfer: 将模型按 layer 分块传输,边传边更新,不需要等所有层传完。

Priority Scheduling: 先传输 embedding 和最后几层(对输出影响最大),其他层异步传输。

Delta Compression: 只传输权重的差值 (),利用稀疏性压缩传输量。

Online Weight Update 代码架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.distributed as dist

class OnlineWeightUpdater:
def __init__(self, model, trainer_rank, inference_ranks, nccl_group):
self.model = model
self.trainer_rank = trainer_rank
self.inference_ranks = inference_ranks
self.group = nccl_group
self.update_lock = threading.Lock()
# Double buffer for zero-downtime update
self.active_params = 0
self.param_buffers = [
{n: p.clone() for n, p in model.named_parameters()},
{n: p.clone() for n, p in model.named_parameters()}
]

def broadcast_weights(self, priority_layers=None):
"""从 trainer broadcast 权重到 inference workers"""
inactive = 1 - self.active_params
layers = priority_layers or list(self.model.named_parameters())

for name, param in layers:
# NCCL broadcast: trainer → inference workers
dist.broadcast(param.data, src=self.trainer_rank, group=self.group)
# 写入 inactive buffer
self.param_buffers[inactive][name].copy_(param.data)

# Atomic swap
with self.update_lock:
self.active_params = inactive
self._apply_active_params()

def _apply_active_params(self):
"""Apply active buffer to model (pointer swap)"""
active = self.param_buffers[self.active_params]
for name, param in self.model.named_parameters():
param.data = active[name]

6.4 FP8/INT4 Training for RL

统一精度的必要性

在 RL for LLM 场景中,training 和 inference 使用不同精度会导致:

  • Reward hacking: 量化误差造成的 reward 信号偏差
  • Train-inference gap: 训练时 BF16,推理时 INT4/FP8,行为不一致
  • 精度转换开销: 频繁的 quantize/dequantize 消耗计算

FP8 E4M3 格式

FP8 E4M3 使用 1 bit sign + 4 bits exponent + 3 bits mantissa,动态范围 ,适合前向计算。E5M2 使用 5 bits exponent + 2 bits mantissa,动态范围更大,适合梯度。

Dynamic Scaling

为了最大化 FP8 的有效利用率,使用 per-tensor 或 per-channel 的 dynamic scaling factor:

'_' allowed only in math mode\text{scale} = \frac{\max(|\text{tensor}|)}{\text{max_fp8_value}}

量化:'_' allowed only in math mode\text{tensor_fp8} = \text{round}(\text{tensor} / \text{scale})

反量化:'_' allowed only in math mode\text{tensor_restored} = \text{tensor_fp8} \times \text{scale}

下面的交互图示展示一个浮点数如何被编码为 FP8 E4M3 的 sign/exponent/mantissa 位,并计算量化误差与 dynamic scale factor:

FP8 Quantization with Dynamic Scaling

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn.functional as F

class FP8Quantizer:
"""FP8 E4M3 quantizer with dynamic per-tensor scaling"""
FP8_E4M3_MAX = 448.0
FP8_E5M2_MAX = 57344.0

@staticmethod
def quantize_e4m3(tensor, per_channel=False):
"""Forward pass quantization (E4M3 for activations/weights)"""
if per_channel:
# Per-channel scaling for weights
amax = tensor.abs().amax(dim=-1, keepdim=True)
else:
# Per-tensor scaling
amax = tensor.abs().amax()

# Compute scale factor
scale = amax / FP8Quantizer.FP8_E4M3_MAX
scale = torch.where(scale > 0, scale, torch.ones_like(scale))

# Quantize
tensor_scaled = tensor / scale
tensor_fp8 = tensor_scaled.clamp(
-FP8Quantizer.FP8_E4M3_MAX, FP8Quantizer.FP8_E4M3_MAX
).to(torch.float8_e4m3fn)

return tensor_fp8, scale

@staticmethod
def quantize_e5m2(tensor):
"""Gradient quantization (E5M2 for larger dynamic range)"""
amax = tensor.abs().amax()
scale = amax / FP8Quantizer.FP8_E5M2_MAX
scale = torch.where(scale > 0, scale, torch.ones_like(scale))

tensor_fp8 = (tensor / scale).clamp(
-FP8Quantizer.FP8_E5M2_MAX, FP8Quantizer.FP8_E5M2_MAX
).to(torch.float8_e5m2)

return tensor_fp8, scale

@staticmethod
def dequantize(tensor_fp8, scale):
"""Restore to higher precision"""
return tensor_fp8.float() * scale


# QAT (Quantization-Aware Training) for RL consistency
class FP8QATLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(
torch.randn(out_features, in_features)
)
self.quantizer = FP8Quantizer()

def forward(self, x):
# Quantize weight → FP8 → dequantize (STE gradient)
w_fp8, w_scale = self.quantizer.quantize_e4m3(self.weight)
w_approx = self.quantizer.dequantize(w_fp8, w_scale)
# Straight-Through Estimator: grad flows through dequantize
w_final = self.weight + (w_approx - self.weight).detach()
return F.linear(x, w_final)

6.5 Speculative Decoding in RL

RL Rollout 的瓶颈

在 RL training (如 PPO, GRPO) 中,rollout (生成 trajectory) 占总训练时间的 60-80%。Speculative Decoding 通过小模型 (draft model) 加速生成:

基本流程

  1. Draft: 使用小模型 (如 1.5B) 快速生成 K 个 token
  2. Verify: 大模型 (如 70B) 一次性验证这 K 个 token 的概率
  3. Accept/Reject: 按照 rejection sampling 接受或拒绝 draft tokens
  4. Correction: 如果某个 token 被拒绝,用大模型重新采样

RL 场景特殊性

Online SFT for Draft Model: 在 RL 训练过程中,policy 模型不断变化。如果 draft model 是固定的,acceptance rate 会随训练下降。解决方案:

  • 每 N 个 RL step,用 policy model 的 rollout 数据 SFT 更新 draft model
  • 使用 knowledge distillation loss 让 draft model 跟随 policy 分布
  • Draft model 的训练开销小(参数少,数据已有)

加速比分析

设 acceptance rate 为 ,draft 长度为 ,draft/verify 时间比为

, , 时,加速比约 2.5x

6.6 Long-Context Optimization

DuoAttention: Retrieval vs Streaming Heads

核心发现:在 LLM 的多头注意力中,不同的 head 承担不同的功能:

  • Retrieval Heads (约 25%): 负责从长上下文中检索信息,attention pattern 稀疏且 long-range。这些 head 需要完整的 KV cache。
  • Streaming Heads (约 75%): 主要关注近期 token 和特殊位置(如 sink token),pattern 呈现局部性。这些 head 只需要保留 recent window + sink tokens 的 KV cache。

内存节省

假设 75% 的 heads 是 streaming heads,只保留最近 1024 tokens 的 KV:

'_' allowed only in math mode\text{内存节省} = 0.75 \times \left(1 - \frac{1024}{\text{seq_len}}\right)

当 seq_len = 128K 时,节省约 74.4% 的 KV cache 内存。

StreamingLLM

StreamingLLM 观察到 attention 的 “sink” 现象:初始几个 token 总是获得高 attention score(不论内容)。因此只需保留:

  • Sink tokens (前 4 个 token 的 KV cache)
  • Recent window (最近 N 个 token 的 KV cache)

这使得 LLM 可以处理无限长度的 streaming 输入,KV cache 恒定大小。

6.7 Diffusion LLMs

从 Autoregressive 到 Non-Autoregressive

传统 LLM 是 autoregressive 的:一次生成一个 token,总延迟与序列长度线性正比。Diffusion LLM (如 LLaDA - Large Language Diffusion with mAsking) 采用不同的生成范式:

LLaDA 的工作原理

  1. Forward Process (训练): 随机 mask 输入 token(类似 BERT 的 MLM,但 mask ratio 连续变化 0→1)
  2. Reverse Process (推理): 从全 mask 开始,逐步 “denoise” 恢复 token
  3. Block-wise Decoding: 每一步可以同时恢复多个 token(confidence-based unmasking)

优势与挑战

  • 并行生成: 每步可恢复多个 token,减少总 step 数 (如 128 tokens 只需 8-16 steps)
  • 全局一致性: 每步都能 “看到” 所有位置的当前状态,避免 AR 的局部决策
  • 速度 vs 质量 trade-off: 更多 diffusion steps = 更高质量,更少 steps = 更快速度
  • 挑战: 当前质量仍略逊于同规模 AR 模型,但差距在缩小

下面的动画展示 token 如何从随机 mask 逐步被 “去噪” 为连贯文本(block-parallel decoding,每步并行恢复多个 token):

练习 / Exercises

Exercise 1: 计算 Overlap Scheduling 的理论加速比

给定以下参数:

  • GPU forward pass 时间: = 20ms
  • CPU-S (sampling + detokenize) 时间: = 1.2ms
  • CPU-L (scheduling + radix tree) 时间: = 5.8ms

请计算:(a) Without overlap 的每 step 延迟;(b) With overlap 的每 step 延迟;(c) 加速比;(d) 如果 增加到 25ms(大于 ),加速比如何变化?

关键公式:

  • Without overlap:
  • With overlap:

注意 overlap 只能在 时完全隐藏 CPU-L 的开销。当 时,GPU 会有 idle。

解答:

(a) Without overlap: ms

(b) With overlap: ms(CPU-L 完全 overlap 在 GPU forward 内)

(c) 加速比: 延迟减少

(d) 当 = 25ms 时:

  • Without overlap: ms
  • With overlap: ms
  • 加速比

当 CPU-L 较大时,overlap 的收益更加显著,但此时 GPU 有 5ms 的 idle bubble (ms)。应该考虑优化 CPU-L 或增加 batch size 来增大

Exercise 2: DP Attention 内存节省计算

给定以下 MoE 模型配置:

  • Num KV heads (GQA): 8
  • Head dimension: 128
  • Sequence length: 32768
  • Total batch size: 128 requests
  • TP/DP degree: 8 GPUs
  • Precision: FP16 (2 bytes per element)

请计算:(a) TP Attention 模式下每个 GPU 的 KV cache 大小;(b) DP Attention 模式下每个 GPU 的 KV cache 大小;(c) 内存节省百分比;(d) 如果 batch size 翻倍到 256,节省比例如何变化?

KV cache 大小公式:'_' allowed only in math mode2 ,(K+V) \times \text{num_kv_heads} \times \text{head_dim} \times \text{seq_len} \times \text{batch_size} \times \text{dtype_bytes}

解答:

(a) TP Attention - KV cache per GPU: 每个 GPU 需要全部 batch 的 KV cache(GQA heads 不分割):

(b) DP Attention - KV cache per GPU: 每个 GPU 只处理 个请求:

(c) 内存节省: ,即 '_' allowed only in math mode1 - 1/\text{DP_degree} = 1 - 1/8 = 87.5%

(d) batch_size = 256 时:

  • TP: GB
  • DP: GB (256/8=32)
  • 节省比例仍然是 (与 batch size 无关!)

节省比例取决于 DP degree,不取决于 batch size。这是 DP Attention 的优美之处。

Exercise 3: FP8 量化误差分析

假设一个权重 tensor 服从正态分布 ,其中 (典型的 LLM 初始化)。

  • Tensor 大小: 元素
  • 使用 FP8 E4M3 (max representable = 448)
  • Per-tensor dynamic scaling

请计算:(a) 期望的 tensor max ();(b) Scale factor;(c) 量化后的有效精度 (最小可区分步长);(d) 近似的量化 MSE。

提示:对于 N 个 i.i.d. 的样本,max 的期望约为 。FP8 E4M3 在数值 1 附近的精度为 (3 bits mantissa)。量化误差可近似为均匀分布在 ,方差为

解答:

(a) 期望 tensor max: 元素

(b) Scale factor:

(c) 有效精度(最小可区分步长): FP8 E4M3 在值 附近的步长 。对于原始空间,最小步长 。对于典型值():scaled 到 ,步长 ,原始步长

(d) 近似量化 MSE: 简化估算

,相对 的比率

结论: FP8 量化对于 的正态分布权重引入的相对误差约 0.2%,对训练质量影响很小,支持 train-inference 统一精度。

Exercise 4 (Bonus): Speculative Decoding Acceptance Rate

给定 draft model 的 per-token acceptance rate ,draft length

请计算:(a) 每次 verify 的期望接受 token 数;(b) 如果 draft model 推理速度是 target model 的 10x,计算端到端加速比;(c) 经过 100 步 RL 训练后, 下降到 0.55,加速比变为多少?是否值得 online SFT 更新 draft model?

期望接受 token 数 ;加速比 期望接受 token

解答:

(a) 期望接受 token 数 (, ):

(b) 加速比 (draft 10x faster): ,一次迭代时间 。AR 模式生成相同 token 数需要 3.466 次 target forward。加速比

(c) 时:

加速比 ,从 2.17x 降到 1.37x,下降了 37%。

是否值得 Online SFT? 如果 online SFT 能将 恢复到 0.70+(接近初始值),且 SFT 开销 < 获得的加速收益,则值得。通常 draft model 很小(1-2B),SFT 成本低,因此答案是值得的。

Coding Exercise A: 实现 FP8 量化和反量化

实现 per-tensor dynamic scaling 的 FP8 E4M3 量化/反量化,包含误差分析和可视化。要求:实现 quantize(tensor)dequantize(fp8_tensor, scale)、计算 MSE 和 max absolute error,对正态分布 和均匀分布 分别测试,并画出原始值 vs 量化后值的散点图。

完整解答:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import numpy as np
import matplotlib.pyplot as plt

class FP8E4M3Quantizer:
"""FP8 E4M3 quantizer with per-tensor dynamic scaling."""
FP8_E4M3_MAX = 448.0

def quantize(self, tensor: torch.Tensor):
# Step 1: Compute absolute max
amax = tensor.abs().amax()

# Step 2: Compute scale (handle zero tensor)
scale = amax / self.FP8_E4M3_MAX
scale = torch.where(scale > 0, scale, torch.ones_like(scale))

# Step 3: Scale and clamp
tensor_scaled = tensor / scale
tensor_scaled = tensor_scaled.clamp(-self.FP8_E4M3_MAX, self.FP8_E4M3_MAX)

# Step 4: Quantize to FP8 E4M3
# Use native dtype if available, otherwise simulate
if hasattr(torch, 'float8_e4m3fn'):
tensor_fp8 = tensor_scaled.to(torch.float8_e4m3fn)
else:
# Simulate: FP8 E4M3 has 3 mantissa bits (8 levels per power of 2)
sign = tensor_scaled.sign()
abs_val = tensor_scaled.abs()
# Find exponent
exp = torch.floor(torch.log2(abs_val.clamp(min=2**-9)))
# Quantize mantissa to 3 bits
mantissa = abs_val / (2.0 ** exp) - 1.0
mantissa_q = torch.round(mantissa * 8) / 8
mantissa_q = mantissa_q.clamp(0, 0.875) # 7/8 max
tensor_fp8 = sign * (1.0 + mantissa_q) * (2.0 ** exp)
tensor_fp8 = torch.where(abs_val == 0, torch.zeros_like(tensor_fp8), tensor_fp8)

return tensor_fp8, scale

def dequantize(self, tensor_fp8: torch.Tensor, scale: torch.Tensor):
# Convert back to float32 and multiply by scale
return tensor_fp8.float() * scale

def compute_error(self, original: torch.Tensor, restored: torch.Tensor):
diff = original - restored
mse = (diff ** 2).mean().item()
max_err = diff.abs().max().item()
return mse, max_err


def test_and_visualize():
quantizer = FP8E4M3Quantizer()
torch.manual_seed(42)

# Create test tensors
tensor_normal = torch.randn(4096) * 0.02 # N(0, 0.02^2)
tensor_uniform = torch.rand(4096) * 2 - 1 # U(-1, 1)

# Quantize and dequantize
results = {}
for name, tensor in [("Normal N(0,0.02²)", tensor_normal),
("Uniform U(-1,1)", tensor_uniform)]:
fp8, scale = quantizer.quantize(tensor)
restored = quantizer.dequantize(fp8, scale)
mse, max_err = quantizer.compute_error(tensor, restored)
results[name] = {
'original': tensor,
'restored': restored,
'mse': mse,
'max_err': max_err,
'scale': scale.item()
}
print(f"[{name}] MSE={mse:.2e}, Max Error={max_err:.2e}, Scale={scale.item():.6f}")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, (name, res) in zip(axes, results.items()):
orig = res['original'].numpy()
rest = res['restored'].numpy()
ax.scatter(orig, rest, alpha=0.3, s=2, color='cyan')
# Perfect quantization reference line
lim = max(abs(orig.min()), abs(orig.max())) * 1.1
ax.plot([-lim, lim], [-lim, lim], 'r--', linewidth=0.8, label='ideal (y=x)')
ax.set_xlabel('Original Value')
ax.set_ylabel('Dequantized Value')
ax.set_title(f'{name}\nMSE={res["mse"]:.2e}, MaxErr={res["max_err"]:.2e}')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('fp8_quantization_scatter.png', dpi=150)
plt.show()


if __name__ == "__main__":
test_and_visualize()

Expected Output:

1
2
3
4
5
6
7
8
# Terminal output:
[Normal N(0,0.02²)] MSE=1.47e-09, Max Error=7.63e-05, Scale=0.000192
[Uniform U(-1,1)] MSE=2.18e-05, Max Error=1.12e-02, Scale=0.002232

# Scatter plot 说明:
# - Normal 分布: 点紧密分布在对角线上,误差极小 (scale 很小,利用率高)
# - Uniform 分布: 离对角线稍有偏离,尤其在接近 0 处 (相对误差较大)
# - 两种情况的 FP8 量化误差对训练影响都可以忽略不计
Coding Exercise B: 实现 Double-Buffering Weight Update 模拟

模拟 training → inference 的 weight update pipeline,对比 double-buffering (零 downtime) 和 stop-the-world (暂停服务) 两种策略的 throughput 差异。要求实现 TrainingEngineInferenceEngine、double-buffer 与 stop-the-world 两种 updater,输出 timeline 并对比总 throughput。

完整解答:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import threading
from dataclasses import dataclass, field
from typing import List, Tuple

@dataclass
class WeightBuffer:
version: int = 0
data: List[float] = field(default_factory=lambda: [0.0] * 100)
is_ready: bool = True


@dataclass
class TimelineEvent:
start_ms: float
end_ms: float
event_type: str # 'training', 'transfer', 'serving', 'downtime'
label: str = ""


class TrainingEngine:
def __init__(self, step_time_ms: float = 50.0):
self.step_time_ms = step_time_ms
self.current_version = 0

def train_step(self) -> Tuple[int, List[float]]:
self.current_version += 1
# Simulate: return version and "new weights"
new_weights = [float(self.current_version) * 0.01] * 100
return self.current_version, new_weights


class Simulator:
"""Virtual-time simulator for both strategies."""

def __init__(self, train_step_ms=50.0, transfer_ms=20.0, request_ms=5.0):
self.train_step_ms = train_step_ms
self.transfer_ms = transfer_ms
self.request_ms = request_ms

def run_double_buffer(self, total_ms: float, num_updates: int):
"""Simulate double-buffer strategy."""
timeline = []
requests_served = 0
clock = 0.0
update_interval = total_ms / num_updates

active_version = 0
trainer = TrainingEngine(self.train_step_ms)
next_update_at = update_interval

while clock < total_ms:
# Check if it's time for a weight update
if clock >= next_update_at:
# Training step (happens on training cluster)
train_start = clock
version, weights = trainer.train_step()
train_end = clock + self.train_step_ms
timeline.append(TimelineEvent(train_start, train_end, 'training',
f'v{version}'))

# Transfer to inactive buffer (OVERLAPS with serving)
transfer_start = train_end
transfer_end = transfer_start + self.transfer_ms
timeline.append(TimelineEvent(transfer_start, transfer_end, 'transfer',
f'→buf'))

# Serving continues during transfer!
serve_during_transfer = int(self.transfer_ms / self.request_ms)
requests_served += serve_during_transfer
timeline.append(TimelineEvent(transfer_start, transfer_end, 'serving',
f'{serve_during_transfer}req'))

# Swap buffers (instant, ~0ms)
active_version = version
clock = transfer_end
next_update_at = clock + update_interval
else:
# Normal serving
serve_end = min(clock + self.request_ms, total_ms)
if serve_end <= total_ms:
requests_served += 1
clock = serve_end

return requests_served, timeline, active_version

def run_stop_the_world(self, total_ms: float, num_updates: int):
"""Simulate stop-the-world strategy."""
timeline = []
requests_served = 0
clock = 0.0
update_interval = total_ms / num_updates

active_version = 0
trainer = TrainingEngine(self.train_step_ms)
next_update_at = update_interval

while clock < total_ms:
if clock >= next_update_at:
# Training step
train_start = clock
version, weights = trainer.train_step()
train_end = clock + self.train_step_ms
timeline.append(TimelineEvent(train_start, train_end, 'training',
f'v{version}'))

# STOP serving → transfer → resume
transfer_start = train_end
transfer_end = transfer_start + self.transfer_ms
timeline.append(TimelineEvent(transfer_start, transfer_end, 'downtime',
'BLOCKED'))
timeline.append(TimelineEvent(transfer_start, transfer_end, 'transfer',
f'→buf'))

active_version = version
clock = transfer_end
next_update_at = clock + update_interval
else:
serve_end = min(clock + self.request_ms, total_ms)
if serve_end <= total_ms:
requests_served += 1
clock = serve_end

return requests_served, timeline, active_version


def print_timeline(label: str, timeline: List[TimelineEvent], total_ms: float):
"""Print ASCII timeline."""
print(f"\n{'='*60}")
print(f" {label}")
print(f"{'='*60}")
width = 50
for etype in ['training', 'transfer', 'serving', 'downtime']:
events = [e for e in timeline if e.event_type == etype]
if not events:
continue
bar = ['.'] * width
for e in events:
s = int(e.start_ms / total_ms * width)
end = int(e.end_ms / total_ms * width)
char = {'training': 'T', 'transfer': '>', 'serving': 'S', 'downtime': 'X'}[etype]
for i in range(s, min(end, width)):
bar[i] = char
print(f" {etype:<10} |{''.join(bar)}|")


def run_simulation(total_time_ms=500.0, num_updates=5):
sim = Simulator(train_step_ms=50.0, transfer_ms=20.0, request_ms=5.0)

# Run both strategies
db_reqs, db_timeline, db_ver = sim.run_double_buffer(total_time_ms, num_updates)
stw_reqs, stw_timeline, stw_ver = sim.run_stop_the_world(total_time_ms, num_updates)

# Print timelines
print_timeline("Double-Buffer (zero downtime)", db_timeline, total_time_ms)
print_timeline("Stop-the-World (has downtime)", stw_timeline, total_time_ms)

# Comparison
print(f"\n{'='*60}")
print(f" COMPARISON (total_time={total_time_ms}ms, {num_updates} updates)")
print(f"{'='*60}")
print(f" Double-Buffer: {db_reqs} requests served")
print(f" Stop-the-World: {stw_reqs} requests served")
improvement = (db_reqs - stw_reqs) / stw_reqs * 100 if stw_reqs > 0 else 0
print(f" Throughput gain: +{improvement:.1f}%")
downtime_total = num_updates * sim.transfer_ms
print(f" STW total downtime: {downtime_total}ms ({downtime_total/total_time_ms*100:.1f}% of time)")
print(f" DB downtime: 0ms (zero downtime)")


if __name__ == "__main__":
run_simulation()

Expected Output:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
============================================================
Double-Buffer (zero downtime)
============================================================
training |..........T.T.......T.T.......T.T.......T.T.......T.T.|
transfer |............>>........>>........>>........>>........>>.|
serving |SSSSSSSSSS..SS.SSSSS..SS.SSSSS..SS.SSSSS..SS.SSSSS..SS|
============================================================
Stop-the-World (has downtime)
============================================================
training |..........T.T.......T.T.......T.T.......T.T.......T.T.|
transfer |............>>........>>........>>........>>........>>.|
downtime |............XX........XX........XX........XX........XX.|
============================================================
COMPARISON (total_time=500ms, 5 updates)
============================================================
Double-Buffer: 80 requests served
Stop-the-World: 60 requests served
Throughput gain: +33.3%
STW total downtime: 100ms (20.0% of time)
DB downtime: 0ms (zero downtime)
Exercise 5: DP Attention 通信量分析

对比传统 TP+EP 方案 vs DP Attention+EP 方案,以 DeepSeek-V3 为例:

  • 模型:61 layers, hidden_size=7168, 256 experts
  • 传统方案:TP=8, EP=64
  • DP Attention 方案:DP Attention (no TP for attention), EP=64
  • 精度:BF16 (2 bytes per element)

请计算:(a) 传统 TP 方案中,每层有多少次 AllReduce 操作?总共多少次?(b) DP Attention 中,用什么替代了 AllReduce?(c) 分别计算两种方案每个 token 的总通信量。

提示:传统 TP 每层需要 2 次 AllReduce(Attention 输出 + FFN/MoE 输出)。Ring AllReduce 通信量 '_' allowed only in math mode= 2 \times (N-1)/N \times \text{message_size}

解答:

(a) 传统 TP 的 AllReduce 次数: 每层 2 次 AllReduce(1 次 Attention output, 1 次 FFN/MoE output) 61 layers 122 次 AllReduce。每次 AllReduce 通信量(ring): bytes 25 KB/token。

(b) DP Attention 的替代方案: Attention 部分零 AllReduce!每个 GPU 独立处理自己的 batch subset 的完整 attention(因为 KV cache 也是 DP 分布的)。Expert 部分仍使用 2 次 All-to-All(dispatch + combine),两种方案都存在。

(c) 每 token 总通信量对比:

传统 TP+EP:

  • AllReduce: MB/token
  • All-to-All (EP): MB/token
  • Total MB/token

DP Attention+EP:

  • AllReduce: 0(attention 无需通信!)
  • All-to-All (EP): MB/token
  • Total MB/token

节省: 3 MB/token 的 AllReduce 带宽,通信量减少约 64%。这是 DP Attention 对 MoE 模型的核心价值:消除了 attention 层的所有集合通信开销。

Exercise 6: FP8 Training 的 Loss Scaling 设计

FP8 E4M3 的可表示范围是 。在 RL 训练中,梯度可能出现超过 1000 的 outlier:

  • FP8 E4M3 max = 448
  • RL 训练梯度方差比 SFT 大 3-5 倍
  • Policy exploration 导致 activation 范围更宽

请回答:(a) 如果不做 loss scaling,梯度 outlier 会怎样?(b) 设计一个 dynamic loss scaling 策略;(c) 为什么 FP8 训练在 RL 场景比 supervised training 更困难?

解答:

(a) 不做 loss scaling 的后果: 梯度值 > 448 时溢出为 NaN 或 Inf → 参数更新为 NaN → 训练完全崩溃(diverge)。即使只有一个元素溢出,NaN 会在 AllReduce 中 “感染” 所有 GPU 的梯度。

(b) Dynamic Loss Scaling 策略:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class DynamicFP8LossScaler:
def __init__(self):
self.scale = 2**16 # 初始 scale = 65536
self.growth_interval = 1000 # 连续 1000 步无 overflow 则放大
self.steps_since_overflow = 0

def step(self, optimizer, gradients):
if has_overflow(gradients): # 检测 NaN/Inf
self.scale /= 2 # 缩小 scale
self.steps_since_overflow = 0
return # 跳过此步(不更新参数)
else:
optimizer.step(gradients / self.scale) # 正常更新
self.steps_since_overflow += 1
if self.steps_since_overflow >= self.growth_interval:
self.scale *= 2 # 放大 scale
self.steps_since_overflow = 0

进阶优化:使用 per-tensor scaling(不同层用不同 scale),因为 attention 层和 FFN 层的梯度量级可能差异很大。

(c) FP8 在 RL 中更困难的原因:

  • 梯度方差更大: RL 的 reward signal 是 noisy 的(来自 RM 打分或规则验证),不像 SFT 有干净的 label,导致梯度方差高 3-5×
  • Policy exploration: RL 鼓励模型探索不同输出,导致 activation 值范围更宽,更容易触发 overflow
  • 非平稳性: RL 的数据分布随策略变化而变化(on-policy),梯度统计量不稳定,dynamic scaling 需要更频繁调整
  • Outlier 更频繁: 当模型产生非常好或非常差的输出时,reward 差异大 → advantage 大 → 梯度大 → overflow 风险高

解决方案:更保守的初始 scale、更短的 growth interval、per-tensor scaling + delay scaling(用前几步的 amax 统计来预设 scale)。

Exercise 7: Speculative Decoding in RL Rollout

在 GRPO rollout 中使用 speculative decoding 加速生成:

  • Draft model: 1B 参数
  • Target model (Actor): 70B 参数
  • Draft length
  • Per-token acceptance rate
  • 每个序列生成 512 tokens

请计算:(a) 每次 verify step 的期望输出 token 数;(b) 生成 512 token 需要多少次 target model forward pass?(c) 随着 RL 训练进行,draft model 的分布还 “够接近” actor 吗?会发生什么?

解答:

(a) 每次 verify step 的期望输出 token 数:

考虑 rejection sampling 保证的 bonus token,实际 tokens/step。

(b) 需要的 target model forward pass 次数: 次 verify passes。对比不用 speculative decoding 需要 512 次。加速比 (draft 开销可忽略,1B vs 70B 成本比约 1:70)。

(c) Draft model 分布漂移问题: 随着 RL 训练进行,actor (70B) 的分布会偏离初始策略。若 draft model (1B) 冻结,其分布会越来越不匹配:

  • acceptance rate 下降: 从 0.7 可能降到 0.4-0.5,加速比显著下降
  • 训练早期: actor 变化小,draft 还能用 → 高加速比
  • 训练后期: actor 已大幅偏离 → draft 几乎无用

解决方案:

  • 定期用 actor 的输出 online SFT 更新 draft model(1B 模型 SFT 很便宜)
  • 使用 actor 的 earlier checkpoint 作为 draft(每 N 步更新一次 draft = actor[step-N])
  • Adaptive K:当 acceptance rate 下降时,减小 draft length K 以减少浪费

本章小结

本章覆盖了 ML 系统中的前沿优化技术:

  • Zero-Overhead Scheduling: 通过 CPU-GPU overlap 和 double-buffering 消除调度开销
  • DP Attention: 针对 MoE 模型优化内存效率,将 KV cache 需求降低 DP_degree 倍
  • Online Weight Updates: 利用 NCCL broadcast 实现训练-推理间的无缝权重同步
  • FP8 Training: 统一训练和推理精度,消除 train-inference gap
  • Speculative Decoding: 通过 draft model + online SFT 加速 RL rollout
  • Long Context: DuoAttention 利用 head 功能分化节省 KV cache
  • Diffusion LLMs: 非自回归生成,block-parallel decoding 打破序列依赖

本文是 ML Systems 系列 Chapter 6。正文 markdown 渲染,4 个交互动画通过自定义 {% anim %} 标签以隔离 iframe 嵌入,源自 Arkive 教程。