GEMM Tiling:矩阵乘法如何拆到 Tensor Core
矩阵乘法(GEMM)是深度学习里占绝大多数算力的算子。本文从一个具体的 C[64, 24576] = A[64, 8192] × B[8192, 24576] 出发,一层层拆解它如何映射到 GPU 的 Grid / Block / Warp / Tensor Core,以及为什么 Decode 阶段会浪费算力。正文为 markdown,关键机制配有可交互动画。
1. 全局视角:分块(Tiling)的动机
考虑一个典型的 LLM 前向 GEMM:
其中
下面的动画用一个缩小的 8×8 = 8×8 × 8×8、block 大小 2×2×2 的例子,逐步演示这套 tiling 调度:每个 program(线程块)负责一个
对应的 Triton 风格伪代码:
1 | |
关键点:累加器
acc全程留在寄存器里,整个 K-loop 期间不写回 Global Memory,只在最后一次性tl.store。这把的写入次数从 次降到 1 次。
2. 层级拆解:从 Grid 到 Tensor Core
回到真实尺寸的 GEMM,GPU 用四个层级把它拆开执行。下面的动画分场景演示整个层级:Grid 切空间 → Block 切 K + 加载 SRAM → Warp 发指令 → Tensor Core 算 16×16×16,最后展示内存层级和 Decode/Prefill 的差异。
① Grid 层:按输出空间切 Block
把输出
② Block 层:加载 SRAM + K 循环
每个 Block 由一个 SM 上的 256 个线程(= 8 个 Warp)执行。256 个线程协作,从 HBM 搬一块 mma.K = 16。
③ Warp → Register → Tensor Core
每个 Warp 用 wmma::load_matrix_sync 把 SRAM 里的 tile 加载成 16×16 的 fragment 放进寄存器,再用 wmma::mma_sync 喂给 Tensor Core。结果累加到寄存器里的 accumulator
④ Tensor Core:16×16×16 = 8192 FLOPs
Tensor Core 的一条 MMA 指令算
3. 内存层级:HBM 是唯一瓶颈
数据流是 HBM → SRAM → Register → Tensor Core,越往下越快越小:
| 层级 | 容量 | 带宽 | 存放内容 |
|---|---|---|---|
| HBM | 80 GB | 3.35 TB/s | 全部权重(如 384 MB) |
| SRAM (Shared Memory) | 228 KB/SM | ~30 TB/s | A_tile + B_tile |
| Register | 256 KB/SM | ~60 TB/s | Fragment 16×16 |
| Tensor Core | — | 989 TFLOPS | 8192 FLOPs/指令 |
算术强度(Arithmetic Intensity) 决定瓶颈在哪:
也就是说,每从 HBM 读 1 字节,至少要做 295 次浮点运算才能喂饱 Tensor Core。
瓶颈:HBM → SRAM。Decode 阶段的 GEMV 算术强度只有 ~1 FLOP/Byte,远低于 295,于是 Tensor Core 99% 的时间都在等数据。
解法:加大 Batch。batch=1 时权重读 1 次只用 1 次;batch=64 时读 1 次复用 64 次,TC 利用率从 ~6% 提升到接近 100%。
4. 为什么 Decode 慢:fragment 的 padding 问题
Tensor Core 的最小粒度是 16×16,
| 场景 | 有效行 | 效率 |
|---|---|---|
| Decode:batch=1 | 1 / 16 行有数据 | ~6% |
| Prefill:batch=64 | 16 / 16 行全是数据 | 100% |
batch=1 时 fragment 15/16 行是零,TC 约 94% 在做无用功;加大 batch 填满这 16 行,每个 FLOP 才都是有效计算。
Continuous Batching 的本质:凑足够多的请求填满 fragment 的 16 行 —— 这就是为什么 LLM serving 要把多个请求的 token 拼到一起做大 GEMM,把内存受限的 Decode 转化为计算受限的高效计算。
5. 小结
- Tiling 把放不下的大矩阵切成片上能容纳的小块,output-stationary 方案让累加器全程留在寄存器,只写回一次。
- GPU 用 Grid(切输出空间)→ Block(切 K,加载 SRAM)→ Warp(发 MMA 指令)→ Tensor Core(16×16×16) 四级把 GEMM 拆开。
- 数据流 HBM → SRAM → Register → TC,HBM 带宽是唯一瓶颈,机器平衡点约 295 FLOP/Byte。
- 加大 batch 提高数据复用率与 fragment 填充率,是把 Decode 从 memory-bound 拉回 compute-bound 的关键,也是 Continuous Batching 的核心动机。
本文正文 markdown 渲染,2 个交互动画通过自定义 {% anim %} 标签以隔离 iframe 嵌入,源自 Arkive 教程。