Multi-CTA

在 Hopper 架构中,NVIDIA 在层级结构中增加了一个新的线程组级别:CGA。CGA 是一个最多包含 16 个 CTA 的组,它们可以相互协作。具体来说,它们可以:

  • 通过 TMA 广播协作从 HBM 加载数据

  • 通过访问彼此的共享内存来交换数据。这通常被称为“使用分布式共享内存”

  • 从 Blackwell 架构开始,CTA 对可以协作计算矩阵乘法的结果

  • CGA 集群的子集可以通过 mbarrier 进行选择性同步

当然,不同的 CTA 可能会也可能不会被分配到同一个 SM(实际上文档对此没有任何保证),因此诸如同步或访问彼此共享内存之类的操作比在单个 CTA 内访问共享内存或同步线程要昂贵得多。因此,在使用 CGA 时,核心策略是最大化协作,同时避免引入不必要的同步点。

Multi-CTA 布局

布局可以以自然的方式跨 CTA 进行分片。例如,我们可以对包含 4 个 warp 和 2 个 CTA 的程序采用如下形式的分块布局:

gl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0], cga_layout=[[1, 0]])

cga_layout 表示 [[1, 0]] 代表线性布局。在这种情况下,它表示两个 CTA 沿着第 0 维将张量分片为两个连续的子张量。

类似地,如果我们有 8 个 CTA,并且想要通过前 4 个 CTA 沿第 0 维对共享内存描述符进行分片,然后再通过最后 4 个 CTA 沿第 1 维进行分片,布局可能如下所示:

gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=[[1, 0], [2, 0], [0, 1]])

cga_layout 将始终具有 log2(numCTAs) 个基底,并且始终表示将完整张量分片为连续块。对于那些 CTA 可能不会将张量分片为连续子张量的分片模式,例如:

| CTA0 warp0 | | CTA1 warp0 | | CTA0 warp1 | | CTA1 warp1 |

可以使用 LinearEncoding 处理寄存器中的数据,使用 SharedLinearEncoding 处理共享内存中的数据。在这些情况下,CGA 布局不是作为名为 CGALayout 的属性,而是作为 LinearLayout 的一部分,编码在名为 block 的输入维度下。上述示例看起来如下所示:

gl.LinearEncoding(warps=[[2]], block=[[1]])

因为我们首先沿着 CTA 进行分片,然后沿着 warp 进行分片。


import importlib
import pytest
import torch
import triton

from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.blackwell import (
    TensorMemoryLayout,
    allocate_tensor_memory,
    clc,
    tcgen05_commit,
    tcgen05_mma,
    tcgen05_mma_barrier_count,
    tensor_memory_descriptor,
)
from triton.experimental.gluon.language.nvidia.hopper import mbarrier, tma
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor

# Re-use baseline tutorials for comparisons.
t8 = importlib.import_module("08-warp-specialization")


def is_hopper_or_newer():
    if not torch.cuda.is_available():
        return False
    target = triton.runtime.driver.active.get_current_target()
    return target.backend == "cuda" and torch.cuda.get_device_capability()[0] >= 9


def is_blackwell():
    if not torch.cuda.is_available():
        return False
    target = triton.runtime.driver.active.get_current_target()
    return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 10


if __name__ == "__main__" and not is_blackwell():
    raise RuntimeError("This tutorial requires a Blackwell NVIDIA GPU")


def tflops(ms, M, N, K):
    return 2 * M * N * K * 1e-12 / (ms * 1e-3)


def gbps(ms, num_bytes):
    return num_bytes * 1e-9 / (ms * 1e-3)


def pick_multicta_softmax_config(n_cols):
    warp_thresholds = [(3072, 1), (6144, 2)]
    cluster_thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]

    num_warps = next((v for limit, v in warp_thresholds if n_cols <= limit), 4)
    cluster_n = next((v for limit, v in cluster_thresholds if n_cols <= limit), 16)
    return {
        "num_warps": num_warps,
        "num_ctas": cluster_n,
    }

Multi-CTA 内核启动时使用 num_ctas > 1,其中 num_ctas 是 2 的幂且 num_ctas <= 16

诸如 gl.convert_layoutgl.reducegl.sum 等布局驱动的操作,当源布局和目标布局对 CTA 维度的分片方式不同时,会自动使用集群。

下面的内核将一行数据分片到多个 CTA 上,并使用 gl.maxgl.sum 中的自动跨 CTA 归约来计算数值稳定的行级 Softmax。

如果没有 CGA,一旦行数据对于单个 CTA 来说太宽,我们就需要切换到迭代归约或多内核方法。



@gluon.jit
def multicta_softmax_kernel(
    x_ptr,
    out_ptr,
    x_row_stride,
    out_row_stride,
    BLOCK_N: gl.constexpr,
):
    pid = gl.program_id(0)
    cga_layout: gl.constexpr = ((1, ), (2, ), (4, ), (8, ), (16, ))[:gl.num_ctas().bit_length() - 1]
    layout: gl.constexpr = gl.BlockedLayout([4], [32], [gl.num_warps()], [0], cga_layout=cga_layout)
    offs_n = gl.arange(0, BLOCK_N, layout)
    mask = offs_n < BLOCK_N
    row_start = pid * x_row_stride
    out_row_start = pid * out_row_stride
    x = gl.load(x_ptr + row_start + offs_n, mask=mask, other=-float("inf"))
    row_max = gl.max(x, axis=0)
    y = gl.exp(x - row_max)
    row_sum = gl.sum(y, axis=0)
    z = y * (1.0 / row_sum)
    gl.store(out_ptr + out_row_start + offs_n, z, mask=mask)


def multicta_softmax_f32(x, out=None):
    M, N = x.shape
    cfg = pick_multicta_softmax_config(N)
    if out is None:
        out = torch.empty_like(x)

    multicta_softmax_kernel[(M, )](
        x,
        out,
        x.stride(0),
        out.stride(0),
        BLOCK_N=N,
        num_warps=cfg["num_warps"],
        num_ctas=cfg["num_ctas"],
    )
    return out


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
@pytest.mark.parametrize("M, N", [(64, 64), (64, 256), (16, 2**16)])
def test_multicta_softmax_f32(M, N):
    x = torch.randn((M, N), device="cuda", dtype=torch.float32)
    out = multicta_softmax_f32(x)
    ref = torch.softmax(x, dim=1)
    torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)


