融合 Softmax

在本教程中,你将编写一个融合 Softmax 操作。对于特定类别的矩阵(即行数可以放入 GPU SRAM 中的矩阵),该操作将显著快于 PyTorch 的原生算子。

通过本教程,您将了解到

  • 带宽受限操作中算子融合的优势。

  • Triton 中的归约(Reduction)算子。

动机

自定义的 GPU 元素级加法内核虽然具有教育价值,但在实践中作用有限。让我们转而考虑一个简单的(数值稳定的)Softmax 操作:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

如果在 PyTorch 中朴素地实现,计算 y = naive_softmax(x)(其中 \(x \in R^{M \times N}\))需要从 DRAM 读取 \(5MN + 2M\) 个元素,并写回 \(3MN + 2M\) 个元素。这显然是浪费的;我们更倾向于使用一个自定义的“融合”内核,只读取一次 X 并在片上完成所有必要的计算。这样做只需读取和写回 \(MN\) 字节,因此理论上可以获得约 4 倍的加速(即 \((8MN + 4M) / 2MN\))。torch.jit.script 标志旨在自动执行这种“内核融合”,但正如我们稍后将看到的,它距离理想状态仍有很大差距。

计算内核

我们的 Softmax 内核工作原理如下:每个程序根据程序的数量跨步(strided)加载输入矩阵 X 的一组行,对其进行归一化,并将结果写回输出 Y。

请注意,Triton 的一个重要限制是每个块必须包含 2 的幂次个元素,因此如果我们想要处理任何可能的输入形状,就需要在内部对每一行进行“填充(padding)”,并对内存操作进行适当的保护。

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

我们可以创建一个辅助函数,为任何给定的输入张量调用内核及其(元)参数。

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

单元测试

我们要确保在行列数不规则的矩阵上测试内核。这将验证我们的填充机制是否正常工作。

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

正如预期的那样,结果完全相同。

基准测试

