triton.language

编程模型

tensor

表示 N 维值或指针数组。

tensor_descriptor

表示全局内存中的张量的描述符。

program_id

返回沿给定 axis 的当前程序实例的 ID。

num_programs

返回沿给定 axis 启动的程序实例数。

创建操作

arange

返回半开区间 [start, end) 内的连续值。

cat

连接给定的块

full

返回一个张量,其中填充了给定 shapedtype 的标量值。

zeros

返回一个张量,其中填充了给定 shapedtype 的标量值 0。

zeros_like

返回一个与给定张量具有相同形状和类型的零张量。

cast

将张量转换为给定的 dtype

形状操作

broadcast

尝试将两个给定块广播到共同兼容的形状。

broadcast_to

尝试将给定张量广播到新的 shape

expand_dims

通过插入新的长度为 1 的维度来扩展张量的形状。

interleave

沿最后一个维度交错两个张量的值。

join

将给定的张量连接到一个新的次要维度中。

permute

置换张量的维度。

ravel

返回 x 的连续扁平视图。

reshape

返回一个张量,其元素数量与输入相同,但具有提供的形状。

split

沿最后一个维度将张量拆分为两部分,其大小必须为 2。

trans

置换张量的维度。

view

返回一个张量,其元素与 input 相同,但形状不同。

线性代数操作

dot

返回两个块的矩阵乘积。

dot_scaled

以微缩放格式返回两个块的矩阵乘积。

内存/指针操作

load

返回一个数据张量,其值从 pointer 定义的内存位置加载

store

将数据张量存储到 pointer 定义的内存位置。

make_tensor_descriptor

创建一个张量描述符对象

load_tensor_descriptor

从张量描述符加载数据块。

store_tensor_descriptor

将数据块存储到张量描述符。

make_block_ptr

返回指向父张量中块的指针

advance

移动块指针

索引操作

flip

沿维度 dim 翻转张量 x

where

根据 condition 返回来自 xy 的元素张量。

swizzle2d

对于每组 size_g 行,将行主序 size_i * size_j 矩阵的索引转换为列主序矩阵的索引。

数学操作

abs

计算 x 的按元素绝对值。

cdiv

计算 x 除以 div 的向上取整除法。

ceil

计算 x 的按元素上取整。

clamp

将输入张量 x 钳制在范围 [min, max] 内。

cos

计算 x 的按元素余弦。

div_rn

计算 xy 的按元素精确除法(根据 IEEE 标准四舍五入到最近)。

erf

计算 x 的按元素误差函数。

exp

计算 x 的按元素指数。

exp2

计算 x 的按元素指数(以 2 为底)。

fdiv

计算 xy 的按元素快速除法。

floor

计算 x 的按元素下取整。

fma

计算 xyz 的按元素融合乘加。

log

计算 x 的按元素自然对数。

log2

计算 x 的按元素对数(以 2 为底)。

maximum

计算 xy 的按元素最大值。

minimum

计算 xy 的按元素最小值。

rsqrt

计算 x 的按元素逆平方根。

sigmoid

计算 x 的元素级 Sigmoid。

sin

计算 x 的按元素正弦。

softmax

计算 x 的元素级 Softmax。

sqrt

计算 x 的按元素快速平方根。

sqrt_rn

计算 x 的按元素精确平方根(根据 IEEE 标准四舍五入到最近)。

umulhi

计算 xy 的 2N 位乘积的按元素最高 N 位。

归约操作

argmax

返回 input 张量中沿给定 axis 的所有元素的最大索引。

argmin

返回 input 张量中沿给定 axis 的所有元素的最小索引。

max

返回 input 张量中沿给定 axis 的所有元素的最大值。

min

返回 input 张量中沿给定 axis 的所有元素的最小值。

reduce

沿提供的 axisinput 张量中的所有元素应用 combine_fn

sum

返回 input 张量中沿给定 axis 的所有元素的和。

xor_sum

返回 input 张量中沿给定 axis 的所有元素的异或和。

扫描/排序操作

associative_scan

沿提供的 axisinput 张量中的每个元素应用 combine_fn 并更新进位

cumprod

返回 input 张量中沿给定 axis 的所有元素的累积积。

cumsum

返回 input 张量中沿给定 axis 的所有元素的累积和。

histogram

根据输入张量计算直方图,其中包含 num_bins 个 bin,bin 的宽度为 1,从 0 开始。

sort

topk

返回沿指定维度输入张量的 k 个最大元素。

gather

沿给定维度从张量中收集。

原子操作

atomic_add

pointer 指定的内存位置执行原子加法。

atomic_and

pointer 指定的内存位置执行原子逻辑与。

atomic_cas

pointer 指定的内存位置执行原子比较和交换。

atomic_max

pointer 指定的内存位置执行原子最大值。

atomic_min

pointer 指定的内存位置执行原子最小值。

atomic_or

pointer 指定的内存位置执行原子逻辑或。

atomic_xchg

pointer 指定的内存位置执行原子交换。

atomic_xor

pointer 指定的内存位置执行原子逻辑异或。

随机数生成

randint4x

给定一个 seed 标量和一个 offset 块,返回四个随机 int32 块。

randint

给定一个 seed 标量和一个 offset 块,返回一个随机 int32 块。

rand

给定一个 seed 标量和一个 offset 块,返回一个在 \(U(0, 1)\) 中的随机 float32 块。

randn

给定一个 seed 标量和一个 offset 块,返回一个在 \(\mathcal{N}(0, 1)\) 中的随机 float32 块。

迭代器

range

永远向上计数的迭代器。

static_range

永远向上计数的迭代器。

内联汇编

inline_asm_elementwise

对张量执行内联汇编。

编译器提示操作

assume

允许编译器假设 cond 为 True。

debug_barrier

插入一个屏障以同步块中的所有线程。

max_constancy

告知编译器 input 中的前 value 个值是常量。

max_contiguous

告知编译器 input 中的前 value 个值是连续的。

multiple_of

告知编译器 input 中的值都是 value 的倍数。

调试操作

static_print

在编译时打印值。

static_assert

在编译时断言条件。

device_print

在运行时从设备打印值。

device_assert

在运行时从设备断言条件。