def benchmark_multicta_softmax_f32():
    if not is_hopper_or_newer():
        raise RuntimeError("softmax benchmark requires Hopper or newer")

    SOFTMAX_BENCH_SHAPES = [
        (2**15, 2**8),
        (2**15, 2**9),
        (2**15, 2**10),
        (2**15, 2**11),
        (2**15, 2**12),
        (2**15, 2**13),
        (2**15, 2**14),
        (2**15, 2**15),
        (2**15, 2**16),
        (2**14, 2**17),
        (2**13, 2**18),
    ]
    print("Benchmarking multicta_softmax")
    print("============================")
    print("  shape         CTAs  warps  time (ms)  bandwidth (GB/s)")
    for M, N in SOFTMAX_BENCH_SHAPES:
        cfg = pick_multicta_softmax_config(N)
        x = torch.empty((M, N), device="cuda", dtype=torch.float32).uniform_(-1, 1)
        out = torch.empty_like(x)
        ms = triton.testing.do_bench_cudagraph(lambda: multicta_softmax_f32(x, out))
        num_bytes = 2 * x.numel() * x.element_size()
        print(f"{M:>6} x {N:<6}  {cfg['num_ctas']:>4}  {cfg['num_warps']:>5}  {ms:>9.3f}  {gbps(ms, num_bytes):>16.2f}")


benchmark_multicta_softmax_f32()

Softmax 基准测试结果

Benchmarking multicta_softmax
============================
  shape         CTAs  warps  time (ms)  bandwidth (GB/s)
 32768 x 256        1      1      0.018           3661.46
 32768 x 512        1      1      0.020           6746.45
 32768 x 1024       1      1      0.040           6740.50
 32768 x 2048       1      1      0.078           6920.01
 32768 x 4096       1      2      0.152           7065.25
 32768 x 8192       1      4      0.301           7136.76
 32768 x 16384      1      4      0.600           7157.74
 32768 x 32768      2      4      1.312           6545.11
 32768 x 65536      4      4      2.836           6057.26
 16384 x 131072     8      4      3.142           5468.66
  8192 x 262144    16      4      3.627           4736.15

我们看到,使用 multiCTA,我们能够获得非常好的全面性能。

Multi-CTA 同步

由于 CTA 可能位于不同的 SM 上,同步比在单个 CTA 内慢得多。因此,gluon 提供了一种相当保守的自动同步保证,其余的同步工作由用户负责。

当源布局和目标布局对 CTA 维度的分片方式不同时,Gluon 会在诸如 gl.convert_layoutgl.reducegl.sum 等操作之间放置同步原语。所有其他操作(如 TMA、WGMMA、TCGen5MMA 等)应由内核编写者通过 mbarrier 进行同步,就像在单 CTA 内核中所做的那样。

多 CTA mbarrier 的 cga_layout 语义略有不同:如 02-layouts.py 中所述,线性布局表示从 F_2^n 到 F_2^m 的映射。在这种情况下,cga_layout 是从 numCTAs(2 的幂)到它所代表的屏障数量的映射。例如,我们可以有一个 mbarrier 布局,其中每个 CTA 都有自己的屏障。

num_ctas: gl.constexpr = 4
bar = gl.allocate_shared_memory(gl.int64, [num_ctas], MBarrierLayout(cga_layout=[[1], [2], [4]]))

因此,我们通过列(二进制)定义 cga_layout 矩阵,它代表 3x3 单位矩阵。由于这种模式非常常见,gluon 提供了一个辅助函数来创建它。

bar = mbarrier.allocate_mbarrier()

现在,屏障布局也允许跨 CTA 同步。例如,我们可以将 8CTA 内核的 2-CTA mbarrier 定义为:

bar = gl.allocate_shared_memory(gl.int64, [4], MBarrierLayout(cga_layout=[[0], [1], [2]]))

请注意,现在非零基底只是 [1] 和 [2],因此只有 2^2 = 4 个屏障。由于它是 8 CTA 内核,因此有 2^3 = 8 个基底。该布局现在在第 0 列上有广播。这意味着任何在第 0 位上仅有差异的 CTA 都将共享一个屏障。例如,CTA0 和 CTA1 将共享一个屏障,CTA2 和 CTA3 也是如此,以此类推。引导 CTA 是组中最小的 CTA ID。对于此布局,偶数 CTA ID 是引导 CTA。

通常,mbarrier cta_layout 是一个序列 [[2**i] for i in range(k)],其中 k <= log2(num_ctas),并交替插入 log2(num_ctas) - k 个零。

所有作用于屏障的操作都能自然地推广到多 CTA 屏障。更具体地说:

  • mbarrier.init 将计数参数乘以组中的 CTA 数量,并且仅在引导 CTA 上初始化;

  • mbarrier.expectsize_per_cta 参数乘以组中的 CTA 数量,并且仅在引导 CTA 上预期。由于 expect 计算为一次到达,所有非引导 CTA 也将向引导 CTA 发出一次到达信号;

  • mbarrier.arrive 组中的每个 CTA 都会在引导 CTA 上到达;

  • mbarrier.wait 仅引导 CTA 等待屏障;

跨 CTA 使用的屏障需要一个额外的排序规则:每个相关的 mbarrier.init 必须在任何 CTA 使用该屏障之前完成。这既适用于跨 CTA 本身的屏障,也适用于由组播或 2CTA 操作消耗的按 CTA 的屏障。编译器在顶级 init 序列之后和第一次使用之前插入所需的 fence.mbarrier_init.release.cluster 以及一个松散的集群屏障。由于集群屏障必须在 warp_specialize 之外执行,请在进入 warp 专用化之前在内核的顶层部分初始化这些屏障。

关于同步的最后一点说明:cluster.arrive / cluster.wait(即 CGA 屏障,等同于 CTA 的 bar.sync)必须由内核中的所有线程执行。因此,它们不能在 warp_specialize 块内使用。

