从 Self-Attention 到 Flash Attention

从 Self-Attention 到 Flash Attention,理解现代 LLM 的核心计算原理。本文正文为 markdown,关键机制配有可交互动画(点按钮逐步演示)。

1.1 Self-Attention 机制

Self-Attention 的核心是让序列中每个 token 都能”关注”到其他所有 token,计算公式:

下面这个动画用一组小矩阵(Q 2×2、K/V 3×2)逐步演示:QKᵀ → 缩放 → softmax → 加权求和 V

代码实现:Scaled Dot-Product Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
"""
手动实现 Scaled Dot-Product Attention

Args:
Q: Query tensor, shape (batch, seq_len_q, d_k)
K: Key tensor, shape (batch, seq_len_k, d_k)
V: Value tensor, shape (batch, seq_len_k, d_v)
mask: 可选的 attention mask

Returns:
output: shape (batch, seq_len_q, d_v)
attention_weights: shape (batch, seq_len_q, seq_len_k)
"""
d_k = Q.size(-1)

# Step 1: 计算 Q @ K^T,得到 attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, seq_q, seq_k)

# Step 2: 缩放
scores = scores / math.sqrt(d_k)

# Step 3: 应用 mask(用于 causal attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

# Step 4: Softmax 归一化
attention_weights = F.softmax(scores, dim=-1)

# Step 5: 加权求和 V
output = torch.matmul(attention_weights, V) # (batch, seq_q, d_v)

return output, attention_weights


# 示例用法
batch_size, seq_len, d_model = 2, 8, 64
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

# Causal mask: 下三角矩阵,防止看到未来的 token
causal_mask = torch.tril(torch.ones(seq_len, seq_len))

output, weights = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
print(f"Output shape: {output.shape}") # (2, 8, 64)
print(f"Weights shape: {weights.shape}") # (2, 8, 8)
print(f"Weights sum: {weights.sum(dim=-1)}") # 每行和为 1

1.2 Multi-Head Attention

Multi-Head Attention (MHA) 将输入投影到多个不同的子空间中,每个 “head” 独立计算 attention,最后将结果拼接并投影回原始维度。


为什么要拆分多个 Head?

  • 多样性:不同的 head 可以学习关注不同类型的关系(语法、语义、位置等)
  • 计算效率:每个 head 的维度为 ,总计算量与单个大 head 相同
  • 表达能力:多个子空间的注意力模式比单一空间更丰富

典型配置:GPT-3 使用 96 个 head,,每个 head 维度为 128。LLaMA-2-70B 使用 64 个 head(加上 8 个 KV head 用于 GQA)。

下面动画展示 tensor 如何被拆成多个 head、并行计算、再 concat 投影回去:

Grouped Query Attention (GQA)

标准 MHA 中,每个 head 都有独立的 K 和 V 投影。GQA 将多个 query head 共享同一组 KV head,显著减少 KV Cache 的内存占用:

方法 Query Heads KV Heads KV Cache 大小
MHA h h
GQA h h/g 1/g ×
MQA h 1 1/h ×

1.3 Flash Attention

Flash Attention 是一种 IO-aware 的精确 attention 算法。它的核心动机:GPU 的计算速度远快于内存带宽。标准 Attention 的瓶颈不是计算 FLOPS,而是 HBM(高带宽内存)的读写次数。

问题根源:标准 Attention 的 HBM 访问

标准 Attention 的计算步骤,每一步都需要 HBM 读/写。以单个 head 为例追踪 shape 和 HBM 访问量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 设 N = 序列长度, d = head 维度 (如 N=4096, d=128)
# Q: [N, d] K: [N, d] V: [N, d] 均存储在 HBM 中,每个元素占 2 bytes (FP16)

# ═══════ Step 1: 计算 Attention Score 矩阵 ═══════
# 从 HBM 读取 Q[N,d] 和 K[N,d] → 读取量 = 2Nd 个元素
S = Q @ K.T # S: [N, d] × [d, N] = [N, N]
# 将 S[N,N] 写回 HBM → 写入量 = N² 个元素 ← 这是灾难性的!

# ═══════ Step 2: Softmax ═══════
# 从 HBM 读取 S[N,N] → 读取量 = N²
P = softmax(S / sqrt(d)) # P: [N, N] (每行 sum=1)
# 将 P[N,N] 写回 HBM → 写入量 = N²

# ═══════ Step 3: 加权求和 Value ═══════
# 从 HBM 读取 P[N,N] 和 V[N,d] → 读取量 = N² + Nd
O = P @ V # O: [N, N] × [N, d] = [N, d]
# 将 O[N,d] 写回 HBM → 写入量 = Nd

