triton.language.device_print
- triton.language.device_print(prefix, *args, hex=False, _semantic=None)
在运行时从设备打印数值。字符串格式化不适用于运行时数值,因此您应将要打印的数值作为参数提供。第一个值必须是字符串,后续所有值都必须是标量或张量。
调用 Python 内置的
print等同于调用此函数,且对参数的要求与此函数保持一致(而非print函数的常规要求)。tl.device_print("pid", pid) print("pid", pid)
在 CUDA 上,printf 内容通过有限大小的缓冲区进行流式传输(在某台主机上,我们测得默认大小为 6912 KiB,但这在不同 GPU 和 CUDA 版本之间可能不一致)。如果您发现某些 printf 内容丢失,可以通过调用来增加缓冲区大小:
triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes)
如果您在运行使用 printf 的内核后尝试更改此值,CUDA 可能会报错。此处设置的值可能仅影响当前设备(因此如果您有多个 GPU,则需要多次调用它)。
- 参数:
prefix – 在数值前打印的前缀。此项必须为字符串字面量。
args – 要打印的数值。可以是任何张量或标量。
hex – 将所有数值以十六进制而非十进制打印