triton.language.inline_asm_elementwise¶
- triton.language.inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: dtype | Sequence[dtype], is_pure: bool, pack: int, _semantic=None)¶
在张量上执行内联汇编。本质上,这是以内联汇编作为函数的
map
操作。输入张量
args
会隐式地广播到相同的形状。dtype
可以是类型的元组,在这种情况下,输出将是张量的元组。内联汇编的每次调用会一次性处理
pack
个元素。块接收哪组输入是未指定的。小于 4 字节的输入元素会被打包到 4 字节寄存器中。此操作不支持空的
dtype
– 内联汇编必须返回至少一个张量,即使您不需要它。您可以通过返回任意类型的虚拟张量来解决此问题;如果您不使用它,这应该不会产生任何开销。使用 PTX 汇编的示例
@triton.jit def kernel(A, B, C, D, BLOCK: tl.constexpr): a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor # For each (a,b) in zip(a,b), perform the following: # - Let ai be `a` converted to int32. # - Let af be `a` converted to float. # - Let m be the max of ai and b. # - Return ai and mi. # Do the above 4 elements at a time. (c, d) = tl.inline_asm_elementwise( asm=""" { // Unpack `a` into `ai`. .reg .b8 tmp<4>; mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; cvt.u32.u8 $0, tmp0; cvt.u32.u8 $1, tmp1; cvt.u32.u8 $2, tmp2; cvt.u32.u8 $3, tmp3; } // Convert `ai` to float. cvt.rn.f32.s32 $4, $0; cvt.rn.f32.s32 $5, $1; cvt.rn.f32.s32 $6, $2; cvt.rn.f32.s32 $7, $3; // Take max of `ai` and `b`. max.f32 $4, $4, $9; max.f32 $5, $5, $10; max.f32 $6, $6, $11; max.f32 $7, $7, $12; """, constraints=( # 8 output registers, namely # $0=ai0, $1=ai1, $2=ai2, $3=ai3, # $4=m0, $5=m1, $6=m2, $7=m3. "=r,=r,=r,=r,=r,=r,=r,=r," # 5 input registers, namely # $8=ai, # $9=b0, $10=b1, $11=b2, $12=b3. # The four elements from `a` are all packed into one register. "r,r,r,r,r"), args=[a, b], dtype=(tl.int32, tl.float32), is_pure=True, pack=4, ) tl.store(C + tl.arange(0, BLOCK), c) tl.store(D + tl.arange(0, BLOCK), d)
- 参数:
asm – 要运行的汇编代码。必须与目标的汇编格式匹配。
constraints – LLVM 格式的汇编约束。
args – 输入张量,其值会传递给汇编块
dtype – 返回张量的元素类型
is_pure – 如果为真,编译器假定汇编块没有副作用
pack – 单个内联汇编实例要处理的元素数量
- 返回:
一个张量或给定数据类型的张量元组