Skip to content

CUTLASS CUTE 7 hpc-ops 之 group gemm & tma

约 3080 个字 260 行代码 1 张图片 预计阅读时间 19 分钟

本笔记目的在于学习 Tencent/hpc-ops: High Performance LLM Inference Operator Library 中对 fp8 kernel 的高效实现。hpc-ops 的特点之一也在其项目 readme 当中提到

A Modern CUDA Tutorial: Hands-on examples of building SOTA kernels with CuTe and CUTLASS in just hundreds of lines.

其代码虽然非常精简,但是包含了不少的优化技巧。我将分两篇博客来整理,本篇博客将主要介绍什么是 group gemm,以及 tma 在 group gemm 当中的应用技巧。在第二篇将介绍 hpc-ops 在 group gemm 上的实现细节,包含:transposed mma, scheduler, scale for dequant。博客主要帮助大家理解代码,会涉及到不少代码细节,所以建议配合代码一起阅读

Group GEMM 详解

What's A Group Gemm?

要理解 Group GEMM,我们先从普通的 GEMM(通用矩阵乘法)开始。

普通 GEMM 计算的是:

Text Only
Y = X * W^T

其中:

  • X 形状为 [M, K](输入激活)
  • W 形状为 [N, K](权重)
  • Y 形状为 [M, N](输出)

Group GEMM 是 GEMM 的扩展,它在一次操作中执行多个独立的矩阵乘法。可以把它想象成"批量处理"不同组的矩阵乘法。

在 Group GEMM 中:

  • 我们有 num_group 个独立的权重矩阵 W_0, W_1, ..., W_{num_group-1}
  • 输入 X 被分割成 num_group 个连续的组 X_0, X_1, ..., X_{num_group-1}
  • 每个 X_i 与对应的 W_i 相乘
  • 所有结果拼接起来得到最终输出 Y

使用简介的代码语言来表示

Python
# X (M, K)
# W (num_group, N, K)
# Y (M, N)
# M is the length of all tokens from different groups
Y[start_i : end_i, :] = X[start_i : end_i, :] * W[i, :, :]^T

# A torch version
def naive_group_gemm_pertensor_fp8(x, w, seqlens, cu_seqlens, scale):
    # Step 1: get tensor shapes
    m, k = x.shape           # m = total_seq, k = hidden_size
    num_group, n, _ = w.shape  # n = output_dim

    # Step 2: init output tensor
    y = torch.zeros((m, n), dtype=torch.bfloat16, device=x.device)

    # Step 3: iterate each group and compute independent GEMM
    start_idx = 0
    for i in range(num_group):
        # get start and end positions of current group
        start_idx = int(cu_seqlens[i].item())
        end_idx = int(start_idx + seqlens[i].item())

        # skip empty groups
        if seqlens[i].item() == 0:
            continue

        # extract input and weight for current group
        x_group = x[start_idx:end_idx]  # shape: [seqlens[i], k]
        w_group = w[i]                   # shape: [n, k]

        # perform matmul (with scaled FP8)
        y_group = torch._scaled_mm(
            x_group, w_group.t(),        # w_group.t() shape: [k, n]
            scale_a=scale, scale_b=scale,
            bias=None, out_dtype=torch.bfloat16
        )

        # write result to corresponding output region
        y[start_idx:end_idx] = y_group

    return y

Why Group Gemm?

Group GEMM 的主要应用场景是混合专家模型(Mixture of Experts, MoE)的推理。如果不使用 Group GEMM,我们需要把输入 X 按照专家分组拆开,对每个专家分别调用一次 GEMM,最后再把结果拼回去

这样做的问题:

  • 多次 kernel launch 开销:每个专家都需要一次独立的 kernel launch,launch 本身有固定开销
  • 分散的内存访问:每个小 kernel 独立访问内存,难以形成高效的流水线访问模式。例如:前一个 group 的最后一部分数据在进行计算时,就可以开始预取下一个 group 的数据了,但独立的 kernel launch 无法做到这一点。另外,GroupGemm 最后也不需要再对各个 group 的计算结果进行拼接,减少数据读写

