注意
跳转至页面底部下载完整示例代码。
融合 Softmax
在本教程中,你将编写一个融合 Softmax 操作。对于特定类别的矩阵(即行数可以放入 GPU SRAM 中的矩阵),该操作将显著快于 PyTorch 的原生算子。
通过本教程,您将了解到
带宽受限操作中算子融合的优势。
Triton 中的归约(Reduction)算子。
动机
自定义的 GPU 元素级加法内核虽然具有教育价值,但在实践中作用有限。让我们转而考虑一个简单的(数值稳定的)Softmax 操作:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
如果在 PyTorch 中朴素地实现,计算 y = naive_softmax(x)(其中 \(x \in R^{M \times N}\))需要从 DRAM 读取 \(5MN + 2M\) 个元素,并写回 \(3MN + 2M\) 个元素。这显然是浪费的;我们更倾向于使用一个自定义的“融合”内核,只读取一次 X 并在片上完成所有必要的计算。这样做只需读取和写回 \(MN\) 字节,因此理论上可以获得约 4 倍的加速(即 \((8MN + 4M) / 2MN\))。torch.jit.script 标志旨在自动执行这种“内核融合”,但正如我们稍后将看到的,它距离理想状态仍有很大差距。
计算内核
我们的 Softmax 内核工作原理如下:每个程序根据程序的数量跨步(strided)加载输入矩阵 X 的一组行,对其进行归一化,并将结果写回输出 Y。
请注意,Triton 的一个重要限制是每个块必须包含 2 的幂次个元素,因此如果我们想要处理任何可能的输入形状,就需要在内部对每一行进行“填充(padding)”,并对内存操作进行适当的保护。
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
我们可以创建一个辅助函数,为任何给定的输入张量调用内核及其(元)参数。
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
NUM_GPRS = NUM_REGS
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
单元测试
我们要确保在行列数不规则的矩阵上测试内核。这将验证我们的填充机制是否正常工作。
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
正如预期的那样,结果完全相同。
基准测试
在这里,我们将以输入矩阵的列数为变量进行基准测试——假设有 4096 行。然后,我们将对比其性能与 (1) torch.softmax 和 (2) 上面定义的 naive_softmax。
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg``
line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines
styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
if provider == 'naive_softmax':
ms = triton.testing.do_bench(lambda: naive_softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s)
0 256.0 505.091135 704.663800 208.360473
1 384.0 701.666012 821.081367 263.815333
2 512.0 839.053880 915.648283 299.846547
3 640.0 836.248289 919.478119 330.189166
4 768.0 913.085118 988.818136 348.525727
5 896.0 978.250002 1036.138136 353.011207
6 1024.0 1027.567946 1075.608731 353.078014
7 1152.0 1020.403231 1074.346139 347.483118
8 1280.0 1081.425317 1113.767076 348.154863
9 1408.0 1109.168672 1138.073973 340.538983
10 1536.0 1151.131931 1159.791310 333.857494
11 1664.0 1186.794173 1192.747490 330.325436
12 1792.0 1209.712832 1193.947161 325.571096
13 1920.0 1231.513068 1216.430415 323.893440
14 2048.0 1246.003820 1240.355203 323.997644
15 2176.0 1251.873738 958.875122 325.248353
16 2304.0 1270.532454 1003.469335 325.454217
17 2432.0 1297.308523 1031.974732 325.751576
18 2560.0 1305.650668 1071.105447 327.617594
19 2688.0 1317.944677 1096.202036 328.719881
20 2816.0 1336.459195 1126.166522 329.115028
21 2944.0 1335.325468 1148.503207 331.800589
22 3072.0 1350.900592 1170.833674 333.977867
23 3200.0 1354.954042 1173.750537 334.656529
24 3328.0 1363.400746 1198.776127 335.699413
25 3456.0 1368.535502 1219.435537 336.453224
26 3584.0 1373.404895 1246.439737 337.956270
27 3712.0 1384.320789 1266.360853 339.608390
28 3840.0 1391.712179 1284.087060 340.398059
29 3968.0 1393.908332 1301.188741 340.822090
30 4096.0 1392.134322 1317.236891 338.648011
31 4224.0 1340.732034 1279.716717 343.213636
32 4352.0 1348.794994 1301.148704 345.463574
33 4480.0 1353.848701 1317.764264 345.720205
34 4608.0 1371.630996 1337.040876 346.656879
35 4736.0 1367.329019 1346.499776 347.893119
36 4864.0 1369.122336 1359.099200 348.682513
37 4992.0 1373.710016 1375.520204 349.929713
38 5120.0 1381.778143 1385.157838 350.390671
39 5248.0 1379.258132 1357.134306 351.669737
40 5376.0 1383.871485 1374.128292 351.766064
41 5504.0 1389.060859 1383.602927 353.899771
42 5632.0 1405.035935 1388.506074 353.273466
43 5760.0 1402.257802 1408.161096 355.219008
44 5888.0 1401.086146 1410.544243 354.934511
45 6016.0 1405.699133 1423.702123 356.266752
46 6144.0 1412.752097 1437.974378 357.325234
47 6272.0 1415.147031 1393.634599 357.728797
48 6400.0 1419.826346 1414.034104 357.894735
49 6528.0 1420.650638 1414.967906 359.047536
50 6656.0 1422.947119 1431.716575 359.172476
51 6784.0 1424.771552 1433.870547 360.417870
52 6912.0 1429.434720 1441.100003 360.629150
53 7040.0 1426.459014 1454.511219 361.025640
54 7168.0 1428.267961 1454.322971 361.488321
55 7296.0 1431.902852 1087.788108 362.316380
56 7424.0 1435.514788 1098.713069 362.603159
57 7552.0 1433.403792 1106.041760 363.013294
58 7680.0 1437.768870 1122.248941 363.396705
59 7808.0 1432.988078 1127.893168 364.299401
60 7936.0 1437.932261 1141.581655 364.618457
61 8064.0 1439.533572 1146.141182 365.067099
62 8192.0 1433.177630 1151.175017 363.489382
63 8320.0 1380.992766 1118.468245 361.685522
64 8448.0 1382.900020 1126.024670 362.402981
65 8576.0 1384.226968 1129.407179 363.098682
66 8704.0 1376.622534 1133.384119 364.209023
67 8832.0 1392.287640 1132.847732 364.916002
68 8960.0 1385.909942 1139.871201 365.598815
69 9088.0 1388.678577 1137.819103 366.551094
70 9216.0 1399.293321 1143.915202 367.436023
71 9344.0 1389.467705 1421.005094 367.667099
72 9472.0 1394.237786 1431.952964 368.211882
73 9600.0 1398.902078 1432.758715 368.783088
74 9728.0 1395.673001 1440.225255 369.772317
75 9856.0 1401.245989 1438.244396 369.974576
76 9984.0 1388.509195 1450.417851 370.689487
77 10112.0 1403.161525 1457.662239 371.149073
78 10240.0 1406.549851 1467.240928 371.257859
79 10368.0 1412.932494 1461.312823 369.864041
80 10496.0 1405.100786 1464.091827 370.195569
81 10624.0 1403.815503 1468.138385 370.823800
82 10752.0 1392.255136 1471.737129 371.176177
83 10880.0 1392.371296 1478.272957 371.934906
84 11008.0 1415.319736 1476.131684 372.429986
85 11136.0 1414.775748 1485.278001 372.629118
86 11264.0 1407.132118 1489.836145 373.378925
87 11392.0 1419.900430 1492.070874 374.360070
88 11520.0 1409.547705 1491.767950 373.686969
89 11648.0 1416.659104 1495.810039 374.100361
90 11776.0 1431.850794 1503.958123 374.504725
91 11904.0 1429.674070 1507.655768 374.786913
92 12032.0 1412.896597 1509.056160 376.327660
93 12160.0 1408.256166 1516.310629 376.240648
94 12288.0 1422.338488 1422.762538 376.116457
95 12416.0 1429.857847 1397.478687 375.145581
96 12544.0 1440.553753 1395.929581 375.533729
97 12672.0 1431.380337 1396.077793 375.180844
- 在上述图表中,我们可以看到:
Triton 比 Torch JIT 快 4 倍。这证实了我们的猜想,即 Torch JIT 在此处并未执行任何融合。
Triton 明显快于
torch.softmax——此外,它还更易于阅读、理解和维护。但请注意,PyTorch 的 softmax 操作更通用,适用于任何形状的张量。
脚本总运行时间:(0 分 34.899 秒)