融合 Softmax

在本教程中,您将编写一个融合的 Softmax 操作,它比 PyTorch 的原生操作显著更快,适用于一类特定的矩阵:即其行可以适应 GPU SRAM 的矩阵。

通过本教程,您将了解到

  • 内核融合对带宽受限操作的益处。

  • Triton 中的规约运算符。

动机

用于逐元素相加的自定义 GPU 内核在教育上很有价值,但在实践中作用不大。我们不妨考虑一个简单的(数值稳定化的)Softmax 操作。

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

当在 PyTorch 中朴素地实现时,计算 y = naive_softmax(x),其中 \(x \in R^{M \times N}\),需要从 DRAM 读取 \(5MN + 2M\) 元素并写回 \(3MN + 2M\) 元素。这显然是浪费;我们更希望有一个自定义的“融合”内核,它只读取 X 一次并在片上完成所有必要的计算。这样做只需要读写回 \(MN\) 字节,因此我们可以预期理论上加速约 4 倍(即 \((8MN + 4M) / 2MN\))。torch.jit.script 旨在自动执行这种“内核融合”,但正如我们稍后将看到的,它仍然远非理想。

计算内核

我们的 Softmax 内核工作原理如下:每个程序根据程序数量跨步加载输入矩阵 X 的一组行,对其进行归一化,并将结果写回输出 Y。

需要注意的是,Triton 的一个重要限制是每个块的元素数量必须是 2 的幂,因此,如果我们要处理任何可能的输入形状,我们需要在内部“填充”每行并妥善保护内存操作。

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

我们可以创建一个辅助函数,它将内核及其(元)参数排入队列,以处理任何给定的输入张量。

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

单元测试

我们确保在行数和列数不规则的矩阵上测试我们的内核。这将使我们能够验证我们的填充机制是否有效。

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

正如预期,结果是相同的。

基准测试

