Multi-CTA
在 Hopper 架构中,NVIDIA 在层级结构中增加了一个新的线程组级别:CGA。CGA 是一个最多包含 16 个 CTA 的组,它们可以相互协作。具体来说,它们可以:
通过 TMA 广播协作从 HBM 加载数据
通过访问彼此的共享内存来交换数据。这通常被称为“使用分布式共享内存”
从 Blackwell 架构开始,CTA 对可以协作计算矩阵乘法的结果
CGA 集群的子集可以通过
mbarrier进行选择性同步
当然,不同的 CTA 可能会也可能不会被分配到同一个 SM(实际上文档对此没有任何保证),因此诸如同步或访问彼此共享内存之类的操作比在单个 CTA 内访问共享内存或同步线程要昂贵得多。因此,在使用 CGA 时,核心策略是最大化协作,同时避免引入不必要的同步点。
Multi-CTA 布局
布局可以以自然的方式跨 CTA 进行分片。例如,我们可以对包含 4 个 warp 和 2 个 CTA 的程序采用如下形式的分块布局:
gl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0], cga_layout=[[1, 0]])
cga_layout 表示 [[1, 0]] 代表线性布局。在这种情况下,它表示两个 CTA 沿着第 0 维将张量分片为两个连续的子张量。
类似地,如果我们有 8 个 CTA,并且想要通过前 4 个 CTA 沿第 0 维对共享内存描述符进行分片,然后再通过最后 4 个 CTA 沿第 1 维进行分片,布局可能如下所示:
gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=[[1, 0], [2, 0], [0, 1]])
cga_layout 将始终具有 log2(numCTAs) 个基底,并且始终表示将完整张量分片为连续块。对于那些 CTA 可能不会将张量分片为连续子张量的分片模式,例如:
| CTA0 warp0 | | CTA1 warp0 | | CTA0 warp1 | | CTA1 warp1 |
可以使用 LinearEncoding 处理寄存器中的数据,使用 SharedLinearEncoding 处理共享内存中的数据。在这些情况下,CGA 布局不是作为名为 CGALayout 的属性,而是作为 LinearLayout 的一部分,编码在名为 block 的输入维度下。上述示例看起来如下所示:
gl.LinearEncoding(warps=[[2]], block=[[1]])
因为我们首先沿着 CTA 进行分片,然后沿着 warp 进行分片。
import importlib
import pytest
import torch
import triton
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
allocate_tensor_memory,
clc,
tcgen05_commit,
tcgen05_mma,
tcgen05_mma_barrier_count,
tensor_memory_descriptor,
)
from triton.experimental.gluon.language.nvidia.hopper import mbarrier, tma
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
# Re-use baseline tutorials for comparisons.
t8 = importlib.import_module("08-warp-specialization")
def is_hopper_or_newer():
if not torch.cuda.is_available():
return False
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] >= 9
def is_blackwell():
if not torch.cuda.is_available():
return False
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 10
if __name__ == "__main__" and not is_blackwell():
raise RuntimeError("This tutorial requires a Blackwell NVIDIA GPU")
def tflops(ms, M, N, K):
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
def gbps(ms, num_bytes):
return num_bytes * 1e-9 / (ms * 1e-3)
def pick_multicta_softmax_config(n_cols):
warp_thresholds = [(3072, 1), (6144, 2)]
cluster_thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
num_warps = next((v for limit, v in warp_thresholds if n_cols <= limit), 4)
cluster_n = next((v for limit, v in cluster_thresholds if n_cols <= limit), 16)
return {
"num_warps": num_warps,
"num_ctas": cluster_n,
}
Multi-CTA 内核启动时使用 num_ctas > 1,其中 num_ctas 是 2 的幂且 num_ctas <= 16。
诸如 gl.convert_layout、gl.reduce 和 gl.sum 等布局驱动的操作,当源布局和目标布局对 CTA 维度的分片方式不同时,会自动使用集群。
下面的内核将一行数据分片到多个 CTA 上,并使用 gl.max 和 gl.sum 中的自动跨 CTA 归约来计算数值稳定的行级 Softmax。
如果没有 CGA,一旦行数据对于单个 CTA 来说太宽,我们就需要切换到迭代归约或多内核方法。
@gluon.jit
def multicta_softmax_kernel(
x_ptr,
out_ptr,
x_row_stride,
out_row_stride,
BLOCK_N: gl.constexpr,
):
pid = gl.program_id(0)
cga_layout: gl.constexpr = ((1, ), (2, ), (4, ), (8, ), (16, ))[:gl.num_ctas().bit_length() - 1]
layout: gl.constexpr = gl.BlockedLayout([4], [32], [gl.num_warps()], [0], cga_layout=cga_layout)
offs_n = gl.arange(0, BLOCK_N, layout)
mask = offs_n < BLOCK_N
row_start = pid * x_row_stride
out_row_start = pid * out_row_stride
x = gl.load(x_ptr + row_start + offs_n, mask=mask, other=-float("inf"))
row_max = gl.max(x, axis=0)
y = gl.exp(x - row_max)
row_sum = gl.sum(y, axis=0)
z = y * (1.0 / row_sum)
gl.store(out_ptr + out_row_start + offs_n, z, mask=mask)
def multicta_softmax_f32(x, out=None):
M, N = x.shape
cfg = pick_multicta_softmax_config(N)
if out is None:
out = torch.empty_like(x)
multicta_softmax_kernel[(M, )](
x,
out,
x.stride(0),
out.stride(0),
BLOCK_N=N,
num_warps=cfg["num_warps"],
num_ctas=cfg["num_ctas"],
)
return out
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
@pytest.mark.parametrize("M, N", [(64, 64), (64, 256), (16, 2**16)])
def test_multicta_softmax_f32(M, N):
x = torch.randn((M, N), device="cuda", dtype=torch.float32)
out = multicta_softmax_f32(x)
ref = torch.softmax(x, dim=1)
torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
def benchmark_multicta_softmax_f32():
if not is_hopper_or_newer():
raise RuntimeError("softmax benchmark requires Hopper or newer")
SOFTMAX_BENCH_SHAPES = [
(2**15, 2**8),
(2**15, 2**9),
(2**15, 2**10),
(2**15, 2**11),
(2**15, 2**12),
(2**15, 2**13),
(2**15, 2**14),
(2**15, 2**15),
(2**15, 2**16),
(2**14, 2**17),
(2**13, 2**18),
]
print("Benchmarking multicta_softmax")
print("============================")
print(" shape CTAs warps time (ms) bandwidth (GB/s)")
for M, N in SOFTMAX_BENCH_SHAPES:
cfg = pick_multicta_softmax_config(N)
x = torch.empty((M, N), device="cuda", dtype=torch.float32).uniform_(-1, 1)
out = torch.empty_like(x)
ms = triton.testing.do_bench_cudagraph(lambda: multicta_softmax_f32(x, out))
num_bytes = 2 * x.numel() * x.element_size()
print(f"{M:>6} x {N:<6} {cfg['num_ctas']:>4} {cfg['num_warps']:>5} {ms:>9.3f} {gbps(ms, num_bytes):>16.2f}")
benchmark_multicta_softmax_f32()
Softmax 基准测试结果
Benchmarking multicta_softmax
============================
shape CTAs warps time (ms) bandwidth (GB/s)
32768 x 256 1 1 0.018 3661.46
32768 x 512 1 1 0.020 6746.45
32768 x 1024 1 1 0.040 6740.50
32768 x 2048 1 1 0.078 6920.01
32768 x 4096 1 2 0.152 7065.25
32768 x 8192 1 4 0.301 7136.76
32768 x 16384 1 4 0.600 7157.74
32768 x 32768 2 4 1.312 6545.11
32768 x 65536 4 4 2.836 6057.26
16384 x 131072 8 4 3.142 5468.66
8192 x 262144 16 4 3.627 4736.15
我们看到,使用 multiCTA,我们能够获得非常好的全面性能。
Multi-CTA 同步
由于 CTA 可能位于不同的 SM 上,同步比在单个 CTA 内慢得多。因此,gluon 提供了一种相当保守的自动同步保证,其余的同步工作由用户负责。
当源布局和目标布局对 CTA 维度的分片方式不同时,Gluon 会在诸如 gl.convert_layout、gl.reduce 和 gl.sum 等操作之间放置同步原语。所有其他操作(如 TMA、WGMMA、TCGen5MMA 等)应由内核编写者通过 mbarrier 进行同步,就像在单 CTA 内核中所做的那样。
多 CTA mbarrier 的 cga_layout 语义略有不同:如 02-layouts.py 中所述,线性布局表示从 F_2^n 到 F_2^m 的映射。在这种情况下,cga_layout 是从 numCTAs(2 的幂)到它所代表的屏障数量的映射。例如,我们可以有一个 mbarrier 布局,其中每个 CTA 都有自己的屏障。
num_ctas: gl.constexpr = 4
bar = gl.allocate_shared_memory(gl.int64, [num_ctas], MBarrierLayout(cga_layout=[[1], [2], [4]]))
因此,我们通过列(二进制)定义 cga_layout 矩阵,它代表 3x3 单位矩阵。由于这种模式非常常见,gluon 提供了一个辅助函数来创建它。
bar = mbarrier.allocate_mbarrier()
现在,屏障布局也允许跨 CTA 同步。例如,我们可以将 8CTA 内核的 2-CTA mbarrier 定义为:
bar = gl.allocate_shared_memory(gl.int64, [4], MBarrierLayout(cga_layout=[[0], [1], [2]]))
请注意,现在非零基底只是 [1] 和 [2],因此只有 2^2 = 4 个屏障。由于它是 8 CTA 内核,因此有 2^3 = 8 个基底。该布局现在在第 0 列上有广播。这意味着任何在第 0 位上仅有差异的 CTA 都将共享一个屏障。例如,CTA0 和 CTA1 将共享一个屏障,CTA2 和 CTA3 也是如此,以此类推。引导 CTA 是组中最小的 CTA ID。对于此布局,偶数 CTA ID 是引导 CTA。
通常,mbarrier cta_layout 是一个序列 [[2**i] for i in range(k)],其中 k <= log2(num_ctas),并交替插入 log2(num_ctas) - k 个零。
所有作用于屏障的操作都能自然地推广到多 CTA 屏障。更具体地说:
mbarrier.init将计数参数乘以组中的 CTA 数量,并且仅在引导 CTA 上初始化;mbarrier.expect将size_per_cta参数乘以组中的 CTA 数量,并且仅在引导 CTA 上预期。由于 expect 计算为一次到达,所有非引导 CTA 也将向引导 CTA 发出一次到达信号;mbarrier.arrive组中的每个 CTA 都会在引导 CTA 上到达;mbarrier.wait仅引导 CTA 等待屏障;
跨 CTA 使用的屏障需要一个额外的排序规则:每个相关的 mbarrier.init 必须在任何 CTA 使用该屏障之前完成。这既适用于跨 CTA 本身的屏障,也适用于由组播或 2CTA 操作消耗的按 CTA 的屏障。编译器在顶级 init 序列之后和第一次使用之前插入所需的 fence.mbarrier_init.release.cluster 以及一个松散的集群屏障。由于集群屏障必须在 warp_specialize 之外执行,请在进入 warp 专用化之前在内核的顶层部分初始化这些屏障。
关于同步的最后一点说明:cluster.arrive / cluster.wait(即 CGA 屏障,等同于 CTA 的 bar.sync)必须由内核中的所有线程执行。因此,它们不能在 warp_specialize 块内使用。
此外,诸如 convert_layout、reduce、sum、max 等操作在跨越 CTA 时会发出 CGA 屏障。因此,当这些操作可能跨越多个 CTA 时,也不允许在 warp_specialize 块内使用。
2CTA TCGen5MMA
在 2CTA 模式下,tcgen05_mma 指令使用来自对中每一个其他 CTA(即 CTA0 和 CTA1,CTA2 和 CTA3 等)的数据来计算结果。用数学术语来说,它计算两个操作数的外部积,其中 LHS 持有沿 M 维度分片的输入,RHS 持有沿 N 维度分片的输入。在 cga_layouts 方面,LHS 的第一个基底等于 (1, 0),RHS 的第一个基底等于 (0, 1)。累加器也共享为 (1, 0)。
与单 CTA 一样,TensorMemoryLayout 的 blockM 形状就是指令的形状,可以是 64 或 128。
上面的外部积布局描述了所需的数据放置:LHS (1, 0),RHS (0, 1),累加器/输出 (1, 0)。在此之上,启用 2CTA 有两个必需的代码更改,以及等待 MMA 完成时的第三个更改:
使用
TensorMemoryLayout(..., two_ctas=True)在累加器上选择 2CTA 模式;使用
mbarrier.allocate_mbarrier(..., two_ctas=True)创建任何在引导 CTA 发出tcgen05_mma之前必须等待的屏障;初始化用于等待 MMA 完成的 mbarrier 时,将
two_ctas=...传递给tcgen05_mma_barrier_count。
这种 MMA 前的屏障需要 two_ctas=True,以便引导 CTA 在发出集合操作之前等待两行。在这个单 tile 示例中,这就是 tma_bar;在下面的流水线矩阵乘法中,对应的示例是 load_ready_bars 和 acc_empty_bars。这些是 03-matmul-multicta.py 使用的相同的 2CTA 站点。
mma_bar 本身不需要 two_ctas=True:tcgen05_mma 会将其完成信号组播到对中的两个 CTA。一旦内核中的一个 tcgen05_mma 使用了 2CTA 模式,该内核中的所有 tcgen05_mma 指令都必须使用 2CTA 模式。
内核 two_cta_tcgen05_kernel 展示了单个 tile 上的 2CTA TCGen5MMA 模式。
值得注意的是,一旦 TMA 必须等待 tcgen05_mma,模式就会发生一些变化。我们将在下一节中讨论这个问题。
@gluon.jit
def two_cta_tcgen05_kernel(a_desc, b_desc, c_desc):
gl.static_assert(gl.num_ctas() == 2)
cluster_m: gl.constexpr = a_desc.block_shape[0]
tile_n: gl.constexpr = b_desc.block_shape[1]
cta_m: gl.constexpr = cluster_m // 2
cga_layout: gl.constexpr = c_desc.layout.cga_layout
smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)
tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
mma_bar = mbarrier.allocate_mbarrier()
mbarrier.init(tma_bar, count=1)
mbarrier.init(mma_bar, count=1)
mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_load(a_desc, [0, 0], tma_bar, smem_a)
tma.async_load(b_desc, [0, 0], tma_bar, smem_b)
mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b])
mbarrier.invalidate(tma_bar)
acc_layout: gl.constexpr = TensorMemoryLayout(
block=(cta_m, tile_n),
col_stride=1,
cga_layout=cga_layout,
two_ctas=True,
)
acc = allocate_tensor_memory(gl.float32, [cluster_m, tile_n], acc_layout)
tcgen05_mma(smem_a, smem_b, acc, use_acc=False, mbarriers=[mma_bar])
mbarrier.wait(mma_bar, phase=0, deps=[smem_a, smem_b])
mbarrier.invalidate(mma_bar)
c_smem = gl.allocate_shared_memory(c_desc.dtype, c_desc.block_shape, c_desc.layout)
c_smem.store(acc.load().to(c_desc.dtype))
tma.async_copy_shared_to_global(c_desc, [0, 0], c_smem)
def run_two_cta_tcgen05(a, b, c):
M, N, K = a.shape[0], b.shape[1], a.shape[1]
a_layout = gl.NVMMASharedLayout.get_default_for([M, K], gl.float16, cga_layout=[(1, 0)])
b_layout = gl.NVMMASharedLayout.get_default_for([K, N], gl.float16, cga_layout=[(0, 1)])
c_layout = gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=[(1, 0)])
a_desc = TensorDescriptor.from_tensor(a, [M, K], a_layout)
b_desc = TensorDescriptor.from_tensor(b, [K, N], b_layout)
c_desc = TensorDescriptor.from_tensor(c, [M, N], c_layout)
two_cta_tcgen05_kernel[(1, )](a_desc, b_desc, c_desc, num_warps=4, num_ctas=2)
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_two_cta_tcgen05():
M, N, K = 256, 128, 64
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
c = torch.empty((M, N), device="cuda", dtype=torch.float16)
run_two_cta_tcgen05(a, b, c)
torch.testing.assert_close(c, torch.matmul(a, b), atol=1e-1, rtol=1e-2)
相对于单 CTA 情况,下面的代码因此改变了操作数布局,在累加器布局上使用 two_ctas=True,并在 TMA 交接屏障上使用 two_ctas=True。mbarrier.expect 字节计数保持为每个 CTA 的,而不是整个对的。
请注意,一旦在 for 循环中和/或与带组播的 TMA 一起使用时,会有一些额外的变化。下一节将详细介绍。
带组播的 TMA
从 Hopper 开始,TMA 具备了向多个 CTA 组播数据的能力。这在 Hopper 的 multi-CTA 内核中非常有用,因为 wgmma 没有 2CTA 模式,或者在 Blackwell+ 内核上使用超过 2 个 CTA 时非常有用。
在这种情况下,对于累加器的 cga_layout,我们可以按照以下方式计算 A 和 B 的布局:
# Example cga_layout
cga_layout = [(1, 0), (2, 0), (0, 1)]
def get_cga_layout(layout, op_idx, two_ctas):
assert op_idx in (0, 1)
if not layout:
return layout
# Broadcast along K (the reduction dimension)
# We multiply by 2 for op_idx == 1, as we have added the (0, 1) basis.
def broadcast(b):
mul = 2 if two_ctas else 1
return (b[0], 0) if op_idx == 0 else (0, mul * b[1])
if not two_ctas:
return tuple(map(broadcast, layout))
# 2CTA performs an outer product so bases are [1, 0] and [0, 1]
assert layout[0] == (1, 0)
first = (1, 0) if op_idx == 0 else (0, 1)
return (first, *map(broadcast, layout[1:]))
cga_layout_a = get_cga_layout(cga_layout, 0, two_ctas=False)
cga_layout_b = get_cga_layout(cga_layout, 1, two_ctas=False)
换句话说,A 和 B 的 cga_layout 是 C 的布局,但将每个的内部维度置零。
这意味着 A 和/或 B 的某些基底为零,因此不同的 CTA 将加载相同的数据。组播将允许这些 CTA 有效地命中 L2 缓存。
上面的广播布局描述了哪些 CTA 应该接收相同的 tile。之后,流水线矩阵乘法有三个 multicast=True 站点:
tma.async_load(..., multicast=True)将 tile 发送到广播组中的每个 CTA;tcgen05_mma(..., multicast=True)在下一个 TMA 覆盖共享内存 tile 之前组播其 mbarrier 到达信号;tcgen05_mma_barrier_count(..., multicast=True, ...)使用匹配的到达计数初始化该 mbarrier。
TMA 同步模式本身与常规加载相同:初始化一个屏障,expect 字节计数,发出 TMA,然后等待屏障。
之所以有效,是因为 TMA 指令以原子方式将其到达信号广播到组播组中的每个 CTA,因此等待端不需要不同的 API。
TMA 目标必须使用广播 cga_layout,以便两个 CTA 接收相同的共享内存 tile。除非内核处于 2CTA 模式,否则屏障保持为常规的 1D TMA 屏障。
下面的示例刻意保持简单:它将一个 tile 组播到共享内存中,然后将该 tile 原样存回全局内存。
@gluon.jit
def tma_multicast_copy_kernel(in_desc, out_desc):
gl.static_assert(gl.num_ctas() == 2)
smem = gl.allocate_shared_memory(in_desc.dtype, in_desc.block_shape, in_desc.layout)
# This kernel is not in 2CTA mode, so the TMA barrier is per-CTA.
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, in_desc.nbytes_per_cta)
tma.async_load(in_desc, [0, 0], bar, smem, multicast=True)
mbarrier.wait(bar, phase=0, deps=[smem])
tma.async_copy_shared_to_global(out_desc, [0, 0], smem)
def run_tma_multicast_copy(inp, out):
layout = gl.NVMMASharedLayout.get_default_for(inp.shape, gl.float16, cga_layout=[[0, 0]])
in_desc = TensorDescriptor.from_tensor(inp, inp.shape, layout)
out_desc = TensorDescriptor.from_tensor(out, inp.shape, layout)
tma_multicast_copy_kernel[(1, )](in_desc, out_desc, num_warps=4, num_ctas=2)
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def test_tma_multicast_copy():
M, N = 128, 128
inp = torch.randn((M, N), device="cuda", dtype=torch.float16)
out = torch.empty_like(inp)
run_tma_multicast_copy(inp, out)
torch.testing.assert_close(out, inp, atol=0, rtol=0)
循环中的 TMA 到 MMA
在这里,我们说明了将 TMA(带或不带组播)混合到 tcgen05_mma 流水线中的完全通用方法。
在这种情况下,tcgen05_mma 指令需要等待其组播组中的所有 CTA 完成,然后才能继续下一次迭代,否则下一次迭代的 TMA 加载会在前一次迭代完全消耗完之前覆盖其共享内存数据。
因此,我们需要使用 tcgen05_mma_barrier_count 计算组播组中的 CTA 数量。同样,我们在 tcgen05_mma 指令上设置 multicast=True 标志,以说明它必须等待组播组完成才能继续。
这些函数是通用的,因此这种形式的模式也适用于非组播内核或非 2CTA 内核。
@gluon.jit
def tma_tcgen05_kernel(a_desc, b_desc, out_desc, NUM_K_TILES: gl.constexpr, acc_tmem_layout: gl.constexpr):
block_m: gl.constexpr = a_desc.block_shape[0]
block_k: gl.constexpr = a_desc.block_shape[1]
block_n: gl.constexpr = b_desc.block_shape[1]
smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)
acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout)
tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
mma_bar = mbarrier.allocate_mbarrier()
mbarrier.init(tma_bar, count=1)
mbarrier.init(
mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True,
two_ctas=acc_tmem.type.layout.two_ctas))
phase_tma = 0
phase_mma = 0
for k in range(NUM_K_TILES):
mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_load(a_desc, [0, k * block_k], tma_bar, smem_a, multicast=True)
tma.async_load(b_desc, [k * block_k, 0], tma_bar, smem_b, multicast=True)
mbarrier.wait(tma_bar, phase=phase_tma, deps=[smem_a, smem_b])
phase_tma ^= 1
tcgen05_mma(smem_a, smem_b, acc_tmem, use_acc=(k != 0), multicast=True, mbarriers=[mma_bar])
mbarrier.wait(mma_bar, phase=phase_mma, deps=[smem_a, smem_b])
phase_mma ^= 1
mbarrier.invalidate(tma_bar)
mbarrier.invalidate(mma_bar)
out_smem = gl.allocate_shared_memory(out_desc.dtype, out_desc.block_shape, out_desc.layout)
out_smem.store(acc_tmem.load().to(out_desc.dtype))
tma.async_copy_shared_to_global(out_desc, [0, 0], out_smem)
def tma_tcgen05_example(a, b):
BLOCK_M = 512
BLOCK_N = 128
BLOCK_K = 64
NUM_K_TILES = 2
cga_layout_a = ((1, 0), (2, 0))
cga_layout_b = ((0, 1), (0, 0))
cga_layout_c = ((1, 0), (2, 0))
M, K = a.shape
Kb, N = b.shape
if K != Kb:
raise ValueError(f"inner dimensions must match, got {K} and {Kb}")
if M != BLOCK_M or N != BLOCK_N or K != BLOCK_K * NUM_K_TILES:
raise ValueError(f"expected shapes {(BLOCK_M, BLOCK_K * NUM_K_TILES)} x "
f"{(BLOCK_K * NUM_K_TILES, BLOCK_N)}, got {tuple(a.shape)} x {tuple(b.shape)}")
out = torch.empty((M, N), device="cuda", dtype=torch.float16)
a_layout = gl.NVMMASharedLayout.get_default_for([M, BLOCK_K], gl.float16, cga_layout=cga_layout_a)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, N], gl.float16, cga_layout=cga_layout_b)
c_layout = gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=cga_layout_c)
acc_tmem_layout = TensorMemoryLayout(block=(128, N), col_stride=1, cga_layout=cga_layout_c, two_ctas=True)
a_desc = TensorDescriptor.from_tensor(a, [M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(b, [BLOCK_K, N], b_layout)
c_desc = TensorDescriptor.from_tensor(out, [M, N], c_layout)
tma_tcgen05_kernel[(1, )](
a_desc,
b_desc,
c_desc,
NUM_K_TILES,
acc_tmem_layout,
num_warps=4,
num_ctas=4,
)
return out
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_tma_tcgen05():
M = 512
N = 128
K = 128
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
out = tma_tcgen05_example(a, b)
torch.testing.assert_close(out, torch.matmul(a, b), atol=1e-1, rtol=1e-2)
极速矩阵乘法内核
在这里,我们说明了编写矩阵乘法内核的完全通用方法,该内核使用 TMA(可能带有组播)到 warp 专用化的 tcgen05_mma 流水线中。
对于此示例,我们将 12-cluster-launch-control.py 中提出的 CLC 思想推广到一个 warp 专用化的内核,方法是添加一个处理 CLC 生成的新分区,并将其广播给所有 CTA。
我们还使用了一个名为 _planar_snake 的额外辅助函数来调整程序 ID,以提高 L2 局部性。
Counter = t8.Counter
cublas = t8.cublas
@gluon.constexpr_function
def get_split_dim(cga_layout, dim):
return 1 << sum(b[dim] != 0 for b in cga_layout)
@gluon.jit
def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim: gl.constexpr, tile_width: gl.constexpr):
major_size = n_tiles if minor_dim == 0 else m_tiles
minor_size = m_tiles if minor_dim == 0 else n_tiles
full_minor_tiles = minor_size // tile_width
full_minor_size = full_minor_tiles * tile_width
full_elements = full_minor_tiles * tile_width * major_size
minor_tile_idx = lin_idx // (tile_width * major_size)
full_minor_within = lin_idx % tile_width
full_major_within = (lin_idx // tile_width) % major_size
full_minor = minor_tile_idx * tile_width + full_minor_within
full_major = gl.where((minor_tile_idx % 2) == 0, full_major_within, major_size - 1 - full_major_within)
partial_width = minor_size - full_minor_size
partial_width = gl.where(partial_width > 0, partial_width, 1)
partial_lin = lin_idx - full_elements
partial_minor_within = partial_lin % partial_width
partial_major_within = (partial_lin // partial_width) % major_size
partial_minor = minor_tile_idx * tile_width + partial_minor_within
partial_major = gl.where((minor_tile_idx % 2) == 0, partial_major_within, major_size - 1 - partial_major_within)
in_full_tile = lin_idx < full_elements
minor = gl.where(in_full_tile, full_minor, partial_minor)
major = gl.where(in_full_tile, full_major, partial_major)
if minor_dim == 0:
return minor, major
return major, minor
@gluon.aggregate
class ClcTileSchedulerConsumer:
has_work: gl.tensor
tile_id: gl.tensor
pid_m: gl.tensor
pid_n: gl.tensor
num_pid_m: gl.tensor
num_pid_n: gl.tensor
TILE_M: gl.constexpr
TILE_N: gl.constexpr
MINOR_DIM: gl.constexpr
GRID_TILE_WIDTH: gl.constexpr
clc_result_buffers: gl.shared_memory_descriptor
clc_barriers: gl.shared_memory_descriptor
clc_planar_pid_buffers: gl.shared_memory_descriptor
clc_planar_ready_bars: gl.shared_memory_descriptor
clc_consumed_bars: gl.shared_memory_descriptor
counter: Counter
consumed_counter: Counter
@gluon.jit
def initialize(M, N, TILE_M: gl.constexpr, TILE_N: gl.constexpr, MINOR_DIM: gl.constexpr,
GRID_TILE_WIDTH: gl.constexpr, clc_result_buffers, clc_barriers, clc_planar_pid_buffers,
clc_planar_ready_bars, clc_consumed_bars):
tile_id = gl.program_id(axis=0)
num_pid_m = gl.cdiv(M, TILE_M)
num_pid_n = gl.cdiv(N, TILE_N)
pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, MINOR_DIM, GRID_TILE_WIDTH)
return ClcTileSchedulerConsumer(
gl.to_tensor(True),
tile_id,
pid_m,
pid_n,
num_pid_m,
num_pid_n,
TILE_M,
TILE_N,
MINOR_DIM,
GRID_TILE_WIDTH,
clc_result_buffers,
clc_barriers,
clc_planar_pid_buffers,
clc_planar_ready_bars,
clc_consumed_bars,
Counter.create(0, clc_barriers.shape[0]),
Counter.create(0, clc_barriers.shape[0]),
)
@gluon.jit
def get_offsets(self):
return self.pid_m * self.TILE_M, self.pid_n * self.TILE_N
@gluon.jit
def step(self, iteration):
consumed_counter = self.consumed_counter
if iteration > 0:
mbarrier.arrive(self.clc_consumed_bars.index(consumed_counter.index))
consumed_counter = consumed_counter.next()
counter = self.counter
barrier = self.clc_barriers.index(counter.index)
result = self.clc_result_buffers.index(counter.index)
mbarrier.wait(barrier, counter.phase)
clc_res = clc.load_result(result)
mbarrier.wait(self.clc_planar_ready_bars.index(counter.index), counter.phase)
planar_slot = self.clc_planar_pid_buffers.index(counter.index)
planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
[[0]] * (gl.num_ctas().bit_length() - 1))
packed_pid = planar_slot.load(planar_layout).reshape([])
pid_m = ((packed_pid >> 32) & 0xFFFFFFFF).to(gl.int32)
pid_n = (packed_pid & 0xFFFFFFFF).to(gl.int32)
has_work = clc_res.is_canceled()
tile_id = self.tile_id
if has_work:
tile_id = clc_res.program_id(0)
return ClcTileSchedulerConsumer(
has_work,
tile_id,
pid_m,
pid_n,
self.num_pid_m,
self.num_pid_n,
self.TILE_M,
self.TILE_N,
self.MINOR_DIM,
self.GRID_TILE_WIDTH,
self.clc_result_buffers,
self.clc_barriers,
self.clc_planar_pid_buffers,
self.clc_planar_ready_bars,
self.clc_consumed_bars,
counter.next(),
consumed_counter,
)
@gluon.aggregate
class MatmulPartitionArgs:
a_desc: tma.tensor_descriptor
b_desc: tma.tensor_descriptor
c_desc: tma.tensor_descriptor
a_bufs: gl.shared_memory_descriptor
b_bufs: gl.shared_memory_descriptor
load_empty_bars: gl.shared_memory_descriptor
load_ready_bars: gl.shared_memory_descriptor
acc_bufs: tensor_memory_descriptor
acc_empty_bars: gl.shared_memory_descriptor
acc_ready_bars: gl.shared_memory_descriptor
clc_result_buffers: gl.shared_memory_descriptor
clc_barriers: gl.shared_memory_descriptor
clc_planar_pid_buffers: gl.shared_memory_descriptor
clc_planar_ready_bars: gl.shared_memory_descriptor
clc_consumed_bars: gl.shared_memory_descriptor
MINOR_DIM: gl.constexpr
GRID_TILE_WIDTH: gl.constexpr
SUBTILE_STAGES: gl.constexpr
@gluon.jit
def get_clc_consumer(self):
return ClcTileSchedulerConsumer.initialize(
self.c_desc.shape[0],
self.c_desc.shape[1],
self.a_desc.block_shape[0],
self.b_desc.block_shape[1],
self.MINOR_DIM,
self.GRID_TILE_WIDTH,
self.clc_result_buffers,
self.clc_barriers,
self.clc_planar_pid_buffers,
self.clc_planar_ready_bars,
self.clc_consumed_bars,
)
@gluon.jit
def matmul_clc_partition(p):
tile_m: gl.constexpr = p.a_desc.block_shape[0]
tile_n: gl.constexpr = p.b_desc.block_shape[1]
has_work = gl.to_tensor(True)
num_pid_m = gl.cdiv(p.c_desc.shape[0], tile_m)
num_pid_n = gl.cdiv(p.c_desc.shape[1], tile_n)
state = Counter.create(0, p.clc_barriers.shape[0])
consumed_state = Counter.create(1, p.clc_barriers.shape[0])
acc_stages: gl.constexpr = p.clc_barriers.shape[0]
i = 0
while has_work:
mbarrier.wait(p.clc_consumed_bars.index(consumed_state.index), consumed_state.phase, pred=(i >= acc_stages))
barrier = p.clc_barriers.index(state.index)
result = p.clc_result_buffers.index(state.index)
mbarrier.expect(barrier, 16)
clc.try_cancel(result, barrier)
mbarrier.wait(barrier, state.phase)
clc_res = clc.load_result(result)
has_work = clc_res.is_canceled()
pid_m = gl.to_tensor(0)
pid_n = gl.to_tensor(0)
if has_work:
tile_id = clc_res.program_id(0)
pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, p.MINOR_DIM, p.GRID_TILE_WIDTH)
packed_pid = (pid_m.to(gl.int64) << 32) | (pid_n.to(gl.int64) & 0xFFFFFFFF)
planar_slot = p.clc_planar_pid_buffers.index(state.index)
planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
[[0]] * (gl.num_ctas().bit_length() - 1))
planar_slot.store(gl.full([1], packed_pid, gl.int64, layout=planar_layout))
mbarrier.arrive(p.clc_planar_ready_bars.index(state.index))
state = state.next()
consumed_state = consumed_state.next()
i += 1
@gluon.jit
def matmul_load_partition(p):
block_k: gl.constexpr = p.a_desc.block_shape[1]
K = p.a_desc.shape[1]
concurrent_loads: gl.constexpr = p.load_ready_bars.shape[0]
state = Counter.create(1, concurrent_loads)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
off_m, off_n = scheduler.get_offsets()
for k in range(0, K, block_k):
pred = (i > 0) or (k >= block_k * concurrent_loads)
mbarrier.wait(p.load_empty_bars.index(state.index), state.phase, pred=pred)
bar = p.load_ready_bars.index(state.index)
mbarrier.expect(bar, p.a_desc.nbytes_per_cta + p.b_desc.nbytes_per_cta)
tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True)
tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True)
state = state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def matmul_mma_partition(p):
block_k: gl.constexpr = p.a_desc.block_shape[1]
K = p.a_desc.shape[1]
acc_stages: gl.constexpr = p.acc_empty_bars.shape[0]
load_state = Counter.create(0, p.load_empty_bars.shape[0])
acc_state = Counter.create(1, acc_stages)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
acc_buf = p.acc_bufs.index(acc_state.index)
mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase, pred=(i >= acc_stages))
use_acc = False
for k in range(0, K, block_k):
mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase)
tcgen05_mma(
p.a_bufs.index(load_state.index),
p.b_bufs.index(load_state.index),
acc_buf,
use_acc=use_acc,
multicast=True,
mbarriers=[p.load_empty_bars.index(load_state.index)],
)
load_state = load_state.next()
use_acc = True
tcgen05_commit(p.acc_ready_bars.index(acc_state.index), descs=[p.a_bufs.index(0), p.b_bufs.index(0)])
acc_state = acc_state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def matmul_epilogue_partition(p):
tile_m: gl.constexpr = p.a_desc.block_shape[0]
tile_n: gl.constexpr = p.b_desc.block_shape[1]
split_tile_n: gl.constexpr = p.c_desc.block_shape[1]
# Separate knobs: SUBTILE_STAGES controls shared-memory usage,
# and SUBTILE_FACTOR is the maximum number of subtiles into which we can split the tile.
subtile_factor: gl.constexpr = tile_n // split_tile_n
subtile_stages: gl.constexpr = p.SUBTILE_STAGES
acc_stages: gl.constexpr = p.acc_empty_bars.shape[0]
dtype: gl.constexpr = p.c_desc.dtype
acc_state = Counter.create(0, acc_stages)
acc_smems = gl.allocate_shared_memory(dtype, [subtile_stages, tile_m, split_tile_n], p.c_desc.layout)
sub_acc_state = Counter.create(0, subtile_stages)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
off_m, off_n = scheduler.get_offsets()
mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
acc_buf = p.acc_bufs.index(acc_state.index)
for s in gl.static_range(subtile_factor):
acc_sub = acc_buf.slice(split_tile_n * s, split_tile_n)
acc_smem = acc_smems.index(sub_acc_state.index)
acc = acc_sub.load().to(dtype)
tma.store_wait(pendings=subtile_stages - 1)
acc_smem.store(acc)
tma.async_copy_shared_to_global(p.c_desc, [off_m, off_n + split_tile_n * s], acc_smem)
sub_acc_state = sub_acc_state.next()
mbarrier.arrive(p.acc_empty_bars.index(acc_state.index))
acc_state = acc_state.next()
scheduler = scheduler.step(i)
i += 1
# The entry kernel allocates and initializes every barrier before
# `gl.warp_specialize`. Some of these barriers are cross-CTA barriers, and some
# otherwise per-CTA barriers are consumed by multicast or 2CTA operations. The
# compiler inserts the required init fence and relaxed cluster barrier after
# this top-level init sequence, before any partition can use the barriers.
# Keeping the init sequence here also keeps that mandatory cluster sync outside
# `warp_specialize`.
@gluon.jit
def matmul_multicta_kernel(
a_desc,
b_desc,
c_desc,
M,
N,
K,
BLOCK_SIZE_M: gl.constexpr,
BLOCK_SIZE_N: gl.constexpr,
BLOCK_SIZE_K: gl.constexpr,
GRID_MINOR_DIM: gl.constexpr,
GRID_TILE_WIDTH: gl.constexpr,
STAGES: gl.constexpr,
ACC_STAGES: gl.constexpr,
CGA_LAYOUT: gl.constexpr,
EPILOGUE_SIZE_N: gl.constexpr,
SUBTILE_STAGES: gl.constexpr,
):
block_m: gl.constexpr = a_desc.block_shape[0]
block_n: gl.constexpr = b_desc.block_shape[1]
two_ctas: gl.constexpr = gl.num_ctas() > 1
n_partitions: gl.constexpr = 4
dtype: gl.constexpr = a_desc.dtype
a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout)
tmem_layout: gl.constexpr = TensorMemoryLayout(
[BLOCK_SIZE_M, block_n // get_split_dim(CGA_LAYOUT, 1)],
col_stride=1,
cga_layout=CGA_LAYOUT,
two_ctas=two_ctas,
)
acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, block_m, block_n], tmem_layout)
mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True,
two_ctas=acc_bufs.index(0).type.layout.two_ctas)
load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES)
load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas)
for i in gl.static_range(STAGES):
mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count)
mbarrier.init(load_ready_bars.index(i), count=1)
acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas)
acc_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
for i in gl.static_range(ACC_STAGES):
mbarrier.init(acc_empty_bars.index(i), count=1)
mbarrier.init(acc_ready_bars.index(i), count=mma_barrier_count)
clc_barriers = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
clc_planar_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
clc_consumed_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas)
for i in gl.static_range(ACC_STAGES):
mbarrier.init(clc_barriers.index(i), count=1)
mbarrier.init(clc_planar_ready_bars.index(i), count=1)
mbarrier.init(clc_consumed_bars.index(i), count=n_partitions - 1)
cga_layout: gl.constexpr = [[0]] * (gl.num_ctas().bit_length() - 1)
clc_layout: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, [0], cga_layout=cga_layout)
clc_result_buffers = gl.allocate_shared_memory(
gl.int64,
[clc_barriers.shape[0], 2],
clc_layout,
)
clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout)
p = MatmulPartitionArgs(
a_desc,
b_desc,
c_desc,
a_bufs,
b_bufs,
load_empty_bars,
load_ready_bars,
acc_bufs,
acc_empty_bars,
acc_ready_bars,
clc_result_buffers,
clc_barriers,
clc_planar_pid_buffers,
clc_planar_ready_bars,
clc_consumed_bars,
GRID_MINOR_DIM,
GRID_TILE_WIDTH,
SUBTILE_STAGES,
)
gl.warp_specialize([
(matmul_epilogue_partition, (p, )),
(matmul_load_partition, (p, )),
(matmul_mma_partition, (p, )),
(matmul_clc_partition, (p, )),
], [1, 1, 1], [24, 24, 24])
def matmul_multicta(
a,
b,
out=None,
*,
block_size_m=128,
block_size_n=256,
block_size_k=64,
grid_minor_dim=0,
grid_tile_width=16,
stages=6,
acc_stages=2,
cga_layout=((1, 0), ),
epilogue_size_n=32,
subtile_stages=4,
):
if block_size_n // get_split_dim(cga_layout, 1) > 256:
raise ValueError(
f"cga_layout={list(cga_layout)} only supports BLOCK_SIZE_N <= {256 * get_split_dim(cga_layout, 1)}")
M, K = a.shape
K1, N = b.shape
if K != K1:
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
if a.dtype != torch.float16 or b.dtype != torch.float16:
raise ValueError("matmul only supports fp16 inputs")
if out is None:
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
if out.shape != (M, N):
raise ValueError(f"Output has invalid shape {out.shape}, expected {(M, N)}")
c = out
tile_m = block_size_m * get_split_dim(cga_layout, 0)
two_ctas = bool(cga_layout)
a_layout = gl.NVMMASharedLayout.get_default_for([tile_m, block_size_k], gl.float16,
cga_layout=get_cga_layout(cga_layout, 0, two_ctas))
b_layout = gl.NVMMASharedLayout.get_default_for([block_size_k, block_size_n], gl.float16,
cga_layout=get_cga_layout(cga_layout, 1, two_ctas))
c_layout = gl.NVMMASharedLayout.get_default_for([tile_m, epilogue_size_n], gl.float16, cga_layout=cga_layout)
a_desc = TensorDescriptor.from_tensor(a, [tile_m, block_size_k], a_layout)
b_desc = TensorDescriptor.from_tensor(b, [block_size_k, block_size_n], b_layout)
c_desc = TensorDescriptor.from_tensor(c, [tile_m, epilogue_size_n], c_layout)
def grid(meta):
tile_m = meta["BLOCK_SIZE_M"] * get_split_dim(meta["CGA_LAYOUT"], 0)
tile_n = meta["BLOCK_SIZE_N"]
num_tiles = triton.cdiv(M, tile_m) * triton.cdiv(N, tile_n)
return (num_tiles, )
matmul_multicta_kernel[grid](
a_desc,
b_desc,
c_desc,
M,
N,
K,
block_size_m,
block_size_n,
block_size_k,
grid_minor_dim,
grid_tile_width,
stages,
acc_stages,
cga_layout,
epilogue_size_n,
subtile_stages,
num_warps=4,
num_ctas=2**len(cga_layout),
)
return c
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_matmul_multicta():
M, N, K = 1024, 1024, 512
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
c = matmul_multicta(a, b)
torch.testing.assert_close(c, torch.matmul(a, b), atol=1e-1, rtol=1e-2)
if __name__ == "__main__" and is_blackwell():
print("Benchmarking matmul_multicta")
print("============================")
cfg = {
"block_size_m": 128,
"block_size_n": 256,
"block_size_k": 64,
"grid_minor_dim": 0,
"grid_tile_width": 16,
"stages": 6,
"acc_stages": 2,
"cga_layout": ((1, 0), ),
"epilogue_size_n": 32,
"subtile_stages": 4,
}
M, N = 8192, 8192
C = torch.empty((M, N), device="cuda", dtype=torch.float16)
print(" K multi-CTA cublas")
for K in [2**i for i in range(9, 15)]:
A = torch.randn((M, K), device="cuda", dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
BT = B.T.contiguous()
r0 = tflops(triton.testing.do_bench(lambda: matmul_multicta(A, B, out=C, **cfg), warmup=200, rep=1000), M, N, K)
r1 = tflops(triton.testing.do_bench(lambda: cublas.matmul(A, BT, C), warmup=200, rep=1000), M, N, K)
print(f"{K:>5} {r0:>17.2f} {r1:>9.2f}")
基准测试 matmul_multicta
K multi-CTA cublas
512 1096.31 1190.98 1024 1306.07 1344.48 2048 1379.80 1374.48 4096 1444.26 1431.93 8192 1302.33 1347.82 16384 1292.40 1371.82
我们能够与 cublas 竞争,甚至在该特定配置的相关 K 范围内胜过它们。如果我们为不同的形状选择不同的配置,我们将能够在更广泛的形状范围内胜过 cublas。