这么看来 GroupGemm 的优势就明显了:一次 kernel launch + 高效的内存流水线访问模式

GroupGemm in hpc-ops Overview

本次学习的 kernel 代码在 src/group_gemm/kernels.cuh,这部分代码包含了对 group gemm pertensor & blockwise 进行了优化实现。我们先从简单的 pertensor group gemm 开始,当我们吃透了这部分代码过后,再来看下 blockwise 的实现有什么细节上的更改

接下来是针对 pertensor group gemm 算法的精髓总结:

  1. Warp Specialization: 384 线程 = 256 线程 (数学计算 warpgroup) + 128 线程 (数据加载 warpgroup)
  2. Producer-Consumer Model: 使用 mbarrier 在阶段之间进行同步和流水线控制
  3. TMA for Groups: 为每个 group 预配置独立的 TMA descriptor,实现高效数据搬运(而不是所有的 group 使用一个 tma descriptor)
  4. Adaptive Tile Sizing (自适应 Tile 大小): 根据每组的平均序列长度选择 kTileM (16/32/48/64)
  5. Dual Scheduling Modes (双调度模式): Horizontal 模式(小矩阵用线性扫描)vs Vertical 模式(大矩阵用二分查找)

前两个算法算是老生常谈了,是 GEMM 算法中的基本。后面三个优化就是 hpc-ops 中的核心。其中我想提前提下第 4 点,对于 Tile Size Configuration,kernel 会根据 num_seq_per_group_avg(每组平均序列长度)选择不同的 tile 大小:

avg_seqlen kTileM kTileN kTileK kStage
≤16 16 128 128 8
≤32 32 128 128 8
≤48 48 128 128 8
>48 64 128 128 8

这样可以确保在不同的序列长度分布下都有良好的硬件利用率。这里的 Tile 看上去很奇怪,一般来说 Tensor Core 都是固定矩阵乘当中的 M 维度,i.e. kTileM = 128,而在这里确是固定了 kTileN = 128,这是因为 hpc-ops 对于 mma 进行了转置处理,这是一个非常巧妙的用法。这样的转置一下子就让 M 维度的粒度变得非常细,对于小 M 的场景非常友好

Pseudocode with Producer-Consumer Structure

下面是一个简洁的伪代码,帮助我们抓住整体的算法流程。其中隐藏了对 schduler & tma & mbarrier & mma 等模块的大量细节,但不妨碍我们理解这个 producer-consumer gemm 的核心思想

