在本教程中,你将编写一个融合 Softmax 操作,它对于特定类型的矩阵(行可以放入 GPU SRAM 的矩阵)比 PyTorch 的原生操作快得多。
通过本教程,你将了解到
核融合对于带宽限制操作的益处。
Triton 中的归约操作符。
动机¶
用于逐元素相加的自定义 GPU 核在教育上很有价值,但在实践中并不能带来太大帮助。让我们转而考虑一个简单的(数值稳定)softmax 操作:
当在 PyTorch 中天真地实现时,对于 \(x \in R^{M \times N}\) 计算 y = naive_softmax(x)
需要从 DRAM 读取 \(5MN + 2M\) 个元素并写回 \(3MN + 2M\) 个元素。这显然是浪费的;我们更希望有一个自定义的“融合”核,它只读取 X 一次并在片上完成所有必要的计算。这样做只需要读取和写回 \(MN\) 字节,因此我们可以预期理论上的加速比约为 4 倍(即 \((8MN + 4M) / 2MN\))。torch.jit.script 旨在自动执行这种“核融合”,但正如我们稍后将看到的,它仍然远非理想。
import torch
import triton
import triton.language as tl
from triton.runtime import driver
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
计算核¶
我们的 softmax 核工作方式如下:每个程序按程序数量步进加载输入矩阵 X 的一组行,将其归一化并将结果写回输出 Y。
请注意,Triton 的一个重要限制是每个块必须包含数量为 2 的幂的元素,因此如果我们要处理任何可能的输入形状,我们需要在内部“填充”每一行并妥善保护内存操作。
我们可以创建一个辅助函数,用于为任何给定的输入张量将核及其(元)参数排队。
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
单元测试¶
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}
def softmax(x):
n_rows, n_cols = x.shape
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8
# Number of software pipelining stages.
num_stages = 4 if SIZE_SMEM > 200000 else 2
# Allocate output
y = torch.empty_like(x)
# pre-compile kernel to get register usage and compute thread occupancy.
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
num_programs = min(num_programs, n_rows)
# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
return y
我们确保在行数和列数不规则的矩阵上测试我们的核。这将使我们能够验证填充机制是否有效。
正如预期的那样,结果是相同的。
torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
基准测试¶
在这里,我们将对我们的操作进行基准测试,以输入矩阵的列数为变量——假设有 4096 行。然后,我们将把其性能与 (1) torch.softmax
和 (2) 上面定义的 naive_softmax
进行比较。
在上面的图中,我们可以看到:
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # argument names to use as an x-axis for the plot
x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name`
line_arg='provider', # argument name whose value corresponds to a different line in the plot
line_vals=['triton', 'torch'], # possible values for `line_arg``
line_names=[
"Triton",
"Torch",
], # label name for the lines
styles=[('blue', '-'), ('green', '-')], # line styles
ylabel="GB/s", # label name for the y-axis
plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot.
args={'M': 4096}, # values for function arguments not in `x_names` and `y_name`
))
def benchmark(M, N, provider):
x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
stream = getattr(torch, DEVICE.type).Stream()
getattr(torch, DEVICE.type).set_stream(stream)
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: softmax(x))
gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)
benchmark.run(show_plots=True, print_data=True)

