在本教程中,你将编写一个融合 Softmax 操作,它对于特定类型的矩阵(行可以放入 GPU SRAM 的矩阵)比 PyTorch 的原生操作快得多。

通过本教程,你将了解到

核融合对于带宽限制操作的益处。

  • Triton 中的归约操作符。

  • 动机

用于逐元素相加的自定义 GPU 核在教育上很有价值,但在实践中并不能带来太大帮助。让我们转而考虑一个简单的(数值稳定)softmax 操作:

当在 PyTorch 中天真地实现时,对于 \(x \in R^{M \times N}\) 计算 y = naive_softmax(x) 需要从 DRAM 读取 \(5MN + 2M\) 个元素并写回 \(3MN + 2M\) 个元素。这显然是浪费的;我们更希望有一个自定义的“融合”核,它只读取 X 一次并在片上完成所有必要的计算。这样做只需要读取和写回 \(MN\) 字节,因此我们可以预期理论上的加速比约为 4 倍(即 \((8MN + 4M) / 2MN\))。torch.jit.script 旨在自动执行这种“核融合”,但正如我们稍后将看到的,它仍然远非理想。

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

计算核

我们的 softmax 核工作方式如下:每个程序按程序数量步进加载输入矩阵 X 的一组行,将其归一化并将结果写回输出 Y。

请注意,Triton 的一个重要限制是每个块必须包含数量为 2 的幂的元素,因此如果我们要处理任何可能的输入形状,我们需要在内部“填充”每一行并妥善保护内存操作。

我们可以创建一个辅助函数,用于为任何给定的输入张量将核及其(元)参数排队。

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

单元测试

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

我们确保在行数和列数不规则的矩阵上测试我们的核。这将使我们能够验证填充机制是否有效。

正如预期的那样,结果是相同的。

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

基准测试

在这里,我们将对我们的操作进行基准测试,以输入矩阵的列数为变量——假设有 4096 行。然后,我们将把其性能与 (1) torch.softmax 和 (2) 上面定义的 naive_softmax 进行比较。

在上面的图中,我们可以看到:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   469.732076   692.689492
1     384.0   652.828102   795.504136
2     512.0   795.015507   923.740799
3     640.0   807.852620   958.504916
4     768.0   875.382000  1019.091935
5     896.0   944.951634  1071.444269
6    1024.0  1011.768310  1122.081397
7    1152.0  1111.014697  1035.029144
8    1280.0  1147.043527  1068.101974
9    1408.0  1154.049121  1101.867086
10   1536.0  1190.612728  1134.838469
11   1664.0  1217.798415  1175.845186
12   1792.0  1234.334401  1190.140659
13   1920.0  1249.736187  1194.084888
14   2048.0  1275.059404  1223.597841
15   2176.0  1235.480696   965.158612
16   2304.0  1240.129928   997.779550
17   2432.0  1269.126226  1037.400082
18   2560.0  1283.319247  1065.043622
19   2688.0  1290.506231  1100.126722
20   2816.0  1294.861797  1121.683566
21   2944.0  1303.655261  1148.005162
22   3072.0  1329.067794  1166.196581
23   3200.0  1327.660707  1172.287463
24   3328.0  1337.282252  1198.259152
25   3456.0  1349.204652  1223.622217
26   3584.0  1347.700761  1245.028181
27   3712.0  1367.456264  1265.842376
28   3840.0  1371.590501  1288.562620
29   3968.0  1373.529987  1302.027667
30   4096.0  1377.958594  1321.597803
31   4224.0  1336.742331  1291.117294
32   4352.0  1337.776161  1318.557875
33   4480.0  1351.075845  1333.451508
34   4608.0  1360.825927  1351.586619
35   4736.0  1358.873199  1370.251089
36   4864.0  1378.304233  1381.059865
37   4992.0  1369.892084  1396.660748
38   5120.0  1376.840949  1408.474925
39   5248.0  1374.494614  1366.456404
40   5376.0  1378.071805  1385.120751
41   5504.0  1380.035006  1395.895439
42   5632.0  1387.620633  1411.242738
43   5760.0  1396.716359  1417.261053
44   5888.0  1389.196130  1438.290106
45   6016.0  1402.831700  1439.457717
46   6144.0  1406.955159  1441.475421
47   6272.0  1415.969958  1395.716261
48   6400.0  1414.891479  1423.016301
49   6528.0  1410.184963  1432.280886
50   6656.0  1416.526416  1444.118325
51   6784.0  1411.076547  1443.480082
52   6912.0  1424.090608  1456.725618
53   7040.0  1418.970537  1464.651645
54   7168.0  1429.525133  1464.807791
55   7296.0  1429.362638  1084.764623
56   7424.0  1427.155530  1100.744607
57   7552.0  1430.106467  1112.697569
58   7680.0  1435.360501  1126.038568
59   7808.0  1432.615390  1135.080828
60   7936.0  1439.267951  1145.385607
61   8064.0  1435.956362  1151.788163
62   8192.0  1439.406032  1156.491679
63   8320.0  1392.595787  1113.680219
64   8448.0  1385.771028  1124.086154
65   8576.0  1398.346265  1122.554068
66   8704.0  1392.877428  1128.744883
67   8832.0  1392.475742  1129.015317
68   8960.0  1403.050563  1134.713606
69   9088.0  1413.563332  1133.665378
70   9216.0  1409.955820  1127.222604
71   9344.0  1405.859161  1422.962682
72   9472.0  1403.946226  1431.461395
73   9600.0  1393.724015  1428.577664
74   9728.0  1407.513490  1438.615060
75   9856.0  1420.547375  1442.005285
76   9984.0  1402.673009  1448.820885
77  10112.0  1418.148573  1450.572145
78  10240.0  1422.633858  1464.953933
79  10368.0  1417.277556  1461.188145
80  10496.0  1420.189939  1465.491636
81  10624.0  1421.473094  1463.260484
82  10752.0  1409.427255  1470.049996
83  10880.0  1404.088724  1477.157006
84  11008.0  1422.004351  1475.620959
85  11136.0  1425.758185  1485.202505
86  11264.0  1431.351554  1484.921399
87  11392.0  1424.294121  1491.254743
88  11520.0  1428.022545  1494.247371
89  11648.0  1427.784029  1498.814314
90  11776.0  1436.136658  1502.769595
91  11904.0  1447.339033  1509.698889
92  12032.0  1427.186683  1512.825150
93  12160.0  1423.977867  1516.111112
94  12288.0  1438.619150  1422.520536
95  12416.0  1454.030155  1395.575299
96  12544.0  1447.487422  1394.804423
97  12672.0  1453.132657  1392.027710
Triton 比 Torch JIT 快 4 倍。这证实了我们对 Torch JIT 在此未执行任何融合的猜测。
  • Triton 明显快于 torch.softmax – 此外,它还更易读、易懂且易于维护。但请注意,PyTorch 的 softmax 操作更通用,适用于任何形状的张量。

  • 脚本总运行时间: (0 分 23.354 秒)

下载 Jupyter notebook: 02-fused-softmax.ipynb