triton.language.device_print

triton.language.device_print(prefix, *args, hex=False)

从设备在运行时打印值。字符串格式化不适用于运行时值,因此应将要打印的值作为参数提供。第一个值必须是字符串,所有后续值必须是标量或张量。

调用 Python 内置的 print 函数与调用此函数相同,并且参数要求将与此函数匹配(而非 print 的常规要求)。

tl.device_print("pid", pid)
print("pid", pid)

在 CUDA 上,printf 通过有限大小的缓冲区进行流式传输(在单个主机上,我们测得默认大小为 6912 KiB,但这在不同 GPU 和 CUDA 版本之间可能不一致)。如果您注意到某些 printf 被丢弃,可以通过调用以下函数来增加缓冲区大小:

triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes)

如果您在使用 printf 的内核运行后尝试更改此值,CUDA 可能会引发错误。此处设置的值可能仅影响当前设备(因此如果您有多块 GPU,则需要多次调用此函数)。

参数:
  • prefix – 在值之前打印的前缀。这必须是字符串文字。

  • args – 要打印的值。可以是任何张量或标量。

  • hex – 将所有值打印为十六进制而不是十进制