Triton Semantics¶
Triton 大体上遵循 NumPy 的语义,但有少数例外。本文档将介绍 Triton 中支持的一些数组计算特性,并涵盖 Triton 语义与 NumPy 不同的地方。
Type Promotion¶
类型提升(Type Promotion)发生在不同数据类型的张量用于运算时。对于与 dunder 方法相关的二元运算以及三元函数 tl.where 的最后两个参数,Triton 会根据一个类型种类(dtypes 的集合)层次结构自动将输入张量转换为一个公共数据类型: {bool} < {integral dypes} < {floating point dtypes}。
算法如下:
种类 如果一个张量的 dtype 属于更高种类,则另一个张量将被提升到该 dtype:
(int32, bfloat16) -> bfloat16宽度 如果两个张量的 dtypes 属于相同种类,其中一个比另一个宽度更高,则另一个张量将被提升到该 dtype:
(float32, float16) -> float32优先 float16 如果两个张量具有相同的宽度和符号性但不同的 dtype(
float16和bfloat16或不同的fp8类型),它们都将被提升到float16。(float16, bfloat16) -> float16优先无符号 否则(相同宽度,不同符号性),它们将被提升到无符号 dtype:
(int32, uint32) -> uint32
当涉及标量时,规则会略有不同。这里我们说的标量是指数值字面量、标记为 tl.constexpr 的变量或它们的组合。这些由 NumPy 标量表示,具有 bool、int 和 float 类型。
当运算涉及张量和标量时:
如果标量的种类低于或等于张量,它将不参与类型提升:
(uint8, int) -> uint8如果标量的种类更高,我们选择一个最小的 dtype 来容纳它,整数类型按
int32<uint32<int64<uint64排序,浮点数类型按float32<float64排序。然后,张量和标量都被提升到该 dtype:(int16, 4.0) -> float32
Broadcasting¶
广播(Broadcasting)允许对不同形状的张量进行运算,通过自动扩展它们的形状使其兼容,而无需复制数据。这遵循以下规则:
如果其中一个张量的形状较短,则在左侧用 1 填充,直到两个张量具有相同的维度数:
((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))两个维度兼容,如果它们相等,或者其中一个为 1。值为 1 的维度将被扩展以匹配另一个张量的维度。
((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))
Differences with NumPy¶
整数除法的 C 舍入 Triton 中的运算符出于效率考虑遵循 C 语义而非 Python 语义。因此,对于符号不同的整数,int // int 实现的是 C 语言中的向零舍入,而不是 Python 中的向负无穷舍入。出于同样的原因,模运算符 int % int(其定义为 a % b = a - b * (a // b))也遵循 C 语义而不是 Python 语义。
或许令人困惑的是,当所有输入都是标量时,整数除法和模运算遵循 Python 语义。