此外,诸如 convert_layout、reduce、sum、max 等操作在跨越 CTA 时会发出 CGA 屏障。因此,当这些操作可能跨越多个 CTA 时,也不允许在 warp_specialize 块内使用。

2CTA TCGen5MMA

在 2CTA 模式下,tcgen05_mma 指令使用来自对中每一个其他 CTA(即 CTA0 和 CTA1,CTA2 和 CTA3 等)的数据来计算结果。用数学术语来说,它计算两个操作数的外部积,其中 LHS 持有沿 M 维度分片的输入,RHS 持有沿 N 维度分片的输入。在 cga_layouts 方面,LHS 的第一个基底等于 (1, 0),RHS 的第一个基底等于 (0, 1)。累加器也共享为 (1, 0)。

与单 CTA 一样,TensorMemoryLayout 的 blockM 形状就是指令的形状,可以是 64 或 128。

上面的外部积布局描述了所需的数据放置:LHS (1, 0),RHS (0, 1),累加器/输出 (1, 0)。在此之上,启用 2CTA 有两个必需的代码更改,以及等待 MMA 完成时的第三个更改:

  • 使用 TensorMemoryLayout(..., two_ctas=True) 在累加器上选择 2CTA 模式;

  • 使用 mbarrier.allocate_mbarrier(..., two_ctas=True) 创建任何在引导 CTA 发出 tcgen05_mma 之前必须等待的屏障;

  • 初始化用于等待 MMA 完成的 mbarrier 时,将 two_ctas=... 传递给 tcgen05_mma_barrier_count

这种 MMA 前的屏障需要 two_ctas=True,以便引导 CTA 在发出集合操作之前等待两行。在这个单 tile 示例中,这就是 tma_bar;在下面的流水线矩阵乘法中,对应的示例是 load_ready_barsacc_empty_bars。这些是 03-matmul-multicta.py 使用的相同的 2CTA 站点。

mma_bar 本身不需要 two_ctas=Truetcgen05_mma 会将其完成信号组播到对中的两个 CTA。一旦内核中的一个 tcgen05_mma 使用了 2CTA 模式,该内核中的所有 tcgen05_mma 指令都必须使用 2CTA 模式。

内核 two_cta_tcgen05_kernel 展示了单个 tile 上的 2CTA TCGen5MMA 模式。

值得注意的是,一旦 TMA 必须等待 tcgen05_mma,模式就会发生一些变化。我们将在下一节中讨论这个问题。



@gluon.jit
def two_cta_tcgen05_kernel(a_desc, b_desc, c_desc):
    gl.static_assert(gl.num_ctas() == 2)

    cluster_m: gl.constexpr = a_desc.block_shape[0]
    tile_n: gl.constexpr = b_desc.block_shape[1]
    cta_m: gl.constexpr = cluster_m // 2
    cga_layout: gl.constexpr = c_desc.layout.cga_layout

    smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
    smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)

    tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
    mma_bar = mbarrier.allocate_mbarrier()
    mbarrier.init(tma_bar, count=1)
    mbarrier.init(mma_bar, count=1)

    mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
    tma.async_load(a_desc, [0, 0], tma_bar, smem_a)
    tma.async_load(b_desc, [0, 0], tma_bar, smem_b)
    mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b])
    mbarrier.invalidate(tma_bar)

    acc_layout: gl.constexpr = TensorMemoryLayout(
        block=(cta_m, tile_n),
        col_stride=1,
        cga_layout=cga_layout,
        two_ctas=True,
    )
    acc = allocate_tensor_memory(gl.float32, [cluster_m, tile_n], acc_layout)

    tcgen05_mma(smem_a, smem_b, acc, use_acc=False, mbarriers=[mma_bar])
    mbarrier.wait(mma_bar, phase=0, deps=[smem_a, smem_b])
    mbarrier.invalidate(mma_bar)

    c_smem = gl.allocate_shared_memory(c_desc.dtype, c_desc.block_shape, c_desc.layout)
    c_smem.store(acc.load().to(c_desc.dtype))
    tma.async_copy_shared_to_global(c_desc, [0, 0], c_smem)


def run_two_cta_tcgen05(a, b, c):
    M, N, K = a.shape[0], b.shape[1], a.shape[1]
    a_layout = gl.NVMMASharedLayout.get_default_for([M, K], gl.float16, cga_layout=[(1, 0)])
    b_layout = gl.NVMMASharedLayout.get_default_for([K, N], gl.float16, cga_layout=[(0, 1)])
    c_layout = gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=[(1, 0)])

    a_desc = TensorDescriptor.from_tensor(a, [M, K], a_layout)
    b_desc = TensorDescriptor.from_tensor(b, [K, N], b_layout)
    c_desc = TensorDescriptor.from_tensor(c, [M, N], c_layout)

    two_cta_tcgen05_kernel[(1, )](a_desc, b_desc, c_desc, num_warps=4, num_ctas=2)


@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_two_cta_tcgen05():
    M, N, K = 256, 128, 64
    a = torch.randn((M, K), device="cuda", dtype=torch.float16)
    b = torch.randn((K, N), device="cuda", dtype=torch.float16)
    c = torch.empty((M, N), device="cuda", dtype=torch.float16)

    run_two_cta_tcgen05(a, b, c)
    torch.testing.assert_close(c, torch.matmul(a, b), atol=1e-1, rtol=1e-2)

相对于单 CTA 情况,下面的代码因此改变了操作数布局,在累加器布局上使用 two_ctas=True,并在 TMA 交接屏障上使用 two_ctas=Truembarrier.expect 字节计数保持为每个 CTA 的,而不是整个对的。

请注意,一旦在 for 循环中和/或与带组播的 TMA 一起使用时,会有一些额外的变化。下一节将详细介绍。

带组播的 TMA

从 Hopper 开始,TMA 具备了向多个 CTA 组播数据的能力。这在 Hopper 的 multi-CTA 内核中非常有用,因为 wgmma 没有 2CTA 模式,或者在 Blackwell+ 内核上使用超过 2 个 CTA 时非常有用。

