GEMM Tiling:矩阵乘法如何拆到 Tensor Core

矩阵乘法(GEMM)是深度学习里占绝大多数算力的算子。本文从一个具体的 C[64, 24576] = A[64, 8192] × B[8192, 24576] 出发,一层层拆解它如何映射到 GPU 的 Grid / Block / Warp / Tensor Core,以及为什么 Decode 阶段会浪费算力。正文为 markdown,关键机制配有可交互动画。

1. 全局视角:分块(Tiling)的动机

考虑一个典型的 LLM 前向 GEMM:

其中 (batch)、(hidden)、(output)。整块矩阵无法放进任何一级片上存储,必须切成小块(tile),让每个计算单元只处理一块。核心思想是 output-stationary tiling:把输出 切成若干 tile,每个 tile 由一个线程块独立负责,沿 维循环累加。

下面的动画用一个缩小的 8×8 = 8×8 × 8×8、block 大小 2×2×2 的例子,逐步演示这套 tiling 调度:每个 program(线程块)负责一个 的输出 tile,沿 循环把 的 tile 从 Global Memory 读进寄存器累加,K-loop 结束后才把累加结果写回:

对应的 Triton 风格伪代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
A_ptr, B_ptr, C_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
# 每个 program 负责一个输出 tile C[pid_m, pid_n]
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)

# 累加器留在寄存器中,FP32 精度
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

# 沿 K 维循环:每次读一块 A、一块 B,累加
for k in range(0, K, BLOCK_K):
a = tl.load(A_ptr + offs_m[:, None] * stride_am + (k + offs_k)[None, :] * stride_ak)
b = tl.load(B_ptr + (k + offs_k)[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc += tl.dot(a, b) # Tensor Core MMA

# K-loop 结束后才写回 Global Memory
c_ptrs = C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc)

关键点:累加器 acc 全程留在寄存器里,整个 K-loop 期间不写回 Global Memory,只在最后一次性 tl.store。这把 的写入次数从 '_' allowed only in math modeK/\text{BLOCK_K} 次降到 1 次。

2. 层级拆解:从 Grid 到 Tensor Core

回到真实尺寸的 GEMM,GPU 用四个层级把它拆开执行。下面的动画分场景演示整个层级:Grid 切空间 → Block 切 K + 加载 SRAM → Warp 发指令 → Tensor Core 算 16×16×16,最后展示内存层级和 Decode/Prefill 的差异。

① Grid 层:按输出空间切 Block

把输出 沿 切成 192 个 Block,每块 。每个 Block 需要 的全部行 对应的 128 列 。这 192 个 Block 被分配到 132 个 SM(以 H100 为例)上并行执行,Block 之间零通信

② Block 层:加载 SRAM + K 循环

每个 Block 由一个 SM 上的 256 个线程(= 8 个 Warp)执行。256 个线程协作,从 HBM 搬一块 (2KB)和 (4KB)到 SRAM(Shared Memory),8 个 Warp 共享读取。然后沿 维循环 512 次()累加,每次步长 —— 之所以是 16,是因为 Tensor Core 的 mma.K = 16

③ Warp → Register → Tensor Core

每个 Warp 用 wmma::load_matrix_sync 把 SRAM 里的 tile 加载成 16×16 的 fragment 放进寄存器,再用 wmma::mma_sync 喂给 Tensor Core。结果累加到寄存器里的 accumulator 留在寄存器不写回内存

④ Tensor Core:16×16×16 = 8192 FLOPs

Tensor Core 的一条 MMA 指令算 ,即 FLOPs,只需 1-2 个 cycle。物理上是 4×4 FMA 阵列 × 4 个 TC,硬件拆成 64 个子操作,但 ISA 把它封装成 16×16,对程序员透明。输入用 FP16,累加用 FP32。

3. 内存层级:HBM 是唯一瓶颈

数据流是 HBM → SRAM → Register → Tensor Core,越往下越快越小:

层级 容量 带宽 存放内容
HBM 80 GB 3.35 TB/s 全部权重(如 384 MB)
SRAM (Shared Memory) 228 KB/SM ~30 TB/s A_tile + B_tile
Register 256 KB/SM ~60 TB/s Fragment 16×16
Tensor Core 989 TFLOPS 8192 FLOPs/指令

算术强度(Arithmetic Intensity) 决定瓶颈在哪:

也就是说,每从 HBM 读 1 字节,至少要做 295 次浮点运算才能喂饱 Tensor Core。

瓶颈:HBM → SRAM。Decode 阶段的 GEMV 算术强度只有 ~1 FLOP/Byte,远低于 295,于是 Tensor Core 99% 的时间都在等数据

解法:加大 Batch。batch=1 时权重读 1 次只用 1 次;batch=64 时读 1 次复用 64 次,TC 利用率从 ~6% 提升到接近 100%。

4. 为什么 Decode 慢:fragment 的 padding 问题

Tensor Core 的最小粒度是 16×16, 维度小于 16 时不足部分会被补零(padding)。硬件照常跑完 8192 FLOPs,但大部分是 无用功。

场景 有效行 效率
Decode:batch=1 1 / 16 行有数据 ~6%
Prefill:batch=64 16 / 16 行全是数据 100%

batch=1 时 fragment 15/16 行是零,TC 约 94% 在做无用功;加大 batch 填满这 16 行,每个 FLOP 才都是有效计算。

Continuous Batching 的本质:凑足够多的请求填满 fragment 的 16 行 —— 这就是为什么 LLM serving 要把多个请求的 token 拼到一起做大 GEMM,把内存受限的 Decode 转化为计算受限的高效计算。

5. 小结

  • Tiling 把放不下的大矩阵切成片上能容纳的小块,output-stationary 方案让累加器全程留在寄存器,只写回一次。
  • GPU 用 Grid(切输出空间)→ Block(切 K,加载 SRAM)→ Warp(发 MMA 指令)→ Tensor Core(16×16×16) 四级把 GEMM 拆开。
  • 数据流 HBM → SRAM → Register → TC,HBM 带宽是唯一瓶颈,机器平衡点约 295 FLOP/Byte。
  • 加大 batch 提高数据复用率与 fragment 填充率,是把 Decode 从 memory-bound 拉回 compute-bound 的关键,也是 Continuous Batching 的核心动机。

本文正文 markdown 渲染,2 个交互动画通过自定义 {% anim %} 标签以隔离 iframe 嵌入,源自 Arkive 教程。