Triton 语义¶
Triton 大部分遵循 NumPy 的语义,只有少数例外。本文档将介绍 Triton 中支持的一些数组计算特性,并涵盖 Triton 语义与 NumPy 不同之处的例外情况。
类型提升¶
类型提升发生在操作中使用不同数据类型的张量时。对于与双下划线方法 (dunder methods)相关的二元操作以及三元函数tl.where
的最后两个参数,Triton 会根据种类(dtypes 集合)的层次结构自动将输入张量转换为通用数据类型:{bool} < {integral dypes} < {floating point dtypes}
。
算法如下所示
种类 如果一个张量的 dtype 属于更高级的种类,则另一个张量将被提升到这个 dtype:
(int32, bfloat16) -> bfloat16
宽度 如果两个张量的 dtype 属于相同的种类,并且其中一个具有更高的宽度,则另一个将被提升到这个 dtype:
(float32, float16) -> float32
优先使用 float16 如果两个张量宽度和符号属性相同,但 dtype 不同(
float16
和bfloat16
或不同的fp8
类型),它们都将被提升到float16
。(float16, bfloat16) -> float16
优先使用无符号类型 否则(宽度相同,符号属性不同),它们将被提升到无符号 dtype:
(int32, uint32) -> uint32
当涉及标量时,规则略有不同。这里的标量指代数字字面值、标记为tl.constexpr的变量或它们的组合。它们由 NumPy 标量表示,类型为bool
、int
和float
。
当操作涉及张量和标量时
如果标量的种类低于或等于张量,它将不参与类型提升:
(uint8, int) -> uint8
如果标量属于更高级的种类,我们会选择它能适应的最低 dtype,对于整数,顺序是
int32
<uint32
<int64
<uint64
;对于浮点数,顺序是float32
<float64
。然后,张量和标量都将被提升到这个 dtype:(int16, 4.0) -> float32
广播¶
广播允许在不同形状的张量上执行操作,通过自动将它们的形状扩展到兼容的大小,而无需复制数据。这遵循以下规则
如果一个张量形状维度较少,则在左侧用 1 填充,直到两个张量具有相同的维度数:
((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))
如果两个维度相等,或者其中一个为 1,则它们是兼容的。值为 1 的维度将被扩展以匹配另一个张量的维度。
((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))
与 NumPy 的区别¶
整数除法中的 C 风格舍入 Triton 中的运算符为了效率遵循 C 语义而不是 Python 语义。因此,int // int
对于混合符号的整数实现的是像 C 一样向零舍入,而不是像 Python 那样向负无穷舍入。出于同样的原因,模运算符int % int
(定义为a % b = a - b * (a // b)
)也遵循 C 语义而不是 Python 语义。
可能令人困惑的是,当所有输入都是标量时,整数除法和模运算遵循 Python 语义。