C++
__global__ void group_gemm_pertensor_fp8_kernel(...) {

    // ===== PRELOGUE =====
    int idx = threadIdx.x;
    bool is_producer = (idx >= 256);  // 128 threads for load
    bool is_consumer = (idx < 256);    // 256 threads for math

    extern __shared__ uint8_t shm_data[];
    auto* shm_a = (Tin*)shm_data;          // [kTileM, kTileK, kStage]
    auto* shm_b = shm_a + ...;              // [kTileN, kTileK, kStage]
    int* shm_tiles = (int*)(shm_data + ...); // For scheduler

    // Initialize mbarriers (producer-consumer sync)
    if (is_leader) {
    for (int s = 0; s < kStage; s++) {
        initialize_barrier(readable[s], 1);  // Consumer waits on this
        initialize_barrier(writable[s], num_mma_warpgroup);  // Producer waits on this
    }
    }
    // Load scheduler metadata to shared memory
    for (int i = idx; i < num_group; i += 384)
    shm_tiles[i] = tiles_ptr[i];
    __syncthreads();

    // ===== MAINLOOP: Producer-Consumer Pipeline =====
    int phase = 0;

    if (is_producer && is_leader_in_load) {
    // PRODUCER: Load data via TMA
    int s_write = 0;  // Current stage to write
    while (true) {
        // SCHEDULER: Get next tile (igroup, itile_m, itile_n)
        if (!get_next_tile(shm_tiles, iblock, ...)) break;

        for (int k = 0; k < ntile_k; k++) {
        // MBARRIER: Wait for consumer to release stage
        wait_barrier(writable[s_write], phase);

        // TMA load X and W tiles to shared memory
        tma_copy(shm_a[_, _, s_write], global_X[igroup, itile_m, k]);
        tma_copy(shm_b[_, _, s_write], global_W[igroup, itile_n, k]);

        // MBARRIER: Signal consumer data is ready
        set_barrier_transaction_bytes(readable[s_write], ...);

        // Circular stage buffer
        s_write = (s_write + 1) % kStage;
        if (s_write == 0) phase ^= 1;
        }
    }
    }

    if (is_consumer) {
    // CONSUMER: Compute via GMMA
    int s_read = 0;  // Current stage to read
    while (true) {
        // SCHEDULER: Same as producer, get next tile
        if (!get_next_tile(shm_tiles, iblock, ...)) break;

        for (int k = 0; k < ntile_k; k++) {
        // MBARRIER: Wait for producer to fill stage
        wait_barrier(readable[s_read], phase);

        // GMMA: Compute on shared memory data
        gemm(shm_a[_, _, s_read], shm_b[_, _, s_read], accum);

        // MBARRIER: Signal producer stage is consumed
        if (is_leader_in_warpgroup)
            arrive_barrier(writable[s_read]);

        // Circular stage buffer
        s_read = (s_read + 1) % kStage;
        if (s_read == 0) phase ^= 1;
        }

        // ===== EPILOGUE =====
        cast_to_bf16(accum, output, pertensor_scale);
        tma_store(output, global_Y[igroup, itile_m, itile_n]);
    }
    }
}

下面我们将对这些核心优化进行逐个分析,把他们的原理和实现解释清楚。

TMA for Group Gemm

在 hpc-ops 的 Group GEMM 实现中,TMA(Tensor Memory Accelerator)的使用非常巧妙。不同于普通 GEMM 中使用单一 TMA descriptor,这里为每个 group 预配置了独立的 TMA descriptor。这一节我们来详细分析这个设计。

Why update_grouped_tma

核心原因就是每个 group 的数据位置不同:X 张量在全局内存中是连续存储的 [total_seq, k],第 igroup 个 group 的起始位置是 x_ptr + cu_seqlens[igroup] * k

如果我们只有一个 tma descriptor,则只能按照这个 tma 的 gmem coord + copy box offset 的方式进行 copy。虽然 gmem coord 可以是任意设置的,但是在实际使用中,我们习惯使用 TiledCopy partition 对 gmem tensor 进行划分,此时 partitioned tensor 是 copy box 对齐的。对于 group gemm 来说,每个 group 的数据起始位置不可能都正好在 copy box offset 中。因此我们有两个选项:1. 把原始数据 Padding 为 copy box aligned 结构,这样每一个 group 都能和 copy box offset 对齐;2. 给每一个 group 都配置一个独立的 tma descriptor,这样每个 group 的数据都能按照自己的起始位置进行 copy

update_grouped_tma 代码解读

首先从 kernel launch 配置出发,来看下整体的问题划分

C++
constexpr int kGroupPerThread = 8;
constexpr int kThreadPerBlock = 32;
kernels::update_grouped_tma<...>
    <<<num_group + 1, kThreadPerBlock, 0, stream>>>(...);
  • Grid/Block 配置num_group + 1 个 block,每个 block 32 个线程
  • Block 分工
    • Block 0 ~ num_group-1:每个 block 处理一个 group,更新该 group 的 X 和 Y 的 TMA descriptor
    • Block num_group:计算所有 group 的 tile 统计信息

总结:update_grouped_tma 在一个 kernel 中完成两件事:为每个 group 更新独立的 TMA descriptor,同时用 BlockScan 计算 tile 统计信息供后续 scheduler 使用。下面是完整的伪代码:

C++
template <typename Tin, typename Tout, typename TmaX, typename TmaY,
            int kTileM, int kGroupPerThread, int kThreadPerBlock>