# ═══════ 总 HBM 访问量汇总 ═══════
# 读: 2Nd + N² + N² + Nd = 3Nd + 2N²
# 写: N² + N² + Nd = 2N² + Nd
# 合计: 4N² + 4Nd = Θ(N² + Nd)
#
# 代入 N=4096, d=128: 总计 ≈ 69M 个元素 ≈ 132 MB
# 额外 HBM 内存 (存中间结果): S[N,N] + P[N,N] = 2N² = 32M 元素 = 64 MB!

关键瓶颈:中间矩阵 S 和 P 都是 [N, N] 大小。当 N=4096, d=128 时:S[4096×4096] 占 32MB,而 Q/K/V 各只占 1MB —— 中间结果比输入/输出大 32 倍!我们为了存临时中间矩阵,付出了 O(N²) 的额外内存和大量 HBM 读写。

Flash Attention 的解法:分块 + 在线 Softmax

核心思想:永远不把 N×N 矩阵写入 HBM。将 Q, K, V 切成小块,在 Shared Memory(每个 SM 上的可编程 SRAM,A100 上最大 ~192KB/SM)中完成所有计算。注意这不是 L2 Cache(40MB,硬件管理),而是程序员可直接控制的片上存储。

设单个 SM 的 Shared Memory 容量为 M 个元素,需要同时容纳 K 块、V 块、Q 块、O 块、中间 score 矩阵:


下面的动画用 N=8, d=4, 的例子,展示每一步各 block 矩阵的 shape、计算过程,以及结果如何累加到最终输出 O:

IO 复杂度推导(核心!)

标准 Attention:总计 次 HBM 访问,额外内存

Flash Attention:外循环遍历 K,V 块 次,内循环遍历 Q 块 次:

访

代入

额外内存只需 (每行的 running max 和 running sum)。

为什么这么好?(典型 M=100K, d²=16K)时,。例:N=4096, d=128, M=100K → 标准 ≈ 17.3M 次 HBM 访问,Flash ≈ 2.75M 次(节省 84%!)。

在线 Softmax 算法

标准 softmax 需要两次遍历(pass 1 求 max;pass 2 求 exp 之和),要求 S 矩阵完全 materialize。Flash Attention 用在线算法只需一次遍历。下面动画用一个 1×8 score 向量(分 4 块),动态展示 m, l, O 如何逐块更新,以及新块出现更大 max 时旧值如何被 correction factor α 修正:

数学原理:attention 输出本质上就是 分子 / 分母

1
2
3
output_i = Σ_k softmax(s_k) × v_k
= (Σ_k exp(s_k) × v_k) / (Σ_k exp(s_k))
= numerator / denominator

在线算法增量式累积分子和分母,同时用 running max 保持数值稳定:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# ═══ 维护三个统计量(处理完前 j-1 个 KV 块后)═══
m = 目前见过的所有 score 的 max # 标量, 用于数值稳定
l = Σ exp(s_k - m) # 标量, rescaled 分母
O = Σ exp(s_k - m) × v_k # [1×d], rescaled 分子
# 关系: O / l = Σ softmax(s_k) × v_k = 最终 attention 输出!

# ═══ 处理第 j 个 KV 块 ═══
# 计算新 score: S_j = q_i × K_j^T (在 SRAM 中)
m_new = max(m, max(S_j)) # 可能发现更大的值
alpha = exp(m - m_new) # correction factor (≤1)
l_new = alpha * l + Σ exp(S_j - m_new) # 更新分母
O_new = alpha * O + exp(S_j - m_new) * V_j # 更新分子

# ═══ 遍历完所有块后 ═══
output_i = O / l # 分子 / 分母 = softmax attention 输出 ✓

直觉理解

  • O 和 l 的关系 —— l 累加 exp(score),O 累加 exp(score)×v,最后 O/l 就是 softmax 加权平均,更新方式完全对称。
  • 为什么需要 α = exp(m_old − m_new)? —— 遇到更大 score 后,之前所有 exp(s − m_old) 应变成 exp(s − m_new),由 exp(s − m_new) = exp(s − m_old) × exp(m_old − m_new),旧值全部乘 α 即可。
  • 新块没有更大 max? —— m_new = m_old → α = 1 → 旧值不动,只加新项。
  • 数学等价性 —— 不是近似!遍历完所有块后 O/l 严格等于 softmax(QKᵀ)V。

完整伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Flash Attention Forward Pass 完整算法
# 输入: Q, K, V ∈ R^{N×d},存储在 HBM 中;输出: O ∈ R^{N×d};SRAM 容量: M 个元素

