从 Self-Attention 到 Flash Attention,理解现代 LLM 的核心计算原理。本文正文为 markdown,关键机制配有可交互动画(点按钮逐步演示)。
1.1 Self-Attention 机制 Self-Attention 的核心是让序列中每个 token 都能”关注”到其他所有 token,计算公式:
下面这个动画用一组小矩阵(Q 2×2、K/V 3×2)逐步演示:QKᵀ → 缩放 → softmax → 加权求和 V。
代码实现:Scaled Dot-Product Attention 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 import torchimport torch.nn.functional as Fimport mathdef scaled_dot_product_attention (Q, K, V, mask=None ): """ 手动实现 Scaled Dot-Product Attention Args: Q: Query tensor, shape (batch, seq_len_q, d_k) K: Key tensor, shape (batch, seq_len_k, d_k) V: Value tensor, shape (batch, seq_len_k, d_v) mask: 可选的 attention mask Returns: output: shape (batch, seq_len_q, d_v) attention_weights: shape (batch, seq_len_q, seq_len_k) """ d_k = Q.size(-1 ) scores = torch.matmul(Q, K.transpose(-2 , -1 )) scores = scores / math.sqrt(d_k) if mask is not None : scores = scores.masked_fill(mask == 0 , float ('-inf' )) attention_weights = F.softmax(scores, dim=-1 ) output = torch.matmul(attention_weights, V) return output, attention_weights batch_size, seq_len, d_model = 2 , 8 , 64 Q = torch.randn(batch_size, seq_len, d_model) K = torch.randn(batch_size, seq_len, d_model) V = torch.randn(batch_size, seq_len, d_model) causal_mask = torch.tril(torch.ones(seq_len, seq_len)) output, weights = scaled_dot_product_attention(Q, K, V, mask=causal_mask)print (f"Output shape: {output.shape} " ) print (f"Weights shape: {weights.shape} " ) print (f"Weights sum: {weights.sum (dim=-1 )} " )
1.2 Multi-Head Attention Multi-Head Attention (MHA) 将输入投影到多个不同的子空间中,每个 “head” 独立计算 attention,最后将结果拼接并投影回原始维度。
为什么要拆分多个 Head?
多样性 :不同的 head 可以学习关注不同类型的关系(语法、语义、位置等)
计算效率 :每个 head 的维度为 ,总计算量与单个大 head 相同
表达能力 :多个子空间的注意力模式比单一空间更丰富
典型配置 :GPT-3 使用 96 个 head, ,每个 head 维度为 128。LLaMA-2-70B 使用 64 个 head(加上 8 个 KV head 用于 GQA)。
下面动画展示 tensor 如何被拆成多个 head、并行计算、再 concat 投影回去:
Grouped Query Attention (GQA) 标准 MHA 中,每个 head 都有独立的 K 和 V 投影。GQA 将多个 query head 共享同一组 KV head,显著减少 KV Cache 的内存占用:
方法
Query Heads
KV Heads
KV Cache 大小
MHA
h
h
1×
GQA
h
h/g
1/g ×
MQA
h
1
1/h ×
1.3 Flash Attention Flash Attention 是一种 IO-aware 的精确 attention 算法。它的核心动机:GPU 的计算速度远快于内存带宽 。标准 Attention 的瓶颈不是计算 FLOPS,而是 HBM(高带宽内存)的读写次数。
问题根源:标准 Attention 的 HBM 访问 标准 Attention 的计算步骤,每一步都需要 HBM 读/写。以单个 head 为例追踪 shape 和 HBM 访问量:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 S = Q @ K.T P = softmax(S / sqrt(d)) O = P @ V
关键瓶颈 :中间矩阵 S 和 P 都是 [N, N] 大小。当 N=4096, d=128 时:S[4096×4096] 占 32MB,而 Q/K/V 各只占 1MB —— 中间结果比输入/输出大 32 倍 !我们为了存临时中间矩阵,付出了 O(N²) 的额外内存和大量 HBM 读写。
Flash Attention 的解法:分块 + 在线 Softmax 核心思想:永远不把 N×N 矩阵写入 HBM 。将 Q, K, V 切成小块,在 Shared Memory (每个 SM 上的可编程 SRAM,A100 上最大 ~192KB/SM)中完成所有计算。注意这不是 L2 Cache(40MB,硬件管理),而是程序员可直接控制的片上存储。
设单个 SM 的 Shared Memory 容量为 M 个元素,需要同时容纳 K 块、V 块、Q 块、O 块、中间 score 矩阵:
下面的动画用 N=8, d=4, 的例子,展示每一步各 block 矩阵的 shape、计算过程,以及结果如何累加到最终输出 O:
IO 复杂度推导(核心!) 标准 Attention :总计 次 HBM 访问,额外内存 。
Flash Attention :外循环遍历 K,V 块 次,内循环遍历 Q 块 次:
访 问
代入 :
额外内存只需 (每行的 running max 和 running sum)。
为什么这么好? 当 (典型 M=100K, d²=16K)时, 。例:N=4096, d=128, M=100K → 标准 ≈ 17.3M 次 HBM 访问,Flash ≈ 2.75M 次(节省 84% !)。
在线 Softmax 算法 标准 softmax 需要两次遍历(pass 1 求 max;pass 2 求 exp 之和),要求 S 矩阵完全 materialize。Flash Attention 用在线算法 只需一次遍历。下面动画用一个 1×8 score 向量(分 4 块),动态展示 m, l, O 如何逐块更新,以及新块出现更大 max 时旧值如何被 correction factor α 修正:
数学原理 :attention 输出本质上就是 分子 / 分母 :
1 2 3 output_i = Σ_k softmax(s_k) × v_k = (Σ_k exp(s_k) × v_k) / (Σ_k exp(s_k)) = numerator / denominator
在线算法增量式累积分子和分母,同时用 running max 保持数值稳定:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 m = 目前见过的所有 score 的 max l = Σ exp(s_k - m) O = Σ exp(s_k - m) × v_k m_new = max (m, max (S_j)) alpha = exp(m - m_new) l_new = alpha * l + Σ exp(S_j - m_new) O_new = alpha * O + exp(S_j - m_new) * V_j output_i = O / l
直觉理解 :
O 和 l 的关系 —— l 累加 exp(score),O 累加 exp(score)×v,最后 O/l 就是 softmax 加权平均,更新方式完全对称。
为什么需要 α = exp(m_old − m_new)? —— 遇到更大 score 后,之前所有 exp(s − m_old) 应变成 exp(s − m_new),由 exp(s − m_new) = exp(s − m_old) × exp(m_old − m_new),旧值全部乘 α 即可。
新块没有更大 max? —— m_new = m_old → α = 1 → 旧值不动,只加新项。
数学等价性 —— 不是近似!遍历完所有块后 O/l 严格等于 softmax(QKᵀ)V。
完整伪代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 B_c = M // (4 * d) B_r = min (M // (4 * d), d) T_c = ceil(N / B_c) T_r = ceil(N / B_r) O = zeros(N, d); l = zeros(N); m = full(N, -inf)for j in range (T_c): K_j = K[j*B_c : (j+1 )*B_c] V_j = V[j*B_c : (j+1 )*B_c] for i in range (T_r): Q_i = Q[i*B_r : (i+1 )*B_r] O_i = O[i*B_r : (i+1 )*B_r] l_i = l[i*B_r : (i+1 )*B_r] m_i = m[i*B_r : (i+1 )*B_r] S_ij = Q_i @ K_j.T m_new = max (m_i, S_ij.max (dim=-1 )) P_ij = exp(S_ij - m_new[:, None ]) l_new = exp(m_i - m_new) * l_i + P_ij.sum (dim=-1 ) O_i = (l_i / l_new)[:, None ] * exp(m_i - m_new)[:, None ] * O_i \ + (1 / l_new)[:, None ] * P_ij @ V_j m_i, l_i = m_new, l_new O[i*B_r : (i+1 )*B_r] = O_i l[i*B_r : (i+1 )*B_r] = l_i m[i*B_r : (i+1 )*B_r] = m_ireturn O
IO 复杂度总结
指标
标准 Attention
Flash Attention
优势
HBM 读写次数
Θ(Nd + N²)
Θ(N²d²/M)
当 M > d² 时减少
额外 HBM 内存
O(N²) (存 S, P)
O(N) (存 l, m)
从二次降到线性!
SRAM 使用
不利用
O(M) (满载利用)
—
计算量 (FLOPs)
O(N²d)
O(N²d) (相同!)
计算量不变
数值结果
精确
精确 (bit-identical)
不是近似!
Flash Attention 不是近似算法! 它与标准 Attention 产生 bit-identical 的结果。差异仅在 IO 模式:用更多计算换更少的内存访问(因为 GPU 计算远快于内存带宽)。
1.4 Cross Attention vs Self Attention Self Attention 和 Cross Attention 的核心区别在于 Q, K, V 的来源:
特性
Self Attention
Cross Attention
Q 来源
同一序列
Decoder 序列
K, V 来源
同一序列
Encoder 输出
典型用途
语言建模、BERT
翻译、图文理解
Causal Mask
Decoder 中使用
通常不使用
KV Cache
每步增长
编码后固定不变
应用场景 :
Self Attention :GPT 系列(causal)、BERT(bidirectional)
Cross Attention :机器翻译(Encoder-Decoder)、多模态模型(image tokens attend to text tokens)、Stable Diffusion(text→image conditioning)
1.5 KV Cache 原理 在 autoregressive 生成中,每生成一个新 token 都需要重新计算所有前序 token 的 attention。KV Cache 通过缓存已计算的 K 和 V 矩阵,避免重复计算。
为什么只缓存 K 和 V,不缓存 Q? 因为生成时只需要计算新 token 的 Q 与所有之前 token 的 K, V 之间的 attention。之前 token 的 Q 不再需要(它们的输出已经确定):
下面动画一步步生成 token,左边是 attention 矩阵(只算新的一行),右边是 KV Cache 的增长,底部对比有/无 cache 的计算量:
内存占用分析 :KV Cache 大小 = 2 (K+V) × num_layers × num_heads × head_dim × seq_len × dtype_size。
例如 :LLaMA-2-70B(80 层, 64 heads, head_dim=128, FP16)在 seq_len=4096 时:KV Cache = 2 × 80 × 64 × 128 × 4096 × 2 bytes ≈ 10.7 GB per request 。
代码实现:KV Cache 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 import torchimport torch.nn as nnclass CachedAttention (nn.Module): """带有 KV Cache 的 Attention 实现""" def __init__ (self, d_model, num_heads ): super ().__init__() self .d_model = d_model self .num_heads = num_heads self .head_dim = d_model // num_heads self .W_q = nn.Linear(d_model, d_model) self .W_k = nn.Linear(d_model, d_model) self .W_v = nn.Linear(d_model, d_model) self .W_o = nn.Linear(d_model, d_model) self .k_cache = None self .v_cache = None def forward (self, x, use_cache=False ): """ x: (batch, seq_len, d_model) - Prefill 阶段: seq_len = 完整 prompt 长度 - Decode 阶段: seq_len = 1 (只有新 token) """ B, L, _ = x.shape q = self .W_q(x).view(B, L, self .num_heads, self .head_dim).transpose(1 , 2 ) k = self .W_k(x).view(B, L, self .num_heads, self .head_dim).transpose(1 , 2 ) v = self .W_v(x).view(B, L, self .num_heads, self .head_dim).transpose(1 , 2 ) if use_cache: if self .k_cache is not None : k = torch.cat([self .k_cache, k], dim=2 ) v = torch.cat([self .v_cache, v], dim=2 ) self .k_cache = k.detach() self .v_cache = v.detach() scale = self .head_dim ** -0.5 scores = torch.matmul(q, k.transpose(-2 , -1 )) * scale if L > 1 : mask = torch.tril(torch.ones(L, k.size(2 ), device=x.device)) scores = scores.masked_fill(mask[-L:] == 0 , float ('-inf' )) attn = torch.softmax(scores, dim=-1 ) out = torch.matmul(attn, v) out = out.transpose(1 , 2 ).contiguous().view(B, L, self .d_model) return self .W_o(out) def clear_cache (self ): self .k_cache = None self .v_cache = None model = CachedAttention(d_model=512 , num_heads=8 ) model.eval () prompt = torch.randn(1 , 10 , 512 ) with torch.no_grad(): out = model(prompt, use_cache=True ) print (f"Prefill: cache size = {model.k_cache.shape} " ) for step in range (5 ): new_token = torch.randn(1 , 1 , 512 ) with torch.no_grad(): out = model(new_token, use_cache=True ) print (f"Step {step+1 } : cache size = {model.k_cache.shape} " )
1.6 Special Tokens 与 Chat Templates 现代 LLM 使用特殊 token 来标记序列结构和控制生成行为。这些 token 在训练时就已定义,模型学会了它们的语义。
Token
用途
示例模型
<BOS> / <s>
序列开始
LLaMA, Mistral
<EOS> / </s>
序列结束,停止生成
LLaMA, Mistral
<PAD>
填充到固定长度
BERT, T5
<UNK>
未知 token(OOV 处理)
多数模型
[INST]…[/INST]
指令标记
LLaMA-2-Chat
<|im_start|> / <|im_end|>
Chat 消息边界
ChatML format
Chat Template 将多轮对话转换为模型可理解的 token 序列,不同模型格式不同:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 """ <s>[INST] <<SYS>> You are a helpful assistant. <</SYS>> What is attention? [/INST] Attention is a mechanism that... </s> <s>[INST] Can you explain more? [/INST] """ """ <|im_start|>system You are a helpful assistant.<|im_end|> <|im_start|>user What is attention?<|im_end|> <|im_start|>assistant Attention is a mechanism that...<|im_end|> """ from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf" ) messages = [ {"role" : "system" , "content" : "You are a helpful assistant." }, {"role" : "user" , "content" : "What is Flash Attention?" }, ] formatted = tokenizer.apply_chat_template(messages, tokenize=False )print (formatted) inputs = tokenizer.apply_chat_template(messages, return_tensors="pt" )
重要 :错误的 chat template 会导致模型输出质量严重下降。每个模型都有期望的格式,必须严格遵守。使用 tokenizer.apply_chat_template() 是最安全的方式。
角色划分与 Attention Mask :在 chat 场景中,System 是全局上下文(所有后续 token 都能 attend 到)、User 是用户输入、Assistant 是模型输出。在 SFT 中通常只对 assistant 部分的 token 计算交叉熵 loss,system 和 user 部分作为 context 但不参与 loss 计算。
练习题
练习 1:手动计算 Scaled Dot-Product Attention
给定矩阵( ):Q = [[1,0],[0,1]],K = [[1,0],[0,1],[1,1]],V = [[1,0],[0,1],[1,1]]。请计算 Attention(Q, K, V)(不使用 causal mask)。
完整解答 :
1 2 3 4 5 6 7 8 9 Step 1: QK^T = [[1,0,1], [0,1,1]] Step 2: 缩放 (÷√2) = [[0.707, 0, 0.707], [0, 0.707, 0.707]] Step 3: Softmax Row 1: [0.401, 0.198, 0.401] Row 2: [0.198, 0.401, 0.401] Step 4: Weights @ V Row 1: [0.802, 0.599] Row 2: [0.599, 0.802] 最终结果 ≈ [[0.802, 0.599], [0.599, 0.802]]
练习 2:Flash Attention IO 复杂度分析
设 N=4096, d=128, SRAM M=100KB(约 50K 个 FP16 元素)。
标准 Attention HBM 访问量 :读 = 3Nd + 2N² ≈ 35.1M,写 = 2N² + Nd ≈ 34.1M,合计 4N² + 4Nd ≈ 69.2M 元素 × 2 bytes ≈ 132 MB。
Flash 配置 : ,取 ;外/内循环各 次。
Flash HBM 访问量 :≈ N²d/B_c = 4096²×128/64 ≈ 33.5M 元素,相比标准节省约 52%。更重要的是不需要 O(N²) 的额外内存——标准需 32MB,Flash 只需 O(N) ≈ 几 KB。
练习 3:估算 KV Cache 内存占用(LLaMA-3-70B + GQA)
配置:80 层、64 query heads(GQA,KV heads=8)、head_dim=128、FP16。batch=32, seq=8192。
1 2 3 4 5 6 7 8 num_layers, num_kv_heads, head_dim = 80 , 8 , 128 seq_len, batch_size, dtype_bytes = 8192 , 32 , 2 per_token = 2 * num_layers * num_kv_heads * head_dim * dtype_bytes per_request = per_token * seq_len total = per_request * batch_size total_mha = total * (64 / 8 ) total_fp8 = total / 2
结论 :GQA (8 KV heads) 使 KV Cache 从 640GB 降至 80GB(省 87.5%);FP8 量化再降至 40GB。这也是 PagedAttention (vLLM) 等技术重要的原因。
练习 4(Bonus):为什么 Prefill 是 compute-bound,Decode 是 memory-bound?
Arithmetic Intensity = FLOPs / Bytes Accessed。
Prefill :N 个 token 同时处理,做大 GEMM。QKV 投影 AI ≈ 3000 FLOPs/byte,Attention AI ≈ 2730 FLOPs/byte,都 >> H100 的计算/带宽比 (~295) → Compute-bound 。
Decode :仅 1 个新 token,但需读整个 KV Cache。FLOPs ≈ 67M,读取 ≈ 64MB → AI ≈ 1.05 FLOPs/byte << 295 → Memory-bound ,GPU 大部分时间在等 HBM 读 KV Cache。
两者 AI 差了约 3000 倍!这就是 LLM serving 区分 prefill / decode(continuous batching、chunked prefill)的原因。
Coding Exercise A:从零实现 Multi-Head Attention
要求:实现 Q/K/V 投影 → split heads → scaled dot-product attention → concat → output projection,支持 causal mask,并与 torch.nn.MultiheadAttention 对比验证(误差 < 1e-5)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathclass MyMultiHeadAttention (nn.Module): def __init__ (self, d_model, num_heads ): super ().__init__() assert d_model % num_heads == 0 self .d_model = d_model self .num_heads = num_heads self .head_dim = d_model // num_heads self .W_q = nn.Linear(d_model, d_model, bias=True ) self .W_k = nn.Linear(d_model, d_model, bias=True ) self .W_v = nn.Linear(d_model, d_model, bias=True ) self .W_o = nn.Linear(d_model, d_model, bias=True ) def forward (self, x, causal_mask=False ): B, N, D = x.shape Q = self .W_q(x).view(B, N, self .num_heads, self .head_dim).transpose(1 , 2 ) K = self .W_k(x).view(B, N, self .num_heads, self .head_dim).transpose(1 , 2 ) V = self .W_v(x).view(B, N, self .num_heads, self .head_dim).transpose(1 , 2 ) scores = torch.matmul(Q, K.transpose(-2 , -1 )) / math.sqrt(self .head_dim) if causal_mask: mask = torch.triu(torch.ones(N, N, device=x.device), diagonal=1 ).bool () scores = scores.masked_fill(mask, float ('-inf' )) attention_weights = F.softmax(scores, dim=-1 ) attention_output = torch.matmul(attention_weights, V) output = attention_output.transpose(1 , 2 ).contiguous().view(B, N, D) return self .W_o(output)def verify_implementation (): torch.manual_seed(42 ) d_model, num_heads = 64 , 8 x = torch.randn(2 , 10 , d_model) my_mha = MyMultiHeadAttention(d_model, num_heads).eval () official_mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True ).eval () with torch.no_grad(): official_mha.in_proj_weight.copy_( torch.cat([my_mha.W_q.weight, my_mha.W_k.weight, my_mha.W_v.weight], dim=0 )) official_mha.in_proj_bias.copy_( torch.cat([my_mha.W_q.bias, my_mha.W_k.bias, my_mha.W_v.bias], dim=0 )) official_mha.out_proj.weight.copy_(my_mha.W_o.weight) official_mha.out_proj.bias.copy_(my_mha.W_o.bias) my_output = my_mha(x, causal_mask=False ) official_output, _ = official_mha(x, x, x) max_diff = (my_output - official_output).abs ().max ().item() print (f"最大差异: {max_diff:.2 e} " ) assert max_diff < 1e-5 print ("✓ 验证通过!" ) verify_implementation()
Coding Exercise B:实现 KV Cache 的增量 Decode
要求:实现 prefill(处理完整 prompt)+ decode(逐 token,复用 cached K/V),并实现无 cache baseline,验证两者输出 bit-identical 且展示加速。提示:decode 时只计算新 token 的 Q,与拼接后的完整 K/V 做 attention;cache shape 为 (batch, num_heads, past_len, head_dim)。
本文是 ML Systems 系列 Chapter 1。正文 markdown 渲染,5 个交互动画通过自定义 {% anim %} 标签以隔离 iframe 嵌入,源自 Arkive 教程。