__global__ void update_grouped_tma(
    const vec_t<TmaDescriptor, 2> td_xy,  // template TMA descriptors (X, Y) from host
    TmaDescriptor *tma_xy,                // [num_group * 2]: X desc at 2*i, Y desc at 2*i+1
    const Tin *x_ptr,                     // X (activation) data
    const Tout *y_ptr,                    // Y (output) data
    const int *seqlens_ptr,               // [num_group]: seqlen per group
    const int *cu_seqlens_ptr,            // [num_group + 1]: cumulative seqlen
    int *tiles_ptr,                       // [num_group]: tile M count per group
    int *cu_tiles_ptr,                    // [num_group + 1]: cumulative tile M count
    int num_group, int m, int n, int k) {

    int idx = threadIdx.x;
    int igroup = blockIdx.x;

    if (igroup == num_group) {
    // ---- Case 1: Compute tile statistics (last block) ----
    int tiles[kGroupPerThread];

    // Step 1: Each thread computes tile counts for its assigned groups
    for (int i = 0; i < kGroupPerThread; i++) {
        int g = idx * kGroupPerThread + i;
        if (g < num_group) {
        tiles[i] = (seqlens_ptr[g] + kTileM - 1) / kTileM;
        tiles_ptr[g] = tiles[i];
        } else {
        tiles[i] = 0;
        }
    }

    // Step 2: Exclusive scan to compute cumulative tile M counts
    using BlockScan = cub::BlockScan<int, kThreadPerBlock>;
    __shared__ typename BlockScan::TempStorage temp_storage;
    int block_aggregate;
    BlockScan(temp_storage).ExclusiveSum(tiles, tiles, block_aggregate);

    // Step 3: Write cumulative results to global memory
    for (int i = 0; i < kGroupPerThread; i++) {
        int g = idx * kGroupPerThread + i;
        if (g < num_group) {
        cu_tiles_ptr[g] = tiles[i];
        }
    }
    if (idx == 0) {
        cu_tiles_ptr[num_group] = block_aggregate;  // total tile M
    }

    } else {
    // ---- Case 2: Update TMA descriptors for group igroup ----
    __shared__ TmaDescriptor smem_tma_desc[2];

    int num_seq = seqlens_ptr[igroup];
    int cu_seqlen = cu_seqlens_ptr[igroup];

    // Step 1: Copy template descriptors to shared memory
    if (idx < 2) {
        smem_tma_desc[idx] = td_xy[idx];
    }
    __syncwarp();

    // Step 2: Thread 0 — update X (activation) TMA descriptor
    if (idx == 0) {
        auto gX = make_tensor(
            make_gmem_ptr(x_ptr + cu_seqlen * k),
            make_shape(num_seq, k),
            make_stride(k, Int<1>{}));              // stride unchanged: same k for all groups
        update_tma_gtensor<TmaX>(smem_tma_desc[0], gX);
    }

    // Step 3: Thread 1 — update Y (output) TMA descriptor
    if (idx == 1) {
        auto gY = make_tensor(
            make_gmem_ptr(y_ptr + cu_seqlen * n),
            make_shape(n, num_seq),
            make_stride(Int<1>{}, n));              // stride unchanged: same n for all groups
        update_tma_gtensor<TmaY>(smem_tma_desc[1], gY);
    }

    // Step 4: Commit smem writes, then copy descriptors from smem to gmem
    for (int i = 0; i < 2; i++) {
        __syncwarp();
        if (elect_one_sync()) {
        tma_desc_commit_group();
        tma_desc_wait_group();
        }
        tma_descriptor_cp_fence_release(tma_xy + igroup * 2 + i, smem_tma_desc[i]);
    }
    }
}

在 Group Gemm 中使用 TMA

TMA for X (activation)

在 hpc-ops 当中,其使用 tma 的方式是

C++
copy(tma_origin.with(new_tma_descriptor, mbarrier), gmem_tensor, smem_tensor);