在这种情况下,对于累加器的 cga_layout,我们可以按照以下方式计算 A 和 B 的布局:


# Example cga_layout
cga_layout = [(1, 0), (2, 0), (0, 1)]


def get_cga_layout(layout, op_idx, two_ctas):
    assert op_idx in (0, 1)
    if not layout:
        return layout

    # Broadcast along K (the reduction dimension)
    # We multiply by 2 for op_idx == 1, as we have added the (0, 1) basis.
    def broadcast(b):
        mul = 2 if two_ctas else 1
        return (b[0], 0) if op_idx == 0 else (0, mul * b[1])

    if not two_ctas:
        return tuple(map(broadcast, layout))

    # 2CTA performs an outer product so bases are [1, 0] and [0, 1]
    assert layout[0] == (1, 0)
    first = (1, 0) if op_idx == 0 else (0, 1)
    return (first, *map(broadcast, layout[1:]))


cga_layout_a = get_cga_layout(cga_layout, 0, two_ctas=False)
cga_layout_b = get_cga_layout(cga_layout, 1, two_ctas=False)

换句话说,A 和 B 的 cga_layout 是 C 的布局,但将每个的内部维度置零。

这意味着 A 和/或 B 的某些基底为零,因此不同的 CTA 将加载相同的数据。组播将允许这些 CTA 有效地命中 L2 缓存。

上面的广播布局描述了哪些 CTA 应该接收相同的 tile。之后,流水线矩阵乘法有三个 multicast=True 站点:

  • tma.async_load(..., multicast=True) 将 tile 发送到广播组中的每个 CTA;

  • tcgen05_mma(..., multicast=True) 在下一个 TMA 覆盖共享内存 tile 之前组播其 mbarrier 到达信号;

  • tcgen05_mma_barrier_count(..., multicast=True, ...) 使用匹配的到达计数初始化该 mbarrier。

TMA 同步模式本身与常规加载相同:初始化一个屏障,expect 字节计数,发出 TMA,然后等待屏障。

之所以有效,是因为 TMA 指令以原子方式将其到达信号广播到组播组中的每个 CTA,因此等待端不需要不同的 API。

TMA 目标必须使用广播 cga_layout,以便两个 CTA 接收相同的共享内存 tile。除非内核处于 2CTA 模式,否则屏障保持为常规的 1D TMA 屏障。

下面的示例刻意保持简单:它将一个 tile 组播到共享内存中,然后将该 tile 原样存回全局内存。



@gluon.jit
def tma_multicast_copy_kernel(in_desc, out_desc):
    gl.static_assert(gl.num_ctas() == 2)

    smem = gl.allocate_shared_memory(in_desc.dtype, in_desc.block_shape, in_desc.layout)
    # This kernel is not in 2CTA mode, so the TMA barrier is per-CTA.
    bar = mbarrier.allocate_mbarrier()
    mbarrier.init(bar, count=1)

    mbarrier.expect(bar, in_desc.nbytes_per_cta)
    tma.async_load(in_desc, [0, 0], bar, smem, multicast=True)
    mbarrier.wait(bar, phase=0, deps=[smem])

    tma.async_copy_shared_to_global(out_desc, [0, 0], smem)


def run_tma_multicast_copy(inp, out):
    layout = gl.NVMMASharedLayout.get_default_for(inp.shape, gl.float16, cga_layout=[[0, 0]])
    in_desc = TensorDescriptor.from_tensor(inp, inp.shape, layout)
    out_desc = TensorDescriptor.from_tensor(out, inp.shape, layout)

    tma_multicast_copy_kernel[(1, )](in_desc, out_desc, num_warps=4, num_ctas=2)


@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def test_tma_multicast_copy():
    M, N = 128, 128
    inp = torch.randn((M, N), device="cuda", dtype=torch.float16)
    out = torch.empty_like(inp)

    run_tma_multicast_copy(inp, out)
    torch.testing.assert_close(out, inp, atol=0, rtol=0)

循环中的 TMA 到 MMA

在这里,我们说明了将 TMA(带或不带组播)混合到 tcgen05_mma 流水线中的完全通用方法。

在这种情况下,tcgen05_mma 指令需要等待其组播组中的所有 CTA 完成,然后才能继续下一次迭代,否则下一次迭代的 TMA 加载会在前一次迭代完全消耗完之前覆盖其共享内存数据。

因此,我们需要使用 tcgen05_mma_barrier_count 计算组播组中的 CTA 数量。同样,我们在 tcgen05_mma 指令上设置 multicast=True 标志,以说明它必须等待组播组完成才能继续。

这些函数是通用的,因此这种形式的模式也适用于非组播内核或非 2CTA 内核。



@gluon.jit
def tma_tcgen05_kernel(a_desc, b_desc, out_desc, NUM_K_TILES: gl.constexpr, acc_tmem_layout: gl.constexpr):
    block_m: gl.constexpr = a_desc.block_shape[0]
    block_k: gl.constexpr = a_desc.block_shape[1]
    block_n: gl.constexpr = b_desc.block_shape[1]

    smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
    smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)

    acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout)
    tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
    mma_bar = mbarrier.allocate_mbarrier()
    mbarrier.init(tma_bar, count=1)
    mbarrier.init(
        mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True,
                                                 two_ctas=acc_tmem.type.layout.two_ctas))

    phase_tma = 0
    phase_mma = 0

    for k in range(NUM_K_TILES):
        mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
        tma.async_load(a_desc, [0, k * block_k], tma_bar, smem_a, multicast=True)
        tma.async_load(b_desc, [k * block_k, 0], tma_bar, smem_b, multicast=True)
        mbarrier.wait(tma_bar, phase=phase_tma, deps=[smem_a, smem_b])
        phase_tma ^= 1

        tcgen05_mma(smem_a, smem_b, acc_tmem, use_acc=(k != 0), multicast=True, mbarriers=[mma_bar])
        mbarrier.wait(mma_bar, phase=phase_mma, deps=[smem_a, smem_b])
        phase_mma ^= 1

    mbarrier.invalidate(tma_bar)
    mbarrier.invalidate(mma_bar)

    out_smem = gl.allocate_shared_memory(out_desc.dtype, out_desc.block_shape, out_desc.layout)
    out_smem.store(acc_tmem.load().to(out_desc.dtype))
    tma.async_copy_shared_to_global(out_desc, [0, 0], out_smem)