softmax-performance:
N Triton Torch
0 256.0 469.732076 692.689492
1 384.0 652.828102 795.504136
2 512.0 795.015507 923.740799
3 640.0 807.852620 958.504916
4 768.0 875.382000 1019.091935
5 896.0 944.951634 1071.444269
6 1024.0 1011.768310 1122.081397
7 1152.0 1111.014697 1035.029144
8 1280.0 1147.043527 1068.101974
9 1408.0 1154.049121 1101.867086
10 1536.0 1190.612728 1134.838469
11 1664.0 1217.798415 1175.845186
12 1792.0 1234.334401 1190.140659
13 1920.0 1249.736187 1194.084888
14 2048.0 1275.059404 1223.597841
15 2176.0 1235.480696 965.158612
16 2304.0 1240.129928 997.779550
17 2432.0 1269.126226 1037.400082
18 2560.0 1283.319247 1065.043622
19 2688.0 1290.506231 1100.126722
20 2816.0 1294.861797 1121.683566
21 2944.0 1303.655261 1148.005162
22 3072.0 1329.067794 1166.196581
23 3200.0 1327.660707 1172.287463
24 3328.0 1337.282252 1198.259152
25 3456.0 1349.204652 1223.622217
26 3584.0 1347.700761 1245.028181
27 3712.0 1367.456264 1265.842376
28 3840.0 1371.590501 1288.562620
29 3968.0 1373.529987 1302.027667
30 4096.0 1377.958594 1321.597803
31 4224.0 1336.742331 1291.117294
32 4352.0 1337.776161 1318.557875
33 4480.0 1351.075845 1333.451508
34 4608.0 1360.825927 1351.586619
35 4736.0 1358.873199 1370.251089
36 4864.0 1378.304233 1381.059865
37 4992.0 1369.892084 1396.660748
38 5120.0 1376.840949 1408.474925
39 5248.0 1374.494614 1366.456404
40 5376.0 1378.071805 1385.120751
41 5504.0 1380.035006 1395.895439
42 5632.0 1387.620633 1411.242738
43 5760.0 1396.716359 1417.261053
44 5888.0 1389.196130 1438.290106
45 6016.0 1402.831700 1439.457717
46 6144.0 1406.955159 1441.475421
47 6272.0 1415.969958 1395.716261
48 6400.0 1414.891479 1423.016301
49 6528.0 1410.184963 1432.280886
50 6656.0 1416.526416 1444.118325
51 6784.0 1411.076547 1443.480082
52 6912.0 1424.090608 1456.725618
53 7040.0 1418.970537 1464.651645
54 7168.0 1429.525133 1464.807791
55 7296.0 1429.362638 1084.764623
56 7424.0 1427.155530 1100.744607
57 7552.0 1430.106467 1112.697569
58 7680.0 1435.360501 1126.038568
59 7808.0 1432.615390 1135.080828
60 7936.0 1439.267951 1145.385607
61 8064.0 1435.956362 1151.788163
62 8192.0 1439.406032 1156.491679
63 8320.0 1392.595787 1113.680219
64 8448.0 1385.771028 1124.086154
65 8576.0 1398.346265 1122.554068
66 8704.0 1392.877428 1128.744883
67 8832.0 1392.475742 1129.015317
68 8960.0 1403.050563 1134.713606
69 9088.0 1413.563332 1133.665378
70 9216.0 1409.955820 1127.222604
71 9344.0 1405.859161 1422.962682
72 9472.0 1403.946226 1431.461395
73 9600.0 1393.724015 1428.577664
74 9728.0 1407.513490 1438.615060
75 9856.0 1420.547375 1442.005285
76 9984.0 1402.673009 1448.820885
77 10112.0 1418.148573 1450.572145
78 10240.0 1422.633858 1464.953933
79 10368.0 1417.277556 1461.188145
80 10496.0 1420.189939 1465.491636
81 10624.0 1421.473094 1463.260484
82 10752.0 1409.427255 1470.049996
83 10880.0 1404.088724 1477.157006
84 11008.0 1422.004351 1475.620959
85 11136.0 1425.758185 1485.202505
86 11264.0 1431.351554 1484.921399
87 11392.0 1424.294121 1491.254743
88 11520.0 1428.022545 1494.247371
89 11648.0 1427.784029 1498.814314
90 11776.0 1436.136658 1502.769595
91 11904.0 1447.339033 1509.698889
92 12032.0 1427.186683 1512.825150
93 12160.0 1423.977867 1516.111112
94 12288.0 1438.619150 1422.520536
95 12416.0 1454.030155 1395.575299
96 12544.0 1447.487422 1394.804423
97 12672.0 1453.132657 1392.027710
- Triton 比 Torch JIT 快 4 倍。这证实了我们对 Torch JIT 在此未执行任何融合的猜测。
Triton 明显快于
torch.softmax
– 此外,它还更易读、易懂且易于维护。但请注意,PyTorch 的 softmax 操作更通用,适用于任何形状的张量。脚本总运行时间: (0 分 23.354 秒)
下载 Jupyter notebook: 02-fused-softmax.ipynb