此时,可以认为 copy 所使用的 tma descriptor 就不是 tma_origin 中原来在 Host 端定义的 tma descriptor 了,而是我们的 new_tma_descriptor。其 gmem ptr 和 shape 都发生了改变,以适应 group gemm 当中不同 group 的 activation 数据搬运。由于 tma descriptor 的改变,我们的 gmem coord 也需要进行相应的适配。在 hpc-ops 当中的 copy 代码如下

C++
// itile_m is not a global idx, it is relative to the group
cute::copy(tma_a.with(td_x, readable[ismem_write]), tAg(_, itile_m, itile_k), tAs(_, 0, 0, ismem_write));

在代码中只有一个 partitioned tensor tAg,但是所有的 group 都使用这个 tAg,这是正确的吗?其实这是一个 coordinate tensor,我们只需要填入正确的 coordinate 即可。所以对于 itile_m,其 tma descriptor 更新为了 td_x,我们需要计算当前的 m coordinate 是相对于该 group 的首地址的偏移量,而不是全局的偏移量即可。这一点我在之后的 scheduler 笔记当中也会再此提到

TMA for W (weight)

对于权重来说,其维度是三维的 (n, k, num_group)。相应的,我们在定义 tiled copy 时所使用的 weight gmem tensor 也是三维的,不过需要注意的是:所使用的 copy box 却是二维的

C++
using SLayoutW = decltype(tile_to_shape(SLayoutWAtom{}, 
                            make_shape(Int<kTileN>{}, Int<kTileK>{}, Int<kStage>{})));
// w is 3-dim tensor, but copy box is 2-dim
auto tma_w = make_tma_copy(SM90_TMA_LOAD{}, w, take<0, 2>(SLayoutW{}));

我之前的理解是:tma 在搬运 tensor 的时候是根据首坐标 + box dim 来确定搬运数据的范围。此时首坐标是 3D 的,box dim 是 2D 的,这似乎挑战了我之前的理解。不过回答也非常简单,现在的数据范围是根据 3D 中的前 2D 坐标 + box dim 来确定的。合理猜测,如果我们把 gmem 维度顺序变成 (num_group, n, k),但 slayout 保持不变,那么:

Text Only
- slayout 第 0 维 (kTileN) → gmem 第 0 维 (num_group)                                                
- slayout 第 1 维 (kTileK) → gmem 第 1 维 (n)

copy box 就会沿着 num_group 和 n 维度进行 copy,这显然不是我们想要的结果

实现细节补充

以下是对上述代码中涉及的关键原语的补充说明。

BlockScan

cub::BlockScan 是一个并行前缀和计算原语。这里使用的是 Exclusive Sum Scan

Exclusive Scan 是一种并行计算原语,对数组进行前缀和计算,但每个位置的结果是该位置之前所有元素的和。

Text Only
示例:
输入:  [a, b, c, d]
输出:  [0, a, a+b, a+b+c]  ← exclusive sum
总和:  a+b+c+d               ← block_aggregate

对比 Inclusive Scan:
输入:  [a, b, c, d]
输出:  [a, a+b, a+b+c, a+b+c+d]  ← inclusive sum

可以从 hpc-ops 中的代码代表了 block scan 的一般用法

C++
// Line 88: define BlockScan type
using BlockScan = cub::BlockScan<int, kThreadPerBlock>;
// - template param 1: int - data type for scan
// - template param 2: kThreadPerBlock = 32 - threads per block

// Line 89: allocate shared memory
__shared__ typename BlockScan::TempStorage temp_storage;
// - TempStorage is a struct defined internally by cub
// - shared memory needed to coordinate thread communication
// - size computed automatically by cub

// Line 90: used to return total sum
int block_aggregate;

// Line 91: execute Exclusive Sum Scan
BlockScan(temp_storage).ExclusiveSum(tiles, tiles, block_aggregate);
// Parameter description:
// - tiles (input): data array contributed by each thread
// - tiles (output): scan result (modified in-place)
// - block_aggregate: returns total sum of the block

TMA Descriptor 更新