在这里,我们将根据输入矩阵的列数来对我们的操作进行基准测试——假设行数为 4096。然后,我们将它的性能与 (1) torch.softmax 和 (2) 上面定义的 naive_softmax 进行比较。

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch  Naive Softmax
0     256.0   477.462712   703.822901     206.253863
1     384.0   664.863479   824.797972     264.411742
2     512.0   811.265215   946.849992     302.867896
3     640.0   877.802125   958.214572     332.318560
4     768.0   956.897098  1028.029727     350.327420
5     896.0  1012.016262  1074.654123     353.915462
6    1024.0  1058.936830  1095.562592     353.520017
7    1152.0  1097.253126  1038.944131     349.141133
8    1280.0  1134.885047  1076.550819     349.465844
9    1408.0  1162.877066  1104.759849     340.882321
10   1536.0  1197.078594  1142.965734     334.313559
11   1664.0  1208.426362  1162.107467     329.747045
12   1792.0  1232.047263  1194.460566     326.213875
13   1920.0  1254.221134  1198.008237     325.163443
14   2048.0  1270.868793  1224.541007     324.795622
15   2176.0  1238.981103   966.640005     325.543620
16   2304.0  1249.673662   999.671124     325.991668
17   2432.0  1274.103191  1036.851407     326.855959
18   2560.0  1285.843872  1074.691278     328.032597
19   2688.0  1289.919793  1103.789172     329.656577
20   2816.0  1302.932329  1122.763803     329.773201
21   2944.0  1321.736236  1150.159162     331.548506
22   3072.0  1320.358598  1168.747194     333.497314
23   3200.0  1339.986863  1177.635725     335.169487
24   3328.0  1341.911072  1205.791789     336.464770
25   3456.0  1345.134990  1223.641691     336.768093
26   3584.0  1356.459754  1243.844716     337.903484
27   3712.0  1365.017159  1263.592994     340.479845
28   3840.0  1367.038641  1282.181248     340.032085
29   3968.0  1374.065700  1295.951169     341.083274
30   4096.0  1388.242775  1323.372434     338.706470
31   4224.0  1327.245712  1297.914009     343.016525
32   4352.0  1333.047076  1321.803972     345.187550
33   4480.0  1340.221524  1336.673665     345.867393
34   4608.0  1354.191504  1356.376747     346.447012
35   4736.0  1352.559607  1366.564458     348.256379
36   4864.0  1361.121268  1380.510427     349.001035
37   4992.0  1366.842846  1385.203084     350.386939
38   5120.0  1370.427007  1409.077324     350.707591
39   5248.0  1371.916948  1367.277118     351.671604
40   5376.0  1372.539057  1380.455673     351.756040
41   5504.0  1375.324435  1393.329007     353.731835
42   5632.0  1387.278094  1408.649175     353.417598
43   5760.0  1389.558013  1412.832607     355.181756
44   5888.0  1386.995394  1421.744729     355.255180
45   6016.0  1392.629236  1441.846884     356.578110
46   6144.0  1402.673006  1453.298007     357.339149
47   6272.0  1406.107165  1400.065488     357.645651
48   6400.0  1407.036565  1413.802463     358.641086
49   6528.0  1408.586249  1423.978952     359.125847
50   6656.0  1411.428892  1428.567559     359.514908
51   6784.0  1416.933998  1451.190287     360.450094
52   6912.0  1421.213683  1454.336125     360.927162
53   7040.0  1413.755012  1448.599521     360.791609
54   7168.0  1413.913697  1468.536980     361.787232
55   7296.0  1421.170734  1086.708410     362.712744
56   7424.0  1427.251521  1101.957011     362.809103
57   7552.0  1424.181749  1115.340071     363.521974
58   7680.0  1428.553742  1124.356520     363.473484
59   7808.0  1425.336361  1136.225566     364.548596
60   7936.0  1427.406822  1147.406355     364.741411
61   8064.0  1429.642807  1151.662541     365.039918
62   8192.0  1427.365779  1157.348999     364.245581
63   8320.0  1388.879335  1114.048896     361.891111
64   8448.0  1383.076857  1121.800331     362.518902
65   8576.0  1385.886635  1122.658140     363.566352
66   8704.0  1383.643694  1131.870819     364.601126
67   8832.0  1394.404237  1128.487748     365.267494
68   8960.0  1388.944135  1133.899457     365.772050
69   9088.0  1394.915427  1130.696173     366.804487
70   9216.0  1397.713073  1131.887190     367.480474
71   9344.0  1390.586198  1419.391670     367.889989
72   9472.0  1397.986576  1426.103825     369.256259
73   9600.0  1399.901578  1430.734912     369.163540
74   9728.0  1395.921156  1441.210094     369.834441
75   9856.0  1399.127980  1437.192995     370.102272
76   9984.0  1391.327978  1448.384086     370.116911
77  10112.0  1405.548575  1457.076805     371.459744
78  10240.0  1403.022935  1463.082375     371.710032
79  10368.0  1412.699700  1460.521474     369.328354
80  10496.0  1406.066570  1465.764544     370.044838
81  10624.0  1403.932071  1464.924666     371.304464
82  10752.0  1396.644627  1471.257164     371.087410
83  10880.0  1390.184004  1476.115203     371.591115
84  11008.0  1414.465179  1476.394860     372.301729
85  11136.0  1420.541108  1485.171304     373.187862
86  11264.0  1407.170750  1485.963982     372.868373
87  11392.0  1417.656797  1488.788444     374.032385
88  11520.0  1410.525833  1495.357355     373.914388
89  11648.0  1418.141926  1499.384979     374.516050
90  11776.0  1425.531967  1500.781373     374.340854
91  11904.0  1422.144565  1507.026601     375.417407
92  12032.0  1411.299068  1508.145651     375.628754
93  12160.0  1408.965220  1515.398698     375.992370
94  12288.0  1421.246722  1418.903649     376.141048
95  12416.0  1431.567369  1397.187103     374.857999
96  12544.0  1437.219073  1394.035935     375.617189
97  12672.0  1427.185539  1390.400542     375.339096
在上面的图中,我们可以看到
  • Triton 比 Torch JIT 快 4 倍。这证实了我们的猜测:Torch JIT 在此并未进行任何融合。

  • Triton 比 torch.softmax 显著更快——此外还更易读、易懂、易维护。但请注意,PyTorch 的 softmax 操作更为通用,适用于任何形状的张量。

脚本总运行时间:(0 分 35.218 秒)

由 Sphinx-Gallery 生成的图库