def tma_tcgen05_example(a, b):
    BLOCK_M = 512
    BLOCK_N = 128
    BLOCK_K = 64
    NUM_K_TILES = 2
    cga_layout_a = ((1, 0), (2, 0))
    cga_layout_b = ((0, 1), (0, 0))
    cga_layout_c = ((1, 0), (2, 0))

    M, K = a.shape
    Kb, N = b.shape
    if K != Kb:
        raise ValueError(f"inner dimensions must match, got {K} and {Kb}")
    if M != BLOCK_M or N != BLOCK_N or K != BLOCK_K * NUM_K_TILES:
        raise ValueError(f"expected shapes {(BLOCK_M, BLOCK_K * NUM_K_TILES)} x "
                         f"{(BLOCK_K * NUM_K_TILES, BLOCK_N)}, got {tuple(a.shape)} x {tuple(b.shape)}")

    out = torch.empty((M, N), device="cuda", dtype=torch.float16)
    a_layout = gl.NVMMASharedLayout.get_default_for([M, BLOCK_K], gl.float16, cga_layout=cga_layout_a)
    b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, N], gl.float16, cga_layout=cga_layout_b)
    c_layout = gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=cga_layout_c)
    acc_tmem_layout = TensorMemoryLayout(block=(128, N), col_stride=1, cga_layout=cga_layout_c, two_ctas=True)
    a_desc = TensorDescriptor.from_tensor(a, [M, BLOCK_K], a_layout)
    b_desc = TensorDescriptor.from_tensor(b, [BLOCK_K, N], b_layout)
    c_desc = TensorDescriptor.from_tensor(out, [M, N], c_layout)

    tma_tcgen05_kernel[(1, )](
        a_desc,
        b_desc,
        c_desc,
        NUM_K_TILES,
        acc_tmem_layout,
        num_warps=4,
        num_ctas=4,
    )
    return out


@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_tma_tcgen05():
    M = 512
    N = 128
    K = 128
    a = torch.randn((M, K), device="cuda", dtype=torch.float16)
    b = torch.randn((K, N), device="cuda", dtype=torch.float16)

    out = tma_tcgen05_example(a, b)
    torch.testing.assert_close(out, torch.matmul(a, b), atol=1e-1, rtol=1e-2)

极速矩阵乘法内核

在这里,我们说明了编写矩阵乘法内核的完全通用方法,该内核使用 TMA(可能带有组播)到 warp 专用化的 tcgen05_mma 流水线中。

对于此示例,我们将 12-cluster-launch-control.py 中提出的 CLC 思想推广到一个 warp 专用化的内核,方法是添加一个处理 CLC 生成的新分区,并将其广播给所有 CTA。

我们还使用了一个名为 _planar_snake 的额外辅助函数来调整程序 ID,以提高 L2 局部性。


Counter = t8.Counter
cublas = t8.cublas


@gluon.constexpr_function
def get_split_dim(cga_layout, dim):
    return 1 << sum(b[dim] != 0 for b in cga_layout)


@gluon.jit
def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim: gl.constexpr, tile_width: gl.constexpr):
    major_size = n_tiles if minor_dim == 0 else m_tiles
    minor_size = m_tiles if minor_dim == 0 else n_tiles

    full_minor_tiles = minor_size // tile_width
    full_minor_size = full_minor_tiles * tile_width
    full_elements = full_minor_tiles * tile_width * major_size

    minor_tile_idx = lin_idx // (tile_width * major_size)

    full_minor_within = lin_idx % tile_width
    full_major_within = (lin_idx // tile_width) % major_size
    full_minor = minor_tile_idx * tile_width + full_minor_within
    full_major = gl.where((minor_tile_idx % 2) == 0, full_major_within, major_size - 1 - full_major_within)

    partial_width = minor_size - full_minor_size
    partial_width = gl.where(partial_width > 0, partial_width, 1)
    partial_lin = lin_idx - full_elements
    partial_minor_within = partial_lin % partial_width
    partial_major_within = (partial_lin // partial_width) % major_size
    partial_minor = minor_tile_idx * tile_width + partial_minor_within
    partial_major = gl.where((minor_tile_idx % 2) == 0, partial_major_within, major_size - 1 - partial_major_within)

    in_full_tile = lin_idx < full_elements
    minor = gl.where(in_full_tile, full_minor, partial_minor)
    major = gl.where(in_full_tile, full_major, partial_major)

    if minor_dim == 0:
        return minor, major
    return major, minor


@gluon.aggregate
class ClcTileSchedulerConsumer:
    has_work: gl.tensor
    tile_id: gl.tensor
    pid_m: gl.tensor
    pid_n: gl.tensor
    num_pid_m: gl.tensor
    num_pid_n: gl.tensor
    TILE_M: gl.constexpr
    TILE_N: gl.constexpr
    MINOR_DIM: gl.constexpr
    GRID_TILE_WIDTH: gl.constexpr
    clc_result_buffers: gl.shared_memory_descriptor
    clc_barriers: gl.shared_memory_descriptor
    clc_planar_pid_buffers: gl.shared_memory_descriptor
    clc_planar_ready_bars: gl.shared_memory_descriptor
    clc_consumed_bars: gl.shared_memory_descriptor
    counter: Counter
    consumed_counter: Counter

    @gluon.jit
    def initialize(M, N, TILE_M: gl.constexpr, TILE_N: gl.constexpr, MINOR_DIM: gl.constexpr,
                   GRID_TILE_WIDTH: gl.constexpr, clc_result_buffers, clc_barriers, clc_planar_pid_buffers,
                   clc_planar_ready_bars, clc_consumed_bars):
        tile_id = gl.program_id(axis=0)
        num_pid_m = gl.cdiv(M, TILE_M)
        num_pid_n = gl.cdiv(N, TILE_N)
        pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, MINOR_DIM, GRID_TILE_WIDTH)
        return ClcTileSchedulerConsumer(
            gl.to_tensor(True),
            tile_id,
            pid_m,
            pid_n,
            num_pid_m,
            num_pid_n,
            TILE_M,
            TILE_N,
            MINOR_DIM,
            GRID_TILE_WIDTH,
            clc_result_buffers,
            clc_barriers,
            clc_planar_pid_buffers,
            clc_planar_ready_bars,
            clc_consumed_bars,
            Counter.create(0, clc_barriers.shape[0]),
            Counter.create(0, clc_barriers.shape[0]),
        )

    @gluon.jit
    def get_offsets(self):
        return self.pid_m * self.TILE_M, self.pid_n * self.TILE_N

    @gluon.jit
    def step(self, iteration):
        consumed_counter = self.consumed_counter
        if iteration > 0:
            mbarrier.arrive(self.clc_consumed_bars.index(consumed_counter.index))
            consumed_counter = consumed_counter.next()

        counter = self.counter
        barrier = self.clc_barriers.index(counter.index)
        result = self.clc_result_buffers.index(counter.index)
        mbarrier.wait(barrier, counter.phase)
        clc_res = clc.load_result(result)
        mbarrier.wait(self.clc_planar_ready_bars.index(counter.index), counter.phase)
        planar_slot = self.clc_planar_pid_buffers.index(counter.index)
        planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
                                                       [[0]] * (gl.num_ctas().bit_length() - 1))
        packed_pid = planar_slot.load(planar_layout).reshape([])
        pid_m = ((packed_pid >> 32) & 0xFFFFFFFF).to(gl.int32)
        pid_n = (packed_pid & 0xFFFFFFFF).to(gl.int32)
        has_work = clc_res.is_canceled()
        tile_id = self.tile_id
        if has_work:
            tile_id = clc_res.program_id(0)
        return ClcTileSchedulerConsumer(
            has_work,
            tile_id,
            pid_m,
            pid_n,
            self.num_pid_m,
            self.num_pid_n,
            self.TILE_M,
            self.TILE_N,
            self.MINOR_DIM,
            self.GRID_TILE_WIDTH,
            self.clc_result_buffers,
            self.clc_barriers,
            self.clc_planar_pid_buffers,
            self.clc_planar_ready_bars,
            self.clc_consumed_bars,
            counter.next(),
            consumed_counter,
        )


