注意
跳转到末尾以下载完整的示例代码。
融合 Softmax¶
在本教程中,您将编写一个融合的 Softmax 操作,它比 PyTorch 的原生操作显著更快,适用于一类特定的矩阵:即其行可以适应 GPU SRAM 的矩阵。
通过本教程,您将了解到
内核融合对带宽受限操作的益处。
Triton 中的规约运算符。
动机¶
用于逐元素相加的自定义 GPU 内核在教育上很有价值,但在实践中作用不大。我们不妨考虑一个简单的(数值稳定化的)Softmax 操作。
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
当在 PyTorch 中朴素地实现时,计算 y = naive_softmax(x)
,其中 \(x \in R^{M \times N}\),需要从 DRAM 读取 \(5MN + 2M\) 元素并写回 \(3MN + 2M\) 元素。这显然是浪费;我们更希望有一个自定义的“融合”内核,它只读取 X 一次并在片上完成所有必要的计算。这样做只需要读写回 \(MN\) 字节,因此我们可以预期理论上加速约 4 倍(即 \((8MN + 4M) / 2MN\))。torch.jit.script 旨在自动执行这种“内核融合”,但正如我们稍后将看到的,它仍然远非理想。
计算内核¶
我们的 Softmax 内核工作原理如下:每个程序根据程序数量跨步加载输入矩阵 X 的一组行,对其进行归一化,并将结果写回输出 Y。
需要注意的是,Triton 的一个重要限制是每个块的元素数量必须是 2 的幂,因此,如果我们要处理任何可能的输入形状,我们需要在内部“填充”每行并妥善保护内存操作。
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
我们可以创建一个辅助函数,它将内核及其(元)参数排入队列,以处理任何给定的输入张量。
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
单元测试¶
我们确保在行数和列数不规则的矩阵上测试我们的内核。这将使我们能够验证我们的填充机制是否有效。
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
正如预期,结果是相同的。
基准测试¶
在这里,我们将根据输入矩阵的列数来对我们的操作进行基准测试——假设行数为 4096。然后,我们将它的性能与 (1) torch.softmax
和 (2) 上面定义的 naive_softmax
进行比较。
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch Naive Softmax
0 256.0 477.462712 703.822901 206.253863
1 384.0 664.863479 824.797972 264.411742
2 512.0 811.265215 946.849992 302.867896
3 640.0 877.802125 958.214572 332.318560
4 768.0 956.897098 1028.029727 350.327420
5 896.0 1012.016262 1074.654123 353.915462
6 1024.0 1058.936830 1095.562592 353.520017
7 1152.0 1097.253126 1038.944131 349.141133
8 1280.0 1134.885047 1076.550819 349.465844
9 1408.0 1162.877066 1104.759849 340.882321
10 1536.0 1197.078594 1142.965734 334.313559
11 1664.0 1208.426362 1162.107467 329.747045
12 1792.0 1232.047263 1194.460566 326.213875
13 1920.0 1254.221134 1198.008237 325.163443
14 2048.0 1270.868793 1224.541007 324.795622
15 2176.0 1238.981103 966.640005 325.543620
16 2304.0 1249.673662 999.671124 325.991668
17 2432.0 1274.103191 1036.851407 326.855959
18 2560.0 1285.843872 1074.691278 328.032597
19 2688.0 1289.919793 1103.789172 329.656577
20 2816.0 1302.932329 1122.763803 329.773201
21 2944.0 1321.736236 1150.159162 331.548506
22 3072.0 1320.358598 1168.747194 333.497314
23 3200.0 1339.986863 1177.635725 335.169487
24 3328.0 1341.911072 1205.791789 336.464770
25 3456.0 1345.134990 1223.641691 336.768093
26 3584.0 1356.459754 1243.844716 337.903484
27 3712.0 1365.017159 1263.592994 340.479845
28 3840.0 1367.038641 1282.181248 340.032085
29 3968.0 1374.065700 1295.951169 341.083274
30 4096.0 1388.242775 1323.372434 338.706470
31 4224.0 1327.245712 1297.914009 343.016525
32 4352.0 1333.047076 1321.803972 345.187550
33 4480.0 1340.221524 1336.673665 345.867393
34 4608.0 1354.191504 1356.376747 346.447012
35 4736.0 1352.559607 1366.564458 348.256379
36 4864.0 1361.121268 1380.510427 349.001035
37 4992.0 1366.842846 1385.203084 350.386939
38 5120.0 1370.427007 1409.077324 350.707591
39 5248.0 1371.916948 1367.277118 351.671604
40 5376.0 1372.539057 1380.455673 351.756040
41 5504.0 1375.324435 1393.329007 353.731835
42 5632.0 1387.278094 1408.649175 353.417598
43 5760.0 1389.558013 1412.832607 355.181756
44 5888.0 1386.995394 1421.744729 355.255180
45 6016.0 1392.629236 1441.846884 356.578110
46 6144.0 1402.673006 1453.298007 357.339149
47 6272.0 1406.107165 1400.065488 357.645651
48 6400.0 1407.036565 1413.802463 358.641086
49 6528.0 1408.586249 1423.978952 359.125847
50 6656.0 1411.428892 1428.567559 359.514908
51 6784.0 1416.933998 1451.190287 360.450094
52 6912.0 1421.213683 1454.336125 360.927162
53 7040.0 1413.755012 1448.599521 360.791609
54 7168.0 1413.913697 1468.536980 361.787232
55 7296.0 1421.170734 1086.708410 362.712744
56 7424.0 1427.251521 1101.957011 362.809103
57 7552.0 1424.181749 1115.340071 363.521974
58 7680.0 1428.553742 1124.356520 363.473484
59 7808.0 1425.336361 1136.225566 364.548596
60 7936.0 1427.406822 1147.406355 364.741411
61 8064.0 1429.642807 1151.662541 365.039918
62 8192.0 1427.365779 1157.348999 364.245581
63 8320.0 1388.879335 1114.048896 361.891111
64 8448.0 1383.076857 1121.800331 362.518902
65 8576.0 1385.886635 1122.658140 363.566352
66 8704.0 1383.643694 1131.870819 364.601126
67 8832.0 1394.404237 1128.487748 365.267494
68 8960.0 1388.944135 1133.899457 365.772050
69 9088.0 1394.915427 1130.696173 366.804487
70 9216.0 1397.713073 1131.887190 367.480474
71 9344.0 1390.586198 1419.391670 367.889989
72 9472.0 1397.986576 1426.103825 369.256259
73 9600.0 1399.901578 1430.734912 369.163540
74 9728.0 1395.921156 1441.210094 369.834441
75 9856.0 1399.127980 1437.192995 370.102272
76 9984.0 1391.327978 1448.384086 370.116911
77 10112.0 1405.548575 1457.076805 371.459744
78 10240.0 1403.022935 1463.082375 371.710032
79 10368.0 1412.699700 1460.521474 369.328354
80 10496.0 1406.066570 1465.764544 370.044838
81 10624.0 1403.932071 1464.924666 371.304464
82 10752.0 1396.644627 1471.257164 371.087410
83 10880.0 1390.184004 1476.115203 371.591115
84 11008.0 1414.465179 1476.394860 372.301729
85 11136.0 1420.541108 1485.171304 373.187862
86 11264.0 1407.170750 1485.963982 372.868373
87 11392.0 1417.656797 1488.788444 374.032385
88 11520.0 1410.525833 1495.357355 373.914388
89 11648.0 1418.141926 1499.384979 374.516050
90 11776.0 1425.531967 1500.781373 374.340854
91 11904.0 1422.144565 1507.026601 375.417407
92 12032.0 1411.299068 1508.145651 375.628754
93 12160.0 1408.965220 1515.398698 375.992370
94 12288.0 1421.246722 1418.903649 376.141048
95 12416.0 1431.567369 1397.187103 374.857999
96 12544.0 1437.219073 1394.035935 375.617189
97 12672.0 1427.185539 1390.400542 375.339096
- 在上面的图中,我们可以看到
Triton 比 Torch JIT 快 4 倍。这证实了我们的猜测:Torch JIT 在此并未进行任何融合。
Triton 比
torch.softmax
显著更快——此外还更易读、易懂、易维护。但请注意,PyTorch 的 softmax 操作更为通用,适用于任何形状的张量。
脚本总运行时间:(0 分 35.218 秒)