调试 Triton

本教程提供了调试 Triton 程序的指南,主要面向 Triton 用户。对探索 Triton 后端(包括 MLIR 代码转换和 LLVM 代码生成)感兴趣的开发人员,可以参考此部分来了解调试选项。

有关浮点计算的编译器级检测,请参阅浮点消毒器 (FpSan)

使用 Triton 的调试操作

Triton 包含四个调试运算符,允许用户检查和查看张量值。

  • static_printstatic_assert 用于编译时调试。

  • device_printdevice_assert 用于运行时调试。

device_assert 仅在 TRITON_DEBUG 设置为 1 时执行。其他调试运算符无论 TRITON_DEBUG 的值如何都会执行。

使用解释器

解释器是调试 Triton 程序的一种直接且有用的工具。它允许 Triton 用户在 CPU 上运行 Triton 程序,并检查每个操作的中间结果。要启用解释器模式,请将环境变量 TRITON_INTERPRET 设置为 1。此设置会导致所有 Triton 内核跳过编译,并由解释器使用 Triton 操作的 numpy 等价物进行模拟。解释器按顺序处理每个 Triton 程序实例,一次执行一个操作。

使用解释器主要有三种方法:

  • 使用 Python print 函数打印每个操作的中间结果。要检查整个张量,请使用 print(tensor)。要检查 idx 处的单个张量值,请使用 print(tensor.handle.data[idx])

  • 附加 pdb 以对 Triton 程序进行分步调试。

    TRITON_INTERPRET=1 pdb main.py
    b main.py:<line number>
    r
    
  • 导入 pdb 包并在 Triton 程序中设置断点。

    import triton
    import triton.language as tl
    import pdb
    
    @triton.jit
    def kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
      pdb.set_trace()
      offs = tl.arange(0, BLOCK_SIZE)
      x = tl.load(x_ptr + offs)
      tl.store(y_ptr + offs, x)
    

局限性

解释器有几个已知的局限性:

  • 它不支持对 bfloat16 数值类型进行操作。要对 bfloat16 张量进行操作,请使用 tl.cast(tensor) 将张量转换为 float32

  • 它不支持间接内存访问模式,例如:

    ptr = tl.load(ptr)
    x = tl.load(ptr)
    

使用第三方工具

对于 NVIDIA GPU 上的调试,compute-sanitizer 是检查数据竞争和内存访问问题的有效工具。要使用它,请在运行 Triton 程序的命令前加上 compute-sanitizer

对于 AMD GPU 上的调试,您可以尝试用于 ROCm 的 LLVM AddressSanitizer

如需详细可视化 Triton 程序中的内存访问,请考虑使用 triton-viz 工具,它与底层 GPU 无关。