@gluon.aggregate
class MatmulPartitionArgs:
    a_desc: tma.tensor_descriptor
    b_desc: tma.tensor_descriptor
    c_desc: tma.tensor_descriptor
    a_bufs: gl.shared_memory_descriptor
    b_bufs: gl.shared_memory_descriptor
    load_empty_bars: gl.shared_memory_descriptor
    load_ready_bars: gl.shared_memory_descriptor
    acc_bufs: tensor_memory_descriptor
    acc_empty_bars: gl.shared_memory_descriptor
    acc_ready_bars: gl.shared_memory_descriptor
    clc_result_buffers: gl.shared_memory_descriptor
    clc_barriers: gl.shared_memory_descriptor
    clc_planar_pid_buffers: gl.shared_memory_descriptor
    clc_planar_ready_bars: gl.shared_memory_descriptor
    clc_consumed_bars: gl.shared_memory_descriptor
    MINOR_DIM: gl.constexpr
    GRID_TILE_WIDTH: gl.constexpr
    SUBTILE_STAGES: gl.constexpr

    @gluon.jit
    def get_clc_consumer(self):
        return ClcTileSchedulerConsumer.initialize(
            self.c_desc.shape[0],
            self.c_desc.shape[1],
            self.a_desc.block_shape[0],
            self.b_desc.block_shape[1],
            self.MINOR_DIM,
            self.GRID_TILE_WIDTH,
            self.clc_result_buffers,
            self.clc_barriers,
            self.clc_planar_pid_buffers,
            self.clc_planar_ready_bars,
            self.clc_consumed_bars,
        )


@gluon.jit
def matmul_clc_partition(p):
    tile_m: gl.constexpr = p.a_desc.block_shape[0]
    tile_n: gl.constexpr = p.b_desc.block_shape[1]
    has_work = gl.to_tensor(True)
    num_pid_m = gl.cdiv(p.c_desc.shape[0], tile_m)
    num_pid_n = gl.cdiv(p.c_desc.shape[1], tile_n)
    state = Counter.create(0, p.clc_barriers.shape[0])
    consumed_state = Counter.create(1, p.clc_barriers.shape[0])
    acc_stages: gl.constexpr = p.clc_barriers.shape[0]
    i = 0
    while has_work:
        mbarrier.wait(p.clc_consumed_bars.index(consumed_state.index), consumed_state.phase, pred=(i >= acc_stages))
        barrier = p.clc_barriers.index(state.index)
        result = p.clc_result_buffers.index(state.index)
        mbarrier.expect(barrier, 16)
        clc.try_cancel(result, barrier)
        mbarrier.wait(barrier, state.phase)
        clc_res = clc.load_result(result)
        has_work = clc_res.is_canceled()
        pid_m = gl.to_tensor(0)
        pid_n = gl.to_tensor(0)
        if has_work:
            tile_id = clc_res.program_id(0)
            pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, p.MINOR_DIM, p.GRID_TILE_WIDTH)
        packed_pid = (pid_m.to(gl.int64) << 32) | (pid_n.to(gl.int64) & 0xFFFFFFFF)
        planar_slot = p.clc_planar_pid_buffers.index(state.index)
        planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
                                                       [[0]] * (gl.num_ctas().bit_length() - 1))
        planar_slot.store(gl.full([1], packed_pid, gl.int64, layout=planar_layout))
        mbarrier.arrive(p.clc_planar_ready_bars.index(state.index))
        state = state.next()
        consumed_state = consumed_state.next()
        i += 1