blockIdx.x < num_group 时,为该 group 更新 TMA descriptor。注意,我们不是在 device 端从头创建 TMA descriptor,而是:

  1. Host 端创建模板td_xy 包含了正确的 stride、tile size 等配置
  2. Device 端只更新必要字段
    • 全局内存地址:指向该 group 数据的起始位置
    • Shape:根据该 group 的 seqlen 设置,例如对于输入 X (activation),其第 i 个 group 的 tma gmem tensor shape 应设置为 [seqlens_ptr[igroup], k]

为什么 stride 不需要更新?

张量 Shape Stride 是否变化
X [num_seq, k] (k, 1) k 固定,所有 group 一样
Y [n, num_seq] (1, n) n 固定,所有 group 一样
  • k 是隐藏层维度(hidden_size),对所有 group 相同
  • n 是输出维度(output_dim),对所有 group 相同
  • 只有 num_seq 变化(seqlens_ptr[igroup]

update_tma_gtensor 的作用

该 device function 是更新 tma descriptor 的核心。会从 gmem tensor 中提取 shape & stride & gmem ptr,然后把这些信息更新到 TMA descriptor 中。我一开始还有疑问:为什么一定要用 shared memory 创建 cuTensorMap?虽然我之前了解到 tma 存储的信息都是放在 smem 当中的,但是我们仍然可以把这些信息放到寄存器当中,然后修改,最后再存回 gmem 当中呀。后来 agent 了解到 tma_descriptor_replace_shapes_in_shared_mem 该 PTX 要求操作源必须在 shared memory 当中,所以必须使用 smem

tma_desc_commit_group

C++
if (cute::elect_one_sync()) {
    cute::tma_desc_commit_group();
    cute::tma_desc_wait_group();
}

在我们的代码中,同一个 warp 中的不同线程在修改不同的 TMA descriptor

  • 线程 0 更新 smem_tma_desc[0] (X)
  • 线程 1 更新 smem_tma_desc[1] (Y)

这时候需要用 tma_desc_commit_group 来确保 warp 中所有线程对 TMA descriptor 的修改都完成并且可见。这里的 PTX 和 tma_store_fence 是一样的,我们之前使用 tma_store_fence 是为了确保 tma store 操作必须要在 smem 写入完成之后。在这里起到同样的作用,因为我们之后要把修改好的 smem 内容写回到 gmem 中存储的 cuTensorMap 当中,必须要保证所有的 smem 写入完成才发起该操作。

我们在之前进行了 update_tma_tensor 的操作,其实这是一个在 async proxy 发起的对 smem 上的修改。所以我们必须要保证整个修改的完成,才能在之后进行 tma_descriptor_cp_fence_release,将 smem 当中的 descriptor copy 到 gmem 当中。此时就需要一个 fence 来保证顺序,而 commit & wait 就是不错的选择。这里仅使用了一个 thread 来进行 commit & wait,实际上由于每一个 CTA 只有一个 warp,这里单个 thread wait,其他 thread 并不能够偷跑到下一个代码进行整形,而是也在空转等待

tma_descriptor_cp_fence_release

这个函数做两件事(fused copy + fence):

  1. Copy:把 128 字节的 TMA descriptor 从 shared memory 拷贝到 global memory
  2. Fence:带 release 语义的内存屏障。此屏障的作用是:确保之后使用 tma 的操作,都必须在该写入操作完成之后执行。可以想象为,这个 release fence 把之前的所有写代码都拦住了,编译器不可能把他们重排到这个 fence 之后。还有另一种带 acquire 语义的内存屏障,它会保证之后的所有读操作都必须在该读操作完成之后执行。这也是为什么 acquire & release 通常成对出现,我查阅了下 tma_store_fence 它到底属于 acquire 还是 release 呢?我认为答案应该是 both!我们既不能让 smem 写操作跨越该 fence,也不让 tma store 操作跨越该 fence

与主 kernel 配对使用

Producer(update_grouped_tma):

C++
tma_descriptor_cp_fence_release(tma_xy + i, smem_tma_desc[i]);
// "Release": ensure all prior writes are visible

Consumer(主 kernel):

C++
tma_descriptor_fence_acquire(td_xy + i);
// "Acquire": ensure subsequent reads see the complete descriptor