Triton 语义

Triton 大部分遵循 NumPy 的语义,只有少数例外。本文档将介绍 Triton 中支持的一些数组计算特性,并涵盖 Triton 语义与 NumPy 不同之处的例外情况。

类型提升

类型提升发生在操作中使用不同数据类型的张量时。对于与双下划线方法 (dunder methods)相关的二元操作以及三元函数tl.where的最后两个参数,Triton 会根据种类(dtypes 集合)的层次结构自动将输入张量转换为通用数据类型:{bool} < {integral dypes} < {floating point dtypes}

算法如下所示

  1. 种类 如果一个张量的 dtype 属于更高级的种类,则另一个张量将被提升到这个 dtype:(int32, bfloat16) -> bfloat16

  2. 宽度 如果两个张量的 dtype 属于相同的种类,并且其中一个具有更高的宽度,则另一个将被提升到这个 dtype:(float32, float16) -> float32

  3. 优先使用 float16 如果两个张量宽度和符号属性相同,但 dtype 不同(float16bfloat16 或不同的 fp8 类型),它们都将被提升到 float16(float16, bfloat16) -> float16

  4. 优先使用无符号类型 否则(宽度相同,符号属性不同),它们将被提升到无符号 dtype:(int32, uint32) -> uint32

当涉及标量时,规则略有不同。这里的标量指代数字字面值、标记为tl.constexpr的变量或它们的组合。它们由 NumPy 标量表示,类型为boolintfloat

当操作涉及张量和标量时

  1. 如果标量的种类低于或等于张量,它将不参与类型提升:(uint8, int) -> uint8

  2. 如果标量属于更高级的种类,我们会选择它能适应的最低 dtype,对于整数,顺序是 int32 < uint32 < int64 < uint64;对于浮点数,顺序是 float32 < float64。然后,张量和标量都将被提升到这个 dtype:(int16, 4.0) -> float32

广播

广播允许在不同形状的张量上执行操作,通过自动将它们的形状扩展到兼容的大小,而无需复制数据。这遵循以下规则

  1. 如果一个张量形状维度较少,则在左侧用 1 填充,直到两个张量具有相同的维度数:((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))

  2. 如果两个维度相等,或者其中一个为 1,则它们是兼容的。值为 1 的维度将被扩展以匹配另一个张量的维度。((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))

与 NumPy 的区别

整数除法中的 C 风格舍入 Triton 中的运算符为了效率遵循 C 语义而不是 Python 语义。因此,int // int对于混合符号的整数实现的是像 C 一样向零舍入,而不是像 Python 那样向负无穷舍入。出于同样的原因,模运算符int % int(定义为a % b = a - b * (a // b))也遵循 C 语义而不是 Python 语义。

可能令人困惑的是,当所有输入都是标量时,整数除法和模运算遵循 Python 语义。