@gluon.jit
def matmul_load_partition(p):
    block_k: gl.constexpr = p.a_desc.block_shape[1]
    K = p.a_desc.shape[1]

    concurrent_loads: gl.constexpr = p.load_ready_bars.shape[0]
    state = Counter.create(1, concurrent_loads)
    scheduler = p.get_clc_consumer()

    i = 0
    while scheduler.has_work:
        off_m, off_n = scheduler.get_offsets()
        for k in range(0, K, block_k):
            pred = (i > 0) or (k >= block_k * concurrent_loads)
            mbarrier.wait(p.load_empty_bars.index(state.index), state.phase, pred=pred)
            bar = p.load_ready_bars.index(state.index)
            mbarrier.expect(bar, p.a_desc.nbytes_per_cta + p.b_desc.nbytes_per_cta)
            tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True)
            tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True)
            state = state.next()
        scheduler = scheduler.step(i)
        i += 1


@gluon.jit
def matmul_mma_partition(p):
    block_k: gl.constexpr = p.a_desc.block_shape[1]
    K = p.a_desc.shape[1]
    acc_stages: gl.constexpr = p.acc_empty_bars.shape[0]

    load_state = Counter.create(0, p.load_empty_bars.shape[0])
    acc_state = Counter.create(1, acc_stages)
    scheduler = p.get_clc_consumer()

    i = 0
    while scheduler.has_work:
        acc_buf = p.acc_bufs.index(acc_state.index)
        mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase, pred=(i >= acc_stages))
        use_acc = False
        for k in range(0, K, block_k):
            mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase)
            tcgen05_mma(
                p.a_bufs.index(load_state.index),
                p.b_bufs.index(load_state.index),
                acc_buf,
                use_acc=use_acc,
                multicast=True,
                mbarriers=[p.load_empty_bars.index(load_state.index)],
            )
            load_state = load_state.next()
            use_acc = True
        tcgen05_commit(p.acc_ready_bars.index(acc_state.index), descs=[p.a_bufs.index(0), p.b_bufs.index(0)])
        acc_state = acc_state.next()
        scheduler = scheduler.step(i)
        i += 1


@gluon.jit
def matmul_epilogue_partition(p):
    tile_m: gl.constexpr = p.a_desc.block_shape[0]
    tile_n: gl.constexpr = p.b_desc.block_shape[1]
    split_tile_n: gl.constexpr = p.c_desc.block_shape[1]
    # Separate knobs: SUBTILE_STAGES controls shared-memory usage,
    # and SUBTILE_FACTOR is the maximum number of subtiles into which we can split the tile.
    subtile_factor: gl.constexpr = tile_n // split_tile_n
    subtile_stages: gl.constexpr = p.SUBTILE_STAGES
    acc_stages: gl.constexpr = p.acc_empty_bars.shape[0]
    dtype: gl.constexpr = p.c_desc.dtype

    acc_state = Counter.create(0, acc_stages)
    acc_smems = gl.allocate_shared_memory(dtype, [subtile_stages, tile_m, split_tile_n], p.c_desc.layout)
    sub_acc_state = Counter.create(0, subtile_stages)
    scheduler = p.get_clc_consumer()

    i = 0
    while scheduler.has_work:
        off_m, off_n = scheduler.get_offsets()
        mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
        acc_buf = p.acc_bufs.index(acc_state.index)

        for s in gl.static_range(subtile_factor):
            acc_sub = acc_buf.slice(split_tile_n * s, split_tile_n)
            acc_smem = acc_smems.index(sub_acc_state.index)
            acc = acc_sub.load().to(dtype)
            tma.store_wait(pendings=subtile_stages - 1)
            acc_smem.store(acc)
            tma.async_copy_shared_to_global(p.c_desc, [off_m, off_n + split_tile_n * s], acc_smem)
            sub_acc_state = sub_acc_state.next()
        mbarrier.arrive(p.acc_empty_bars.index(acc_state.index))
        acc_state = acc_state.next()
        scheduler = scheduler.step(i)
        i += 1


# The entry kernel allocates and initializes every barrier before
# `gl.warp_specialize`. Some of these barriers are cross-CTA barriers, and some
# otherwise per-CTA barriers are consumed by multicast or 2CTA operations. The
# compiler inserts the required init fence and relaxed cluster barrier after
# this top-level init sequence, before any partition can use the barriers.
# Keeping the init sequence here also keeps that mandatory cluster sync outside
# `warp_specialize`.
@gluon.jit
def matmul_multicta_kernel(
    a_desc,
    b_desc,
    c_desc,
    M,
    N,
    K,
    BLOCK_SIZE_M: gl.constexpr,
    BLOCK_SIZE_N: gl.constexpr,
    BLOCK_SIZE_K: gl.constexpr,
    GRID_MINOR_DIM: gl.constexpr,
    GRID_TILE_WIDTH: gl.constexpr,
    STAGES: gl.constexpr,
    ACC_STAGES: gl.constexpr,
    CGA_LAYOUT: gl.constexpr,
    EPILOGUE_SIZE_N: gl.constexpr,
    SUBTILE_STAGES: gl.constexpr,
):
    block_m: gl.constexpr = a_desc.block_shape[0]
    block_n: gl.constexpr = b_desc.block_shape[1]
    two_ctas: gl.constexpr = gl.num_ctas() > 1
    n_partitions: gl.constexpr = 4

    dtype: gl.constexpr = a_desc.dtype
    a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout)
    b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout)
    tmem_layout: gl.constexpr = TensorMemoryLayout(
        [BLOCK_SIZE_M, block_n // get_split_dim(CGA_LAYOUT, 1)],
        col_stride=1,
        cga_layout=CGA_LAYOUT,
        two_ctas=two_ctas,
    )
    acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, block_m, block_n], tmem_layout)
    mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True,
                                                                two_ctas=acc_bufs.index(0).type.layout.two_ctas)

    load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES)
    load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas)
    for i in gl.static_range(STAGES):
        mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count)
        mbarrier.init(load_ready_bars.index(i), count=1)

    acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas)
    acc_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
    for i in gl.static_range(ACC_STAGES):
        mbarrier.init(acc_empty_bars.index(i), count=1)
        mbarrier.init(acc_ready_bars.index(i), count=mma_barrier_count)

    clc_barriers = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
    clc_planar_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
    clc_consumed_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas)
    for i in gl.static_range(ACC_STAGES):
        mbarrier.init(clc_barriers.index(i), count=1)
        mbarrier.init(clc_planar_ready_bars.index(i), count=1)
        mbarrier.init(clc_consumed_bars.index(i), count=n_partitions - 1)

    cga_layout: gl.constexpr = [[0]] * (gl.num_ctas().bit_length() - 1)
    clc_layout: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, [0], cga_layout=cga_layout)
    clc_result_buffers = gl.allocate_shared_memory(
        gl.int64,
        [clc_barriers.shape[0], 2],
        clc_layout,
    )
    clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout)

    p = MatmulPartitionArgs(
        a_desc,
        b_desc,
        c_desc,
        a_bufs,
        b_bufs,
        load_empty_bars,
        load_ready_bars,
        acc_bufs,
        acc_empty_bars,
        acc_ready_bars,
        clc_result_buffers,
        clc_barriers,
        clc_planar_pid_buffers,
        clc_planar_ready_bars,
        clc_consumed_bars,
        GRID_MINOR_DIM,
        GRID_TILE_WIDTH,
        SUBTILE_STAGES,
    )

    gl.warp_specialize([
        (matmul_epilogue_partition, (p, )),
        (matmul_load_partition, (p, )),
        (matmul_mma_partition, (p, )),
        (matmul_clc_partition, (p, )),
    ], [1, 1, 1], [24, 24, 24])


