triton.language

编程模型

tensor

表示 N 维值的数组或指针。

program_id

返回当前程序实例在给定 axis 方向上的 ID。

num_programs

返回沿给定 axis 方向启动的程序实例数。

创建操作

arange

返回半开区间 [start, end) 内的连续值。

cat

连接给定的块

full

返回给定 shapedtype 并填充标量值的张量。

zeros

返回给定 shapedtype 并填充标量值 0 的张量。

zeros_like

返回一个与给定张量具有相同形状和类型的全零张量。

cast

将张量转换为给定的 dtype

形状操作

broadcast

尝试将给定的两个块广播到兼容的通用形状。

broadcast_to

尝试将给定的张量广播到新的 shape

expand_dims

通过插入新的长度为 1 的维度来扩展张量的形状。

interleave

沿最后维度交错两个张量的值。

join

在新的较小维度上连接给定的张量。

permute

置换张量的维度。

ravel

返回 x 的连续展平视图。

reshape

返回一个与输入张量元素数量相同但具有指定形状的张量。

split

沿最后一个维度将张量分成两部分,该维度的大小必须为 2。

trans

置换张量的维度。

view

返回一个与 input 具有相同元素但形状不同的张量。

线性代数操作

dot

返回两个块的矩阵乘积。

dot_scaled

以微缩格式返回两个块的矩阵乘积。

内存/指针操作

load

返回一个数据张量,其值从由 pointer 定义的内存位置加载。

store

将数据张量存储到由 pointer 定义的内存位置。

make_block_ptr

返回父张量中一个块的指针

advance

前进块指针

索引操作

flip

沿维度 dim 翻转张量 x

where

根据 condition 的结果,返回一个包含来自 xy 元素的张量。

swizzle2d

将行主序 size_i * size_j 矩阵的索引转换为列主序矩阵的索引,适用于每组 size_g 行。

数学操作

abs

计算 x 的元素级绝对值。

cdiv

计算 x 除以 div 的向上取整结果

ceil

计算 x 的元素级向上取整。

clamp

将输入张量 x 限制在 [min, max] 范围内。

cos

计算 x 的元素级余弦值。

div_rn

计算 xy 的元素级精确除法(根据 IEEE 标准四舍五入到最接近的整数)。

erf

计算 x 的元素级误差函数。

exp

计算 x 的元素级指数。

exp2

计算 x 的元素级指数(以 2 为底)。

fdiv

计算 xy 的元素级快速除法。

floor

计算 x 的元素级向下取整。

fma

计算 xyz 的元素级融合乘加。

log

计算 x 的元素级自然对数。

log2

计算 x 的元素级对数(以 2 为底)。

maximum

计算 xy 的元素级最大值。

minimum

计算 xy 的元素级最小值。

rsqrt

计算 x 的元素级平方根倒数。

sigmoid

计算 x 的元素级 Sigmoid 函数。

sin

计算 x 的元素级正弦值。

softmax

计算 x 的元素级 Softmax 函数。

sqrt

计算 x 的元素级快速平方根。

sqrt_rn

计算 x 的元素级精确平方根(根据 IEEE 标准四舍五入到最接近的整数)。

umulhi

计算 xy 的 2N 位乘积中元素级的最高 N 位。

归约操作

argmax

返回 input 张量沿给定 axis 方向上所有元素的最大索引。

argmin

返回 input 张量沿给定 axis 方向上所有元素的最小索引。

max

返回 input 张量沿给定 axis 方向上所有元素的最大值。

min

返回 input 张量沿给定 axis 方向上所有元素的最小值。

reduce

将 combine_fn 应用于 input 张量沿给定 axis 方向上的所有元素。

sum

返回 input 张量沿给定 axis 方向上所有元素的总和。

xor_sum

返回 input 张量沿给定 axis 方向上所有元素的异或总和。

扫描/排序操作

associative_scan

将 combine_fn 应用于 input 张量沿给定 axis 方向上带有进位的每个元素,并更新进位

cumprod

返回 input 张量沿给定 axis 方向上所有元素的累积乘积。

cumsum

返回 input 张量沿给定 axis 方向上所有元素的累积总和。

histogram

根据输入张量计算直方图,该直方图包含 num_bins 个 bin,这些 bin 的宽度为 1 且从 0 开始。

sort

gather

沿给定维度从张量中收集元素。

原子操作

atomic_add

在由 pointer 指定的内存位置执行原子加法。

atomic_and

在由 pointer 指定的内存位置执行原子逻辑与。

atomic_cas

在由 pointer 指定的内存位置执行原子比较并交换。

atomic_max

在由 pointer 指定的内存位置执行原子最大值操作。

atomic_min

在由 pointer 指定的内存位置执行原子最小值操作。

atomic_or

在由 pointer 指定的内存位置执行原子逻辑或。

atomic_xchg

在由 pointer 指定的内存位置执行原子交换。

atomic_xor

在由 pointer 指定的内存位置执行原子逻辑异或。

随机数生成

randint4x

给定一个 seed 标量和一个 offset 块,返回四个随机的 int32 块。

randint

给定一个 seed 标量和一个 offset 块,返回一个随机的 int32 块。

rand

给定一个 seed 标量和一个 offset 块,返回一个在 \(U(0, 1)\) 中的随机 float32 块。

randn

给定一个 seed 标量和一个 offset 块,返回一个在 \(\mathcal{N}(0, 1)\) 中的随机 float32 块。

迭代器

range

永远向上计数的迭代器。

static_range

永远向上计数的迭代器。

内联汇编

inline_asm_elementwise

在张量上执行内联汇编。

编译器提示操作

assume

允许编译器假定 cond 为 True。

debug_barrier

插入一个屏障以同步块中的所有线程。

max_constancy

告知编译器 input 中的前 value 个值是常量。

max_contiguous

告知编译器 input 中的前 value 个值是连续的。

multiple_of

告知编译器 input 中的所有值都是 value 的倍数。

调试操作

static_print

在编译时打印值。

static_assert

在编译时断言条件。

device_print

在运行时从设备端打印值。

device_assert

在运行时从设备端断言条件。