# Step 1: 确定 block sizes
B_c = M // (4 * d) # KV block 行数
B_r = min(M // (4 * d), d) # Q block 行数
T_c = ceil(N / B_c) # K,V 的块数
T_r = ceil(N / B_r) # Q 的块数

# Step 2: 初始化输出和统计量(在 HBM 中)
O = zeros(N, d); l = zeros(N); m = full(N, -inf)

# Step 3: 外循环 - 遍历 K, V 块
for j in range(T_c):
K_j = K[j*B_c : (j+1)*B_c] # [B_c, d] → HBM 读取
V_j = V[j*B_c : (j+1)*B_c] # [B_c, d] → HBM 读取
for i in range(T_r): # 内循环 - 遍历 Q 块
Q_i = Q[i*B_r : (i+1)*B_r]
O_i = O[i*B_r : (i+1)*B_r]
l_i = l[i*B_r : (i+1)*B_r]
m_i = m[i*B_r : (i+1)*B_r]

# === 以下全部在 SRAM 中完成 ===
S_ij = Q_i @ K_j.T # [B_r, B_c] ← 在 SRAM 中!
m_new = max(m_i, S_ij.max(dim=-1))
P_ij = exp(S_ij - m_new[:, None])
l_new = exp(m_i - m_new) * l_i + P_ij.sum(dim=-1)
O_i = (l_i / l_new)[:, None] * exp(m_i - m_new)[:, None] * O_i \
+ (1 / l_new)[:, None] * P_ij @ V_j
m_i, l_i = m_new, l_new

# === 写回 HBM ===
O[i*B_r : (i+1)*B_r] = O_i
l[i*B_r : (i+1)*B_r] = l_i
m[i*B_r : (i+1)*B_r] = m_i

return O # 数学上 ≡ softmax(QK^T / √d) @ V

IO 复杂度总结

指标 标准 Attention Flash Attention 优势
HBM 读写次数 Θ(Nd + N²) Θ(N²d²/M) 当 M > d² 时减少
额外 HBM 内存 O(N²) (存 S, P) O(N) (存 l, m) 从二次降到线性!
SRAM 使用 不利用 O(M) (满载利用)
计算量 (FLOPs) O(N²d) O(N²d) (相同!) 计算量不变
数值结果 精确 精确 (bit-identical) 不是近似!

Flash Attention 不是近似算法! 它与标准 Attention 产生 bit-identical 的结果。差异仅在 IO 模式:用更多计算换更少的内存访问(因为 GPU 计算远快于内存带宽)。

1.4 Cross Attention vs Self Attention

Self Attention 和 Cross Attention 的核心区别在于 Q, K, V 的来源:

特性 Self Attention Cross Attention
Q 来源 同一序列 Decoder 序列
K, V 来源 同一序列 Encoder 输出
典型用途 语言建模、BERT 翻译、图文理解
Causal Mask Decoder 中使用 通常不使用
KV Cache 每步增长 编码后固定不变

应用场景

  • Self Attention:GPT 系列(causal)、BERT(bidirectional)
  • Cross Attention:机器翻译(Encoder-Decoder)、多模态模型(image tokens attend to text tokens)、Stable Diffusion(text→image conditioning)

1.5 KV Cache 原理

在 autoregressive 生成中,每生成一个新 token 都需要重新计算所有前序 token 的 attention。KV Cache 通过缓存已计算的 K 和 V 矩阵,避免重复计算。

为什么只缓存 K 和 V,不缓存 Q? 因为生成时只需要计算新 token 的 Q 与所有之前 token 的 K, V 之间的 attention。之前 token 的 Q 不再需要(它们的输出已经确定):

下面动画一步步生成 token,左边是 attention 矩阵(只算新的一行),右边是 KV Cache 的增长,底部对比有/无 cache 的计算量:

内存占用分析:KV Cache 大小 = 2 (K+V) × num_layers × num_heads × head_dim × seq_len × dtype_size。

例如:LLaMA-2-70B(80 层, 64 heads, head_dim=128, FP16)在 seq_len=4096 时:KV Cache = 2 × 80 × 64 × 128 × 4096 × 2 bytes ≈ 10.7 GB per request

代码实现:KV Cache

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn

class CachedAttention(nn.Module):
"""带有 KV Cache 的 Attention 实现"""

def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.k_cache = None # KV Cache: 初始为空
self.v_cache = None

def forward(self, x, use_cache=False):
"""
x: (batch, seq_len, d_model)
- Prefill 阶段: seq_len = 完整 prompt 长度
- Decode 阶段: seq_len = 1 (只有新 token)
"""
B, L, _ = x.shape
q = self.W_q(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.W_k(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
v = self.W_v(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

if use_cache:
if self.k_cache is not None:
# Decode 阶段:拼接缓存的 K, V
k = torch.cat([self.k_cache, k], dim=2)
v = torch.cat([self.v_cache, v], dim=2)
self.k_cache = k.detach() # 更新缓存
self.v_cache = v.detach()

scale = self.head_dim ** -0.5
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
if L > 1: # Causal mask (decode 时 q 只有 1 个 token,不需要 mask)
mask = torch.tril(torch.ones(L, k.size(2), device=x.device))
scores = scores.masked_fill(mask[-L:] == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
return self.W_o(out)

def clear_cache(self):
self.k_cache = None
self.v_cache = None


# 使用示例:模拟 autoregressive 生成
model = CachedAttention(d_model=512, num_heads=8)
model.eval()

# Prefill: 处理完整 prompt
prompt = torch.randn(1, 10, 512) # 10 个 token
with torch.no_grad():
out = model(prompt, use_cache=True)
print(f"Prefill: cache size = {model.k_cache.shape}") # (1, 8, 10, 64)

# Decode: 每次只输入 1 个新 token
for step in range(5):
new_token = torch.randn(1, 1, 512)
with torch.no_grad():
out = model(new_token, use_cache=True)
print(f"Step {step+1}: cache size = {model.k_cache.shape}")
# cache 逐步增长: (1, 8, 11, 64) → (1, 8, 12, 64) → ...

1.6 Special Tokens 与 Chat Templates

现代 LLM 使用特殊 token 来标记序列结构和控制生成行为。这些 token 在训练时就已定义,模型学会了它们的语义。

Token 用途 示例模型
<BOS> / <s> 序列开始 LLaMA, Mistral
<EOS> / </s> 序列结束,停止生成 LLaMA, Mistral
<PAD> 填充到固定长度 BERT, T5
<UNK> 未知 token(OOV 处理) 多数模型
[INST][/INST] 指令标记 LLaMA-2-Chat
<|im_start|> / <|im_end|> Chat 消息边界 ChatML format

Chat Template 将多轮对话转换为模型可理解的 token 序列,不同模型格式不同:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# LLaMA-2-Chat 格式
"""
<s>[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>

What is attention? [/INST] Attention is a mechanism that... </s>
<s>[INST] Can you explain more? [/INST]
"""

# ChatML 格式 (used by many models)
"""
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is attention?<|im_end|>
<|im_start|>assistant
Attention is a mechanism that...<|im_end|>
"""

# 使用 Hugging Face transformers 应用 chat template
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is Flash Attention?"},
]
formatted = tokenizer.apply_chat_template(messages, tokenize=False)
print(formatted)
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")

重要:错误的 chat template 会导致模型输出质量严重下降。每个模型都有期望的格式,必须严格遵守。使用 tokenizer.apply_chat_template() 是最安全的方式。

角色划分与 Attention Mask:在 chat 场景中,System 是全局上下文(所有后续 token 都能 attend 到)、User 是用户输入、Assistant 是模型输出。在 SFT 中通常只对 assistant 部分的 token 计算交叉熵 loss,system 和 user 部分作为 context 但不参与 loss 计算。

练习题

练习 1:手动计算 Scaled Dot-Product Attention

给定矩阵():Q = [[1,0],[0,1]]K = [[1,0],[0,1],[1,1]]V = [[1,0],[0,1],[1,1]]。请计算 Attention(Q, K, V)(不使用 causal mask)。

完整解答

1
2
3
4
5
6
7
8
9
Step 1: QK^T = [[1,0,1], [0,1,1]]
Step 2: 缩放 (÷√2) = [[0.707, 0, 0.707], [0, 0.707, 0.707]]
Step 3: Softmax
Row 1: [0.401, 0.198, 0.401]
Row 2: [0.198, 0.401, 0.401]
Step 4: Weights @ V
Row 1: [0.802, 0.599]
Row 2: [0.599, 0.802]
最终结果 ≈ [[0.802, 0.599], [0.599, 0.802]]
练习 2:Flash Attention IO 复杂度分析

设 N=4096, d=128, SRAM M=100KB(约 50K 个 FP16 元素)。

标准 Attention HBM 访问量:读 = 3Nd + 2N² ≈ 35.1M,写 = 2N² + Nd ≈ 34.1M,合计 4N² + 4Nd ≈ 69.2M 元素 × 2 bytes ≈ 132 MB。

Flash 配置,取 ;外/内循环各 次。

Flash HBM 访问量:≈ N²d/B_c = 4096²×128/64 ≈ 33.5M 元素,相比标准节省约 52%。更重要的是不需要 O(N²) 的额外内存——标准需 32MB,Flash 只需 O(N) ≈ 几 KB。

练习 3:估算 KV Cache 内存占用(LLaMA-3-70B + GQA)

配置:80 层、64 query heads(GQA,KV heads=8)、head_dim=128、FP16。batch=32, seq=8192。

1
2
3
4
5
6
7
8
num_layers, num_kv_heads, head_dim = 80, 8, 128
seq_len, batch_size, dtype_bytes = 8192, 32, 2

per_token = 2 * num_layers * num_kv_heads * head_dim * dtype_bytes # ≈ 320 KB
per_request = per_token * seq_len # ≈ 2.5 GB
total = per_request * batch_size # = 80 GB (FP16)
total_mha = total * (64 / 8) # = 640 GB (不可行!)
total_fp8 = total / 2 # = 40 GB

结论:GQA (8 KV heads) 使 KV Cache 从 640GB 降至 80GB(省 87.5%);FP8 量化再降至 40GB。这也是 PagedAttention (vLLM) 等技术重要的原因。

练习 4(Bonus):为什么 Prefill 是 compute-bound,Decode 是 memory-bound?

Arithmetic Intensity = FLOPs / Bytes Accessed。

Prefill:N 个 token 同时处理,做大 GEMM。QKV 投影 AI ≈ 3000 FLOPs/byte,Attention AI ≈ 2730 FLOPs/byte,都 >> H100 的计算/带宽比 (~295) → Compute-bound

Decode:仅 1 个新 token,但需读整个 KV Cache。FLOPs ≈ 67M,读取 ≈ 64MB → AI ≈ 1.05 FLOPs/byte << 295 → Memory-bound,GPU 大部分时间在等 HBM 读 KV Cache。

两者 AI 差了约 3000 倍!这就是 LLM serving 区分 prefill / decode(continuous batching、chunked prefill)的原因。

Coding Exercise A:从零实现 Multi-Head Attention

要求:实现 Q/K/V 投影 → split heads → scaled dot-product attention → concat → output projection,支持 causal mask,并与 torch.nn.MultiheadAttention 对比验证(误差 < 1e-5)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MyMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=True)
self.W_k = nn.Linear(d_model, d_model, bias=True)
self.W_v = nn.Linear(d_model, d_model, bias=True)
self.W_o = nn.Linear(d_model, d_model, bias=True)

def forward(self, x, causal_mask=False):
B, N, D = x.shape
Q = self.W_q(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if causal_mask:
mask = torch.triu(torch.ones(N, N, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
output = attention_output.transpose(1, 2).contiguous().view(B, N, D)
return self.W_o(output)


def verify_implementation():
torch.manual_seed(42)
d_model, num_heads = 64, 8
x = torch.randn(2, 10, d_model)
my_mha = MyMultiHeadAttention(d_model, num_heads).eval()
official_mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True).eval()
with torch.no_grad():
# nn.MultiheadAttention 的 in_proj_weight = [W_q; W_k; W_v]
official_mha.in_proj_weight.copy_(
torch.cat([my_mha.W_q.weight, my_mha.W_k.weight, my_mha.W_v.weight], dim=0))
official_mha.in_proj_bias.copy_(
torch.cat([my_mha.W_q.bias, my_mha.W_k.bias, my_mha.W_v.bias], dim=0))
official_mha.out_proj.weight.copy_(my_mha.W_o.weight)
official_mha.out_proj.bias.copy_(my_mha.W_o.bias)
my_output = my_mha(x, causal_mask=False)
official_output, _ = official_mha(x, x, x)
max_diff = (my_output - official_output).abs().max().item()
print(f"最大差异: {max_diff:.2e}")
assert max_diff < 1e-5
print("✓ 验证通过!")

verify_implementation()
Coding Exercise B:实现 KV Cache 的增量 Decode

要求:实现 prefill(处理完整 prompt)+ decode(逐 token,复用 cached K/V),并实现无 cache baseline,验证两者输出 bit-identical 且展示加速。提示:decode 时只计算新 token 的 Q,与拼接后的完整 K/V 做 attention;cache shape 为 (batch, num_heads, past_len, head_dim)


本文是 ML Systems 系列 Chapter 1。正文 markdown 渲染,5 个交互动画通过自定义 {% anim %} 标签以隔离 iframe 嵌入,源自 Arkive 教程。