在这里,我们将以输入矩阵的列数为变量进行基准测试——假设有 4096 行。然后,我们将对比其性能与 (1) torch.softmax 和 (2) 上面定义的 naive_softmax

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N  Triton (GB/s)  Torch (GB/s)  Naive Softmax (GB/s)
0     256.0     505.091135    704.663800            208.360473
1     384.0     701.666012    821.081367            263.815333
2     512.0     839.053880    915.648283            299.846547
3     640.0     836.248289    919.478119            330.189166
4     768.0     913.085118    988.818136            348.525727
5     896.0     978.250002   1036.138136            353.011207
6    1024.0    1027.567946   1075.608731            353.078014
7    1152.0    1020.403231   1074.346139            347.483118
8    1280.0    1081.425317   1113.767076            348.154863
9    1408.0    1109.168672   1138.073973            340.538983
10   1536.0    1151.131931   1159.791310            333.857494
11   1664.0    1186.794173   1192.747490            330.325436
12   1792.0    1209.712832   1193.947161            325.571096
13   1920.0    1231.513068   1216.430415            323.893440
14   2048.0    1246.003820   1240.355203            323.997644
15   2176.0    1251.873738    958.875122            325.248353
16   2304.0    1270.532454   1003.469335            325.454217
17   2432.0    1297.308523   1031.974732            325.751576
18   2560.0    1305.650668   1071.105447            327.617594
19   2688.0    1317.944677   1096.202036            328.719881
20   2816.0    1336.459195   1126.166522            329.115028
21   2944.0    1335.325468   1148.503207            331.800589
22   3072.0    1350.900592   1170.833674            333.977867
23   3200.0    1354.954042   1173.750537            334.656529
24   3328.0    1363.400746   1198.776127            335.699413
25   3456.0    1368.535502   1219.435537            336.453224
26   3584.0    1373.404895   1246.439737            337.956270
27   3712.0    1384.320789   1266.360853            339.608390
28   3840.0    1391.712179   1284.087060            340.398059
29   3968.0    1393.908332   1301.188741            340.822090
30   4096.0    1392.134322   1317.236891            338.648011
31   4224.0    1340.732034   1279.716717            343.213636
32   4352.0    1348.794994   1301.148704            345.463574
33   4480.0    1353.848701   1317.764264            345.720205
34   4608.0    1371.630996   1337.040876            346.656879
35   4736.0    1367.329019   1346.499776            347.893119
36   4864.0    1369.122336   1359.099200            348.682513
37   4992.0    1373.710016   1375.520204            349.929713
38   5120.0    1381.778143   1385.157838            350.390671
39   5248.0    1379.258132   1357.134306            351.669737
40   5376.0    1383.871485   1374.128292            351.766064
41   5504.0    1389.060859   1383.602927            353.899771
42   5632.0    1405.035935   1388.506074            353.273466
43   5760.0    1402.257802   1408.161096            355.219008
44   5888.0    1401.086146   1410.544243            354.934511
45   6016.0    1405.699133   1423.702123            356.266752
46   6144.0    1412.752097   1437.974378            357.325234
47   6272.0    1415.147031   1393.634599            357.728797
48   6400.0    1419.826346   1414.034104            357.894735
49   6528.0    1420.650638   1414.967906            359.047536
50   6656.0    1422.947119   1431.716575            359.172476
51   6784.0    1424.771552   1433.870547            360.417870
52   6912.0    1429.434720   1441.100003            360.629150
53   7040.0    1426.459014   1454.511219            361.025640
54   7168.0    1428.267961   1454.322971            361.488321
55   7296.0    1431.902852   1087.788108            362.316380
56   7424.0    1435.514788   1098.713069            362.603159
57   7552.0    1433.403792   1106.041760            363.013294
58   7680.0    1437.768870   1122.248941            363.396705
59   7808.0    1432.988078   1127.893168            364.299401
60   7936.0    1437.932261   1141.581655            364.618457
61   8064.0    1439.533572   1146.141182            365.067099
62   8192.0    1433.177630   1151.175017            363.489382
63   8320.0    1380.992766   1118.468245            361.685522
64   8448.0    1382.900020   1126.024670            362.402981
65   8576.0    1384.226968   1129.407179            363.098682
66   8704.0    1376.622534   1133.384119            364.209023
67   8832.0    1392.287640   1132.847732            364.916002
68   8960.0    1385.909942   1139.871201            365.598815
69   9088.0    1388.678577   1137.819103            366.551094
70   9216.0    1399.293321   1143.915202            367.436023
71   9344.0    1389.467705   1421.005094            367.667099
72   9472.0    1394.237786   1431.952964            368.211882
73   9600.0    1398.902078   1432.758715            368.783088
74   9728.0    1395.673001   1440.225255            369.772317
75   9856.0    1401.245989   1438.244396            369.974576
76   9984.0    1388.509195   1450.417851            370.689487
77  10112.0    1403.161525   1457.662239            371.149073
78  10240.0    1406.549851   1467.240928            371.257859
79  10368.0    1412.932494   1461.312823            369.864041
80  10496.0    1405.100786   1464.091827            370.195569
81  10624.0    1403.815503   1468.138385            370.823800
82  10752.0    1392.255136   1471.737129            371.176177
83  10880.0    1392.371296   1478.272957            371.934906
84  11008.0    1415.319736   1476.131684            372.429986
85  11136.0    1414.775748   1485.278001            372.629118
86  11264.0    1407.132118   1489.836145            373.378925
87  11392.0    1419.900430   1492.070874            374.360070
88  11520.0    1409.547705   1491.767950            373.686969
89  11648.0    1416.659104   1495.810039            374.100361
90  11776.0    1431.850794   1503.958123            374.504725
91  11904.0    1429.674070   1507.655768            374.786913
92  12032.0    1412.896597   1509.056160            376.327660
93  12160.0    1408.256166   1516.310629            376.240648
94  12288.0    1422.338488   1422.762538            376.116457
95  12416.0    1429.857847   1397.478687            375.145581
96  12544.0    1440.553753   1395.929581            375.533729
97  12672.0    1431.380337   1396.077793            375.180844
在上述图表中,我们可以看到:
  • Triton 比 Torch JIT 快 4 倍。这证实了我们的猜想,即 Torch JIT 在此处并未执行任何融合。

  • Triton 明显快于 torch.softmax——此外,它还更易于阅读、理解和维护。但请注意,PyTorch 的 softmax 操作更通用,适用于任何形状的张量。

脚本总运行时间:(0 分 34.899 秒)

由 Sphinx-Gallery 生成的图库