注意
跳转到末尾 下载完整的示例代码。
矩阵乘法¶
在本教程中,您将编写一个非常简短的高性能 FP16 矩阵乘法内核,其性能可与 cuBLAS 或 rocBLAS 相媲美。
您将特别了解:
块级矩阵乘法。
多维指针算术。
程序重排序以提高 L2 缓存命中率。
自动性能调优。
动机¶
矩阵乘法是大多数现代高性能计算系统的关键构建块。它们非常难以优化,因此其实现通常由硬件厂商自己完成,作为所谓的“内核库”(例如 cuBLAS)的一部分。不幸的是,这些库通常是专有的,无法轻松定制以满足现代深度学习工作负载的需求(例如,融合激活函数)。在本教程中,您将学习如何使用 Triton 自己实现高效的矩阵乘法,这种方式易于定制和扩展。
大致来说,我们将编写的内核将实现以下分块算法来将一个 (M, K) 矩阵乘以一个 (K, N) 矩阵:
# Do in parallel for m in range(0, M, BLOCK_SIZE_M): # Do in parallel for n in range(0, N, BLOCK_SIZE_N): acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) for k in range(0, K, BLOCK_SIZE_K): a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] acc += dot(a, b) C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc
其中双重嵌套的 for 循环的每次迭代都由一个专门的 Triton 程序实例执行。
计算内核¶
实际上,上述算法在 Triton 中实现起来相当简单。主要困难在于计算内层循环中必须读取的 A
和 B
块的内存位置。为此,我们需要多维指针算术。
指针算术¶
对于一个行主序二维张量 X
,X[i, j]
的内存位置由 &X[i, j] = X + i*stride_xi + j*stride_xj
给出。因此,对于 A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]
和 B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
的指针块可以用伪代码定义为:
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
这意味着 A 和 B 块的指针可以在 Triton 中按以下代码进行初始化(即 k=0
)。另请注意,当 M
不是 BLOCK_SIZE_M
的倍数或 N
不是 BLOCK_SIZE_N
的倍数时,我们需要额外的取模运算来处理这种情况,此时我们可以用一些不会影响结果的无用值填充数据。对于 K
维度,我们将稍后使用掩码加载语义来处理。
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
然后在内层循环中按如下方式更新:
a_ptrs += BLOCK_SIZE_K * stride_ak; b_ptrs += BLOCK_SIZE_K * stride_bk;
L2 缓存优化¶
如上所述,每个程序实例计算 C
的一个 [BLOCK_SIZE_M, BLOCK_SIZE_N]
块。重要的是要记住计算这些块的顺序确实很重要,因为它会影响我们程序的 L2 缓存命中率,不幸的是,简单的行主序排列:
pid = tl.program_id(axis=0) grid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // grid_n pid_n = pid % grid_n
是远远不够的。
一个可能的解决方案是以促进数据重用的顺序启动块。这可以通过在切换到下一列之前,将块“超级分组”为 GROUP_M
行的组来完成:
# Program ID pid = tl.program_id(axis=0) # Number of program ids along the M axis num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # Number of programs ids along the N axis num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # Number of programs in group num_pid_in_group = GROUP_SIZE_M * num_pid_n # Id of the group this program is in group_id = pid // num_pid_in_group # Row-id of the first program in the group first_pid_m = group_id * GROUP_SIZE_M # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # *Within groups*, programs are ordered in a column-major order # Row-id of the program in the *launch grid* pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) # Col-id of the program in the *launch grid* pid_n = (pid % num_pid_in_group) // group_size_m
例如,在下面的矩阵乘法中,每个矩阵是 9x9 的块,我们可以看到如果以行主序计算输出,我们需要加载 90 个块到 SRAM 来计算前 9 个输出块,但如果以分组排列方式计算,我们只需要加载 54 个块。
实际上,这可以在某些硬件架构上将我们的矩阵乘法内核性能提高 10% 以上(例如,在 A100 上从 220 TFLOPS 提升到 245 TFLOPS)。
最终结果¶
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def is_hip_cdna2():
target = triton.runtime.driver.active.get_current_target()
return target.backend == 'hip' and target.arch == 'gfx90a'
def get_cuda_autotune_config():
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
# Good config for fp8 inputs.
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4)
]
def get_hip_autotune_config():
return [
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
num_warps=4, num_stages=2),
]
def get_autotune_config():
if is_cuda():
return get_cuda_autotune_config()
else:
return get_hip_autotune_config()
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=get_autotune_config(),
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACTIVATION: tl.constexpr #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)
现在我们可以创建一个便利的封装函数,它只接受两个输入张量,并执行 (1) 检查形状约束; (2) 分配输出; (3) 启动上述内核。
def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
ACTIVATION=activation #
)
return c
单元测试¶
我们可以将我们的自定义矩阵乘法操作与原生的 PyTorch 实现(即 cuBLAS)进行对比测试。
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# Bigger tolerance for AMD CDNA2 devices.
# CDNA2 devices use reduced precision fp16 and bf16 and flush input and
# output denormal values to zero. Detailed info is at: https://pytorch.ac.cn/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
rtol = 1e-2 if is_hip_cdna2() else 0
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
if TORCH_HAS_FP8 and is_cuda():
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
a = a.to(torch.float8_e5m2)
# pre-transpose b for efficiency.
b = b.T
b = b.to(torch.float8_e5m2)
triton_output = matmul(a, b)
torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
print(f"triton_output_with_fp8_inputs={triton_output}")
print(f"torch_output_with_fp8_inputs={torch_output}")
if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
triton_output_with_fp16_inputs=tensor([[-10.9531, -4.7109, 15.6953, ..., -28.4062, 4.3320, -26.4219],
[ 26.8438, 10.0469, -5.4297, ..., -11.2969, -8.5312, 30.7500],
[-13.2578, 15.8516, 18.0781, ..., -21.7656, -8.6406, 10.2031],
...,
[ 40.2812, 18.6094, -25.6094, ..., -2.7598, -3.2441, 41.0000],
[ -6.1211, -16.8281, 4.4844, ..., -21.0312, 24.7031, 15.0234],
[-17.0938, -19.0000, -0.3831, ..., 21.5469, -30.2344, -13.2188]],
device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[-10.9531, -4.7109, 15.6953, ..., -28.4062, 4.3320, -26.4219],
[ 26.8438, 10.0469, -5.4297, ..., -11.2969, -8.5312, 30.7500],
[-13.2578, 15.8516, 18.0781, ..., -21.7656, -8.6406, 10.2031],
...,
[ 40.2812, 18.6094, -25.6094, ..., -2.7598, -3.2441, 41.0000],
[ -6.1211, -16.8281, 4.4844, ..., -21.0312, 24.7031, 15.0234],
[-17.0938, -19.0000, -0.3831, ..., 21.5469, -30.2344, -13.2188]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
triton_output_with_fp8_inputs=tensor([[-21.4375, 13.1719, 6.0352, ..., 28.7031, 8.6719, -40.7500],
[ 10.0000, 37.0000, -5.5664, ..., 20.9844, 46.8125, 30.8281],
[ 19.5625, -3.0078, -20.0469, ..., -2.1309, -8.0625, 12.5625],
...,
[-18.1562, -34.1562, -27.4219, ..., -27.3906, -24.0938, -12.3516],
[ -3.3945, -8.6250, -23.6562, ..., -4.1094, -3.5332, -16.0781],
[-23.9688, -3.2637, -33.6875, ..., 17.3125, -36.6250, 25.8594]],
device='cuda:0', dtype=torch.float16)
torch_output_with_fp8_inputs=tensor([[-21.4375, 13.1719, 6.0352, ..., 28.7031, 8.6719, -40.7500],
[ 10.0000, 37.0000, -5.5664, ..., 20.9844, 46.8125, 30.8281],
[ 19.5625, -3.0078, -20.0469, ..., -2.1309, -8.0625, 12.5625],
...,
[-18.1562, -34.1562, -27.4219, ..., -27.3906, -24.0938, -12.3516],
[ -3.3945, -8.6250, -23.6562, ..., -4.1094, -3.5332, -16.0781],
[-23.9688, -3.2637, -33.6875, ..., 17.3125, -36.6250, 25.8594]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
基准测试¶
方阵性能¶
这里我们重点介绍方阵,但您可以随意安排此脚本来测试任何其他矩阵形状的基准性能。
ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'
configs = []
for fp8_inputs in [False, True]:
if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
continue
configs.append(
triton.testing.Benchmark(
x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name`
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
line_vals=["triton"] if fp8_inputs else [ref_lib.lower(), "triton"], # Label name for the lines
line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"], # Line styles
styles=[("green", "-"), ("blue", "-")],
ylabel="TFLOPS", # Label name for the y-axis
plot_name="matmul-performance-" +
("fp16" if not fp8_inputs else "fp8"), # Name for the plot, used also as a file name for saving the plot.
args={"fp8_inputs": fp8_inputs},
))
@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, fp8_inputs):
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
if TORCH_HAS_FP8 and fp8_inputs:
a = a.to(torch.float8_e5m2)
b = b.T
b = b.to(torch.float8_e5m2)
quantiles = [0.5, 0.2, 0.8]
if provider == ref_lib.lower():
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=True)
matmul-performance-fp16:
M N K cuBLAS Triton
0 256.0 256.0 256.0 4.096000 4.096000
1 384.0 384.0 384.0 11.059200 12.288000
2 512.0 512.0 512.0 26.214401 26.214401
3 640.0 640.0 640.0 42.666665 42.666665
4 768.0 768.0 768.0 63.195428 58.982401
5 896.0 896.0 896.0 78.051553 82.642822
6 1024.0 1024.0 1024.0 110.376426 80.659693
7 1152.0 1152.0 1152.0 135.726544 106.642284
8 1280.0 1280.0 1280.0 157.538463 132.129034
9 1408.0 1408.0 1408.0 151.438217 118.516867
10 1536.0 1536.0 1536.0 176.947204 144.446699
11 1664.0 1664.0 1664.0 179.978245 160.694855
12 1792.0 1792.0 1792.0 172.914215 184.252856
13 1920.0 1920.0 1920.0 200.347822 150.260866
14 2048.0 2048.0 2048.0 223.696203 167.772164
15 2176.0 2176.0 2176.0 214.081356 173.479720
16 2304.0 2304.0 2304.0 236.513589 182.350177
17 2432.0 2432.0 2432.0 205.069087 175.590404
18 2560.0 2560.0 2560.0 222.911566 195.047621
19 2688.0 2688.0 2688.0 199.647657 167.845378
20 2816.0 2816.0 2816.0 210.696652 169.705085
21 2944.0 2944.0 2944.0 221.493479 175.478980
22 3072.0 3072.0 3072.0 208.941345 186.874926
23 3200.0 3200.0 3200.0 216.216207 192.192190
24 3328.0 3328.0 3328.0 208.067338 168.597886
25 3456.0 3456.0 3456.0 216.143621 175.646117
26 3584.0 3584.0 3584.0 216.663602 189.295559
27 3712.0 3712.0 3712.0 208.990259 194.731662
28 3840.0 3840.0 3840.0 214.741739 178.951459
29 3968.0 3968.0 3968.0 211.114084 182.944428
30 4096.0 4096.0 4096.0 219.668951 195.652669
matmul-performance-fp8:
M N K Triton
0 256.0 256.0 256.0 4.096000
1 384.0 384.0 384.0 12.288000
2 512.0 512.0 512.0 26.214401
3 640.0 640.0 640.0 46.545454
4 768.0 768.0 768.0 58.982401
5 896.0 896.0 896.0 87.808000
6 1024.0 1024.0 1024.0 95.325090
7 1152.0 1152.0 1152.0 119.439363
8 1280.0 1280.0 1280.0 146.285712
9 1408.0 1408.0 1408.0 132.970149
10 1536.0 1536.0 1536.0 157.286398
11 1664.0 1664.0 1664.0 157.875646
12 1792.0 1792.0 1792.0 184.252856
13 1920.0 1920.0 1920.0 166.554219
14 2048.0 2048.0 2048.0 188.508043
15 2176.0 2176.0 2176.0 181.294124
16 2304.0 2304.0 2304.0 202.439587
17 2432.0 2432.0 2432.0 195.100438
18 2560.0 2560.0 2560.0 206.088047
19 2688.0 2688.0 2688.0 190.618370
20 2816.0 2816.0 2816.0 202.856788
21 2944.0 2944.0 2944.0 205.086550
22 3072.0 3072.0 3072.0 199.377121
23 3200.0 3200.0 3200.0 202.531652
24 3328.0 3328.0 3328.0 197.778282
25 3456.0 3456.0 3456.0 204.105230
26 3584.0 3584.0 3584.0 207.656790
27 3712.0 3712.0 3712.0 205.128011
28 3840.0 3840.0 3840.0 197.133682
29 3968.0 3968.0 3968.0 209.303487
30 4096.0 4096.0 4096.0 215.784121
脚本总运行时间: (2 分 8.779 秒)