CUTLASS CUTE 7 hpc-ops 之 group gemm & tma¶
约 3080 个字 260 行代码 1 张图片 预计阅读时间 19 分钟
本笔记目的在于学习 Tencent/hpc-ops: High Performance LLM Inference Operator Library 中对 fp8 kernel 的高效实现。hpc-ops 的特点之一也在其项目 readme 当中提到
A Modern CUDA Tutorial: Hands-on examples of building SOTA kernels with CuTe and CUTLASS in just hundreds of lines.
其代码虽然非常精简,但是包含了不少的优化技巧。我将分两篇博客来整理,本篇博客将主要介绍什么是 group gemm,以及 tma 在 group gemm 当中的应用技巧。在第二篇将介绍 hpc-ops 在 group gemm 上的实现细节,包含:transposed mma, scheduler, scale for dequant。博客主要帮助大家理解代码,会涉及到不少代码细节,所以建议配合代码一起阅读
Group GEMM 详解¶
What's A Group Gemm?¶
要理解 Group GEMM,我们先从普通的 GEMM(通用矩阵乘法)开始。
普通 GEMM 计算的是:
其中:
- X 形状为 [M, K](输入激活)
- W 形状为 [N, K](权重)
- Y 形状为 [M, N](输出)
Group GEMM 是 GEMM 的扩展,它在一次操作中执行多个独立的矩阵乘法。可以把它想象成"批量处理"不同组的矩阵乘法。
在 Group GEMM 中:
- 我们有
num_group个独立的权重矩阵W_0, W_1, ..., W_{num_group-1} - 输入 X 被分割成
num_group个连续的组X_0, X_1, ..., X_{num_group-1} - 每个
X_i与对应的W_i相乘 - 所有结果拼接起来得到最终输出
Y
使用简介的代码语言来表示
# X (M, K)
# W (num_group, N, K)
# Y (M, N)
# M is the length of all tokens from different groups
Y[start_i : end_i, :] = X[start_i : end_i, :] * W[i, :, :]^T
# A torch version
def naive_group_gemm_pertensor_fp8(x, w, seqlens, cu_seqlens, scale):
# Step 1: get tensor shapes
m, k = x.shape # m = total_seq, k = hidden_size
num_group, n, _ = w.shape # n = output_dim
# Step 2: init output tensor
y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device)
# Step 3: iterate each group and compute independent GEMM
start_idx = 0
for i in range(num_group):
# get start and end positions of current group
start_idx = int(cu_seqlens[i].item())
end_idx = int(start_idx + seqlens[i].item())
# skip empty groups
if seqlens[i].item() == 0:
continue
# extract input and weight for current group
x_group = x[start_idx:end_idx] # shape: [seqlens[i], k]
w_group = w[i] # shape: [n, k]
# perform matmul (with scaled FP8)
y_group = torch._scaled_mm(
x_group, w_group.t(), # w_group.t() shape: [k, n]
scale_a=scale, scale_b=scale,
bias=None, out_dtype=torch.bfloat16
)
# write result to corresponding output region
y[start_idx:end_idx] = y_group
return y
Why Group Gemm?¶
Group GEMM 的主要应用场景是混合专家模型(Mixture of Experts, MoE)的推理。如果不使用 Group GEMM,我们需要把输入 X 按照专家分组拆开,对每个专家分别调用一次 GEMM,最后再把结果拼回去
这样做的问题:
- 多次 kernel launch 开销:每个专家都需要一次独立的 kernel launch,launch 本身有固定开销
- 分散的内存访问:每个小 kernel 独立访问内存,难以形成高效的流水线访问模式。例如:前一个 group 的最后一部分数据在进行计算时,就可以开始预取下一个 group 的数据了,但独立的 kernel launch 无法做到这一点。另外,GroupGemm 最后也不需要再对各个 group 的计算结果进行拼接,减少数据读写
这么看来 GroupGemm 的优势就明显了:一次 kernel launch + 高效的内存流水线访问模式
GroupGemm in hpc-ops Overview¶
本次学习的 kernel 代码在 src/group_gemm/kernels.cuh,这部分代码包含了对 group gemm pertensor & blockwise 进行了优化实现。我们先从简单的 pertensor group gemm 开始,当我们吃透了这部分代码过后,再来看下 blockwise 的实现有什么细节上的更改
接下来是针对 pertensor group gemm 算法的精髓总结:
- Warp Specialization: 384 线程 = 256 线程 (数学计算 warpgroup) + 128 线程 (数据加载 warpgroup)
- Producer-Consumer Model: 使用 mbarrier 在阶段之间进行同步和流水线控制
- TMA for Groups: 为每个 group 预配置独立的 TMA descriptor,实现高效数据搬运(而不是所有的 group 使用一个 tma descriptor)
- Adaptive Tile Sizing (自适应 Tile 大小): 根据每组的平均序列长度选择 kTileM (16/32/48/64)
- Dual Scheduling Modes (双调度模式): Horizontal 模式(小矩阵用线性扫描)vs Vertical 模式(大矩阵用二分查找)
前两个算法算是老生常谈了,是 GEMM 算法中的基本。后面三个优化就是 hpc-ops 中的核心。其中我想提前提下第 4 点,对于 Tile Size Configuration,kernel 会根据 num_seq_per_group_avg(每组平均序列长度)选择不同的 tile 大小:
| avg_seqlen | kTileM | kTileN | kTileK | kStage |
|---|---|---|---|---|
| ≤16 | 16 | 128 | 128 | 8 |
| ≤32 | 32 | 128 | 128 | 8 |
| ≤48 | 48 | 128 | 128 | 8 |
| >48 | 64 | 128 | 128 | 8 |
这样可以确保在不同的序列长度分布下都有良好的硬件利用率。这里的 Tile 看上去很奇怪,一般来说 Tensor Core 都是固定矩阵乘当中的 M 维度,i.e. kTileM = 128,而在这里确是固定了 kTileN = 128,这是因为 hpc-ops 对于 mma 进行了转置处理,这是一个非常巧妙的用法。这样的转置一下子就让 M 维度的粒度变得非常细,对于小 M 的场景非常友好
Pseudocode with Producer-Consumer Structure¶
下面是一个简洁的伪代码,帮助我们抓住整体的算法流程。其中隐藏了对 schduler & tma & mbarrier & mma 等模块的大量细节,但不妨碍我们理解这个 producer-consumer gemm 的核心思想
__global__ void group_gemm_pertensor_fp8_kernel(...) {
// ===== PRELOGUE =====
int idx = threadIdx.x;
bool is_producer = (idx >= 256); // 128 threads for load
bool is_consumer = (idx < 256); // 256 threads for math
extern __shared__ uint8_t shm_data[];
auto* shm_a = (Tin*)shm_data; // [kTileM, kTileK, kStage]
auto* shm_b = shm_a + ...; // [kTileN, kTileK, kStage]
int* shm_tiles = (int*)(shm_data + ...); // For scheduler
// Initialize mbarriers (producer-consumer sync)
if (is_leader) {
for (int s = 0; s < kStage; s++) {
initialize_barrier(readable[s], 1); // Consumer waits on this
initialize_barrier(writable[s], num_mma_warpgroup); // Producer waits on this
}
}
// Load scheduler metadata to shared memory
for (int i = idx; i < num_group; i += 384)
shm_tiles[i] = tiles_ptr[i];
__syncthreads();
// ===== MAINLOOP: Producer-Consumer Pipeline =====
int phase = 0;
if (is_producer && is_leader_in_load) {
// PRODUCER: Load data via TMA
int s_write = 0; // Current stage to write
while (true) {
// SCHEDULER: Get next tile (igroup, itile_m, itile_n)
if (!get_next_tile(shm_tiles, iblock, ...)) break;
for (int k = 0; k < ntile_k; k++) {
// MBARRIER: Wait for consumer to release stage
wait_barrier(writable[s_write], phase);
// TMA load X and W tiles to shared memory
tma_copy(shm_a[_, _, s_write], global_X[igroup, itile_m, k]);
tma_copy(shm_b[_, _, s_write], global_W[igroup, itile_n, k]);
// MBARRIER: Signal consumer data is ready
set_barrier_transaction_bytes(readable[s_write], ...);
// Circular stage buffer
s_write = (s_write + 1) % kStage;
if (s_write == 0) phase ^= 1;
}
}
}
if (is_consumer) {
// CONSUMER: Compute via GMMA
int s_read = 0; // Current stage to read
while (true) {
// SCHEDULER: Same as producer, get next tile
if (!get_next_tile(shm_tiles, iblock, ...)) break;
for (int k = 0; k < ntile_k; k++) {
// MBARRIER: Wait for producer to fill stage
wait_barrier(readable[s_read], phase);
// GMMA: Compute on shared memory data
gemm(shm_a[_, _, s_read], shm_b[_, _, s_read], accum);
// MBARRIER: Signal producer stage is consumed
if (is_leader_in_warpgroup)
arrive_barrier(writable[s_read]);
// Circular stage buffer
s_read = (s_read + 1) % kStage;
if (s_read == 0) phase ^= 1;
}
// ===== EPILOGUE =====
cast_to_bf16(accum, output, pertensor_scale);
tma_store(output, global_Y[igroup, itile_m, itile_n]);
}
}
}
下面我们将对这些核心优化进行逐个分析,把他们的原理和实现解释清楚。
TMA for Group Gemm¶
在 hpc-ops 的 Group GEMM 实现中,TMA(Tensor Memory Accelerator)的使用非常巧妙。不同于普通 GEMM 中使用单一 TMA descriptor,这里为每个 group 预配置了独立的 TMA descriptor。这一节我们来详细分析这个设计。
Why update_grouped_tma?¶
核心原因就是每个 group 的数据位置不同:X 张量在全局内存中是连续存储的 [total_seq, k],第 igroup 个 group 的起始位置是 x_ptr + cu_seqlens[igroup] * k
如果我们只有一个 tma descriptor,则只能按照这个 tma 的 gmem coord + copy box offset 的方式进行 copy。虽然 gmem coord 可以是任意设置的,但是在实际使用中,我们习惯使用 TiledCopy partition 对 gmem tensor 进行划分,此时 partitioned tensor 是 copy box 对齐的。对于 group gemm 来说,每个 group 的数据起始位置不可能都正好在 copy box offset 中。因此我们有两个选项:1. 把原始数据 Padding 为 copy box aligned 结构,这样每一个 group 都能和 copy box offset 对齐;2. 给每一个 group 都配置一个独立的 tma descriptor,这样每个 group 的数据都能按照自己的起始位置进行 copy
update_grouped_tma 代码解读¶
首先从 kernel launch 配置出发,来看下整体的问题划分
constexpr int kGroupPerThread = 8;
constexpr int kThreadPerBlock = 32;
kernels::update_grouped_tma<...>
<<<num_group + 1, kThreadPerBlock, 0, stream>>>(...);
- Grid/Block 配置:
num_group + 1个 block,每个 block 32 个线程 - Block 分工:
- Block
0 ~ num_group-1:每个 block 处理一个 group,更新该 group 的 X 和 Y 的 TMA descriptor - Block
num_group:计算所有 group 的 tile 统计信息
- Block
总结:update_grouped_tma 在一个 kernel 中完成两件事:为每个 group 更新独立的 TMA descriptor,同时用 BlockScan 计算 tile 统计信息供后续 scheduler 使用。下面是完整的伪代码:
template <typename Tin, typename Tout, typename TmaX, typename TmaY,
int kTileM, int kGroupPerThread, int kThreadPerBlock>
__global__ void update_grouped_tma(
const vec_t<TmaDescriptor, 2> td_xy, // template TMA descriptors (X, Y) from host
TmaDescriptor *tma_xy, // [num_group * 2]: X desc at 2*i, Y desc at 2*i+1
const Tin *x_ptr, // X (activation) data
const Tout *y_ptr, // Y (output) data
const int *seqlens_ptr, // [num_group]: seqlen per group
const int *cu_seqlens_ptr, // [num_group + 1]: cumulative seqlen
int *tiles_ptr, // [num_group]: tile M count per group
int *cu_tiles_ptr, // [num_group + 1]: cumulative tile M count
int num_group, int m, int n, int k) {
int idx = threadIdx.x;
int igroup = blockIdx.x;
if (igroup == num_group) {
// ---- Case 1: Compute tile statistics (last block) ----
int tiles[kGroupPerThread];
// Step 1: Each thread computes tile counts for its assigned groups
for (int i = 0; i < kGroupPerThread; i++) {
int g = idx * kGroupPerThread + i;
if (g < num_group) {
tiles[i] = (seqlens_ptr[g] + kTileM - 1) / kTileM;
tiles_ptr[g] = tiles[i];
} else {
tiles[i] = 0;
}
}
// Step 2: Exclusive scan to compute cumulative tile M counts
using BlockScan = cub::BlockScan<int, kThreadPerBlock>;
__shared__ typename BlockScan::TempStorage temp_storage;
int block_aggregate;
BlockScan(temp_storage).ExclusiveSum(tiles, tiles, block_aggregate);
// Step 3: Write cumulative results to global memory
for (int i = 0; i < kGroupPerThread; i++) {
int g = idx * kGroupPerThread + i;
if (g < num_group) {
cu_tiles_ptr[g] = tiles[i];
}
}
if (idx == 0) {
cu_tiles_ptr[num_group] = block_aggregate; // total tile M
}
} else {
// ---- Case 2: Update TMA descriptors for group igroup ----
__shared__ TmaDescriptor smem_tma_desc[2];
int num_seq = seqlens_ptr[igroup];
int cu_seqlen = cu_seqlens_ptr[igroup];
// Step 1: Copy template descriptors to shared memory
if (idx < 2) {
smem_tma_desc[idx] = td_xy[idx];
}
__syncwarp();
// Step 2: Thread 0 — update X (activation) TMA descriptor
if (idx == 0) {
auto gX = make_tensor(
make_gmem_ptr(x_ptr + cu_seqlen * k),
make_shape(num_seq, k),
make_stride(k, Int<1>{})); // stride unchanged: same k for all groups
update_tma_gtensor<TmaX>(smem_tma_desc[0], gX);
}
// Step 3: Thread 1 — update Y (output) TMA descriptor
if (idx == 1) {
auto gY = make_tensor(
make_gmem_ptr(y_ptr + cu_seqlen * n),
make_shape(n, num_seq),
make_stride(Int<1>{}, n)); // stride unchanged: same n for all groups
update_tma_gtensor<TmaY>(smem_tma_desc[1], gY);
}
// Step 4: Commit smem writes, then copy descriptors from smem to gmem
for (int i = 0; i < 2; i++) {
__syncwarp();
if (elect_one_sync()) {
tma_desc_commit_group();
tma_desc_wait_group();
}
tma_descriptor_cp_fence_release(tma_xy + igroup * 2 + i, smem_tma_desc[i]);
}
}
}
在 Group Gemm 中使用 TMA¶
TMA for X (activation)¶
在 hpc-ops 当中,其使用 tma 的方式是
此时,可以认为 copy 所使用的 tma descriptor 就不是 tma_origin 中原来在 Host 端定义的 tma descriptor 了,而是我们的 new_tma_descriptor。其 gmem ptr 和 shape 都发生了改变,以适应 group gemm 当中不同 group 的 activation 数据搬运。由于 tma descriptor 的改变,我们的 gmem coord 也需要进行相应的适配。在 hpc-ops 当中的 copy 代码如下
// itile_m is not a global idx, it is relative to the group
cute::copy(tma_a.with(td_x, readable[ismem_write]), tAg(_, itile_m, itile_k), tAs(_, 0, 0, ismem_write));
在代码中只有一个 partitioned tensor tAg,但是所有的 group 都使用这个 tAg,这是正确的吗?其实这是一个 coordinate tensor,我们只需要填入正确的 coordinate 即可。所以对于 itile_m,其 tma descriptor 更新为了 td_x,我们需要计算当前的 m coordinate 是相对于该 group 的首地址的偏移量,而不是全局的偏移量即可。这一点我在之后的 scheduler 笔记当中也会再此提到
TMA for W (weight)¶
对于权重来说,其维度是三维的 (n, k, num_group)。相应的,我们在定义 tiled copy 时所使用的 weight gmem tensor 也是三维的,不过需要注意的是:所使用的 copy box 却是二维的
using SLayoutW = decltype(tile_to_shape(SLayoutWAtom{},
make_shape(Int<kTileN>{}, Int<kTileK>{}, Int<kStage>{})));
// w is 3-dim tensor, but copy box is 2-dim
auto tma_w = make_tma_copy(SM90_TMA_LOAD{}, w, take<0, 2>(SLayoutW{}));
我之前的理解是:tma 在搬运 tensor 的时候是根据首坐标 + box dim 来确定搬运数据的范围。此时首坐标是 3D 的,box dim 是 2D 的,这似乎挑战了我之前的理解。不过回答也非常简单,现在的数据范围是根据 3D 中的前 2D 坐标 + box dim 来确定的。合理猜测,如果我们把 gmem 维度顺序变成 (num_group, n, k),但 slayout 保持不变,那么:
- slayout 第 0 维 (kTileN) → gmem 第 0 维 (num_group)
- slayout 第 1 维 (kTileK) → gmem 第 1 维 (n)
copy box 就会沿着 num_group 和 n 维度进行 copy,这显然不是我们想要的结果
实现细节补充¶
以下是对上述代码中涉及的关键原语的补充说明。
BlockScan¶
cub::BlockScan 是一个并行前缀和计算原语。这里使用的是 Exclusive Sum Scan:
Exclusive Scan 是一种并行计算原语,对数组进行前缀和计算,但每个位置的结果是该位置之前所有元素的和。
示例:
输入: [a, b, c, d]
输出: [0, a, a+b, a+b+c] ← exclusive sum
总和: a+b+c+d ← block_aggregate
对比 Inclusive Scan:
输入: [a, b, c, d]
输出: [a, a+b, a+b+c, a+b+c+d] ← inclusive sum
可以从 hpc-ops 中的代码代表了 block scan 的一般用法
// Line 88: define BlockScan type
using BlockScan = cub::BlockScan<int, kThreadPerBlock>;
// - template param 1: int - data type for scan
// - template param 2: kThreadPerBlock = 32 - threads per block
// Line 89: allocate shared memory
__shared__ typename BlockScan::TempStorage temp_storage;
// - TempStorage is a struct defined internally by cub
// - shared memory needed to coordinate thread communication
// - size computed automatically by cub
// Line 90: used to return total sum
int block_aggregate;
// Line 91: execute Exclusive Sum Scan
BlockScan(temp_storage).ExclusiveSum(tiles, tiles, block_aggregate);
// Parameter description:
// - tiles (input): data array contributed by each thread
// - tiles (output): scan result (modified in-place)
// - block_aggregate: returns total sum of the block
TMA Descriptor 更新¶
当 blockIdx.x < num_group 时,为该 group 更新 TMA descriptor。注意,我们不是在 device 端从头创建 TMA descriptor,而是:
- Host 端创建模板:
td_xy包含了正确的 stride、tile size 等配置 - Device 端只更新必要字段:
- 全局内存地址:指向该 group 数据的起始位置
- Shape:根据该 group 的 seqlen 设置,例如对于输入 X (activation),其第 i 个 group 的 tma gmem tensor shape 应设置为
[seqlens_ptr[igroup], k]
为什么 stride 不需要更新?
| 张量 | Shape | Stride | 是否变化 |
|---|---|---|---|
| X | [num_seq, k] | (k, 1) | k 固定,所有 group 一样 |
| Y | [n, num_seq] | (1, n) | n 固定,所有 group 一样 |
k是隐藏层维度(hidden_size),对所有 group 相同n是输出维度(output_dim),对所有 group 相同- 只有
num_seq变化(seqlens_ptr[igroup])
update_tma_gtensor 的作用
该 device function 是更新 tma descriptor 的核心。会从 gmem tensor 中提取 shape & stride & gmem ptr,然后把这些信息更新到 TMA descriptor 中。我一开始还有疑问:为什么一定要用 shared memory 创建 cuTensorMap?虽然我之前了解到 tma 存储的信息都是放在 smem 当中的,但是我们仍然可以把这些信息放到寄存器当中,然后修改,最后再存回 gmem 当中呀。后来 agent 了解到 tma_descriptor_replace_shapes_in_shared_mem 该 PTX 要求操作源必须在 shared memory 当中,所以必须使用 smem
tma_desc_commit_group¶
在我们的代码中,同一个 warp 中的不同线程在修改不同的 TMA descriptor:
- 线程 0 更新
smem_tma_desc[0](X) - 线程 1 更新
smem_tma_desc[1](Y)
这时候需要用 tma_desc_commit_group 来确保 warp 中所有线程对 TMA descriptor 的修改都完成并且可见。这里的 PTX 和 tma_store_fence 是一样的,我们之前使用 tma_store_fence 是为了确保 tma store 操作必须要在 smem 写入完成之后。在这里起到同样的作用,因为我们之后要把修改好的 smem 内容写回到 gmem 中存储的 cuTensorMap 当中,必须要保证所有的 smem 写入完成才发起该操作。
我们在之前进行了 update_tma_tensor 的操作,其实这是一个在 async proxy 发起的对 smem 上的修改。所以我们必须要保证整个修改的完成,才能在之后进行 tma_descriptor_cp_fence_release,将 smem 当中的 descriptor copy 到 gmem 当中。此时就需要一个 fence 来保证顺序,而 commit & wait 就是不错的选择。这里仅使用了一个 thread 来进行 commit & wait,实际上由于每一个 CTA 只有一个 warp,这里单个 thread wait,其他 thread 并不能够偷跑到下一个代码进行整形,而是也在空转等待
tma_descriptor_cp_fence_release¶
这个函数做两件事(fused copy + fence):
- Copy:把 128 字节的 TMA descriptor 从 shared memory 拷贝到 global memory
- Fence:带
release语义的内存屏障。此屏障的作用是:确保之后使用 tma 的操作,都必须在该写入操作完成之后执行。可以想象为,这个 release fence 把之前的所有写代码都拦住了,编译器不可能把他们重排到这个 fence 之后。还有另一种带acquire语义的内存屏障,它会保证之后的所有读操作都必须在该读操作完成之后执行。这也是为什么acquire & release通常成对出现,我查阅了下tma_store_fence它到底属于 acquire 还是 release 呢?我认为答案应该是 both!我们既不能让 smem 写操作跨越该 fence,也不让 tma store 操作跨越该 fence
与主 kernel 配对使用:
Producer(update_grouped_tma):
tma_descriptor_cp_fence_release(tma_xy + i, smem_tma_desc[i]);
// "Release": ensure all prior writes are visible
Consumer(主 kernel):