triton.language

编程模型

tensor

表示 N 维值或指针数组。

tensor_descriptor

表示全局内存中张量的描述符。

program_id

返回当前程序实例沿给定 axis 的 ID。

num_programs

返回沿给定 axis 启动的程序实例数量。

创建操作 (Creation Ops)

arange

返回半开区间 [start, end) 内的连续值。

cat

连接给定的块

full

返回一个填充了给定 shapedtype 的标量值的张量。

zeros

返回一个填充了给定 shapedtype 的标量值 0 的张量。

zeros_like

返回一个形状和类型与给定张量相同的零张量。

cast

将张量转换为给定的 dtype

形状操作 (Shape Manipulation Ops)

broadcast

尝试将两个给定块广播到共同的兼容形状。

broadcast_to

尝试将给定张量广播到新的 shape

expand_dims

通过插入新的长度为 1 的维度来扩展张量的形状。

interleave

沿最后维度交错两个张量的值。

join

在新的次要维度中连接给定张量。

permute

置换张量的维度。

ravel

返回 x 的连续扁平视图。

reshape

返回一个元素数量与输入相同但形状不同的张量。

split

沿其最后维度将张量分成两半,该维度必须大小为 2。

trans

置换张量的维度。

view

返回一个元素与 input 相同但形状不同的张量。

线性代数操作 (Linear Algebra Ops)

dot

返回两个块的矩阵乘积。

dot_scaled

返回两个块的微缩放格式的矩阵乘积。

内存/指针操作 (Memory/Pointer Ops)

load

返回一个数据张量,其值从 pointer 定义的内存位置加载

store

将数据张量存储到 pointer 定义的内存位置。

make_tensor_descriptor

创建张量描述符对象

load_tensor_descriptor

从张量描述符加载数据块。

store_tensor_descriptor

将数据块存储到张量描述符。

make_block_ptr

返回父张量中块的指针

advance

移动块指针

索引操作 (Indexing Ops)

flip

沿维度 dim 翻转张量 x

where

根据 condition 返回来自 xy 的元素张量。

swizzle2d

将行主序 size_i * size_j 矩阵的索引转换为每个 size_g 行组的列主序矩阵的索引。

数学操作 (Math Ops)

abs

计算 x 的元素级绝对值。

cdiv

计算 x 除以 div 的向上取整除法。

ceil

计算 x 的元素级向上取整。

clamp

将输入张量 x 钳制在 [min, max] 范围内。

cos

计算 x 的元素级余弦。

div_rn

计算 xy 的元素级精确除法(根据 IEEE 标准四舍五入到最近的整数)。

erf

计算 x 的元素级误差函数。

exp

计算 x 的元素级指数。

exp2

计算 x 的元素级指数(以 2 为底)。

fdiv

计算 xy 的元素级快速除法。

floor

计算 x 的元素级向下取整。

fma

计算 xyz 的元素级融合乘加运算。

log

计算 x 的元素级自然对数。

log2

计算 x 的元素级对数(以 2 为底)。

maximum

计算 xy 的元素级最大值。

minimum

计算 xy 的元素级最小值。

rsqrt

计算 x 的元素级平方根倒数。

sigmoid

计算 x 的元素级 Sigmoid。

sin

计算 x 的元素级正弦。

softmax

计算 x 的元素级 Softmax。

sqrt

计算 x 的元素级快速平方根。

sqrt_rn

计算 x 的元素级精确平方根(根据 IEEE 标准四舍五入到最近的整数)。

umulhi

计算 xy 的 2N 位乘积的元素级最高 N 位。

归约操作 (Reduction Ops)

argmax

返回 input 张量中沿给定 axis 的所有元素的最大索引。

argmin

返回 input 张量中沿给定 axis 的所有元素的最小索引。

max

返回 input 张量中沿给定 axis 的所有元素的最大值。

min

返回 input 张量中沿给定 axis 的所有元素的最小值。

reduce

将 combine_fn 应用于 input 张量中沿给定 axis 的所有元素。

sum

返回 input 张量中沿给定 axis 的所有元素的和。

xor_sum

返回 input 张量中沿给定 axis 的所有元素的异或和。

扫描/排序操作 (Scan/Sort Ops)

associative_scan

将 combine_fn 应用于 input 张量中沿给定 axis 的每个带有进位的元素,并更新进位。

cumprod

返回 input 张量中沿给定 axis 的所有元素的累积积。

cumsum

返回 input 张量中沿给定 axis 的所有元素的累积和。

histogram

根据输入张量计算直方图,具有 num_bins 个箱子,箱子宽度为 1 并从 0 开始。

sort

gather

沿给定维度从张量中收集。

原子操作 (Atomic Ops)

atomic_add

pointer 指定的内存位置执行原子加法。

atomic_and

pointer 指定的内存位置执行原子逻辑与。

atomic_cas

pointer 指定的内存位置执行原子比较并交换。

atomic_max

pointer 指定的内存位置执行原子最大值。

atomic_min

pointer 指定的内存位置执行原子最小值。

atomic_or

pointer 指定的内存位置执行原子逻辑或。

atomic_xchg

pointer 指定的内存位置执行原子交换。

atomic_xor

pointer 指定的内存位置执行原子逻辑异或。

随机数生成

randint4x

给定一个 seed 标量和一个 offset 块,返回四个随机 int32 块。

randint

给定一个 seed 标量和一个 offset 块,返回一个随机 int32 块。

rand

给定一个 seed 标量和一个 offset 块,返回一个在 \(U(0, 1)\) 中的随机 float32 块。

randn

给定一个 seed 标量和一个 offset 块,返回一个在 \(\mathcal{N}(0, 1)\) 中的随机 float32 块。

迭代器

range

永远向上计数的迭代器。

static_range

永远向上计数的迭代器。

内联汇编

inline_asm_elementwise

在张量上执行内联汇编。

编译器提示操作 (Compiler Hint Ops)

assume

允许编译器假定 cond 为 True。

debug_barrier

插入一个屏障以同步块中的所有线程。

max_constancy

告知编译器 input 中的前 value 个值是常量。

max_contiguous

告知编译器 input 中的前 value 个值是连续的。

multiple_of

告知编译器 input 中的所有值都是 value 的倍数。

调试操作 (Debug Ops)

static_print

在编译时打印值。

static_assert

在编译时断言条件。

device_print

从设备在运行时打印值。

device_assert

从设备在运行时断言条件。