def matmul_multicta(
        a,
        b,
        out=None,
        *,
        block_size_m=128,
        block_size_n=256,
        block_size_k=64,
        grid_minor_dim=0,
        grid_tile_width=16,
        stages=6,
        acc_stages=2,
        cga_layout=((1, 0), ),
        epilogue_size_n=32,
        subtile_stages=4,
):
    if block_size_n // get_split_dim(cga_layout, 1) > 256:
        raise ValueError(
            f"cga_layout={list(cga_layout)} only supports BLOCK_SIZE_N <= {256 * get_split_dim(cga_layout, 1)}")

    M, K = a.shape
    K1, N = b.shape
    if K != K1:
        raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
    if a.dtype != torch.float16 or b.dtype != torch.float16:
        raise ValueError("matmul only supports fp16 inputs")

    if out is None:
        c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    else:
        if out.shape != (M, N):
            raise ValueError(f"Output has invalid shape {out.shape}, expected {(M, N)}")
        c = out

    tile_m = block_size_m * get_split_dim(cga_layout, 0)
    two_ctas = bool(cga_layout)
    a_layout = gl.NVMMASharedLayout.get_default_for([tile_m, block_size_k], gl.float16,
                                                    cga_layout=get_cga_layout(cga_layout, 0, two_ctas))
    b_layout = gl.NVMMASharedLayout.get_default_for([block_size_k, block_size_n], gl.float16,
                                                    cga_layout=get_cga_layout(cga_layout, 1, two_ctas))
    c_layout = gl.NVMMASharedLayout.get_default_for([tile_m, epilogue_size_n], gl.float16, cga_layout=cga_layout)

    a_desc = TensorDescriptor.from_tensor(a, [tile_m, block_size_k], a_layout)
    b_desc = TensorDescriptor.from_tensor(b, [block_size_k, block_size_n], b_layout)
    c_desc = TensorDescriptor.from_tensor(c, [tile_m, epilogue_size_n], c_layout)

    def grid(meta):
        tile_m = meta["BLOCK_SIZE_M"] * get_split_dim(meta["CGA_LAYOUT"], 0)
        tile_n = meta["BLOCK_SIZE_N"]
        num_tiles = triton.cdiv(M, tile_m) * triton.cdiv(N, tile_n)
        return (num_tiles, )

    matmul_multicta_kernel[grid](
        a_desc,
        b_desc,
        c_desc,
        M,
        N,
        K,
        block_size_m,
        block_size_n,
        block_size_k,
        grid_minor_dim,
        grid_tile_width,
        stages,
        acc_stages,
        cga_layout,
        epilogue_size_n,
        subtile_stages,
        num_warps=4,
        num_ctas=2**len(cga_layout),
    )
    return c


@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_matmul_multicta():
    M, N, K = 1024, 1024, 512
    a = torch.randn((M, K), device="cuda", dtype=torch.float16)
    b = torch.randn((K, N), device="cuda", dtype=torch.float16)
    c = matmul_multicta(a, b)
    torch.testing.assert_close(c, torch.matmul(a, b), atol=1e-1, rtol=1e-2)


if __name__ == "__main__" and is_blackwell():
    print("Benchmarking matmul_multicta")
    print("============================")
    cfg = {
        "block_size_m": 128,
        "block_size_n": 256,
        "block_size_k": 64,
        "grid_minor_dim": 0,
        "grid_tile_width": 16,
        "stages": 6,
        "acc_stages": 2,
        "cga_layout": ((1, 0), ),
        "epilogue_size_n": 32,
        "subtile_stages": 4,
    }

    M, N = 8192, 8192
    C = torch.empty((M, N), device="cuda", dtype=torch.float16)
    print("    K         multi-CTA    cublas")
    for K in [2**i for i in range(9, 15)]:
        A = torch.randn((M, K), device="cuda", dtype=torch.float16)
        B = torch.randn((K, N), device="cuda", dtype=torch.float16)
        BT = B.T.contiguous()
        r0 = tflops(triton.testing.do_bench(lambda: matmul_multicta(A, B, out=C, **cfg), warmup=200, rep=1000), M, N, K)
        r1 = tflops(triton.testing.do_bench(lambda: cublas.matmul(A, BT, C), warmup=200, rep=1000), M, N, K)
        print(f"{K:>5} {r0:>17.2f} {r1:>9.2f}")

基准测试 matmul_multicta

K         multi-CTA    cublas

512 1096.31 1190.98 1024 1306.07 1344.48 2048 1379.80 1374.48 4096 1444.26 1431.93 8192 1302.33 1347.82 16384 1292.40 1371.82

我们能够与 cublas 竞争,甚至在该特定配置的相关 K 范围内胜过它们。如果我们为不同的形状选择不同的配置,我们将能够在更广泛的形状范围内胜过 cublas。