triton.language.inline_asm_elementwise

triton.language.inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: dtype | Sequence[dtype], is_pure: bool, pack: int)

在张量上执行内联汇编。本质上,这相当于一个以内联汇编作为函数的 map 操作。

输入张量 args 会隐式地广播到相同的形状。

dtype 可以是类型的元组,在这种情况下,输出将是张量的元组。

每次调用内联汇编会处理 pack 个元素。具体哪些输入元素会分配给一个块是未指定的。小于 4 字节的输入元素会被打包到 4 字节寄存器中。

此操作不支持空的 dtype – 内联汇编必须至少返回一个张量,即使您不需要它。您可以通过返回一个任意类型的虚拟张量来规避此问题;如果您不使用它,这不应产生任何额外开销。

使用 PTX 汇编的示例

@triton.jit
def kernel(A, B, C, D, BLOCK: tl.constexpr):
    a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor
    b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor

    # For each (a,b) in zip(a,b), perform the following:
    # - Let ai be `a` converted to int32.
    # - Let af be `a` converted to float.
    # - Let m be the max of ai and b.
    # - Return ai and mi.
    # Do the above 4 elements at a time.
    (c, d) = tl.inline_asm_elementwise(
        asm="""
        {
            // Unpack `a` into `ai`.
            .reg .b8 tmp<4>;
            mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8;
            cvt.u32.u8 $0, tmp0;
            cvt.u32.u8 $1, tmp1;
            cvt.u32.u8 $2, tmp2;
            cvt.u32.u8 $3, tmp3;
        }
        // Convert `ai` to float.
        cvt.rn.f32.s32 $4, $0;
        cvt.rn.f32.s32 $5, $1;
        cvt.rn.f32.s32 $6, $2;
        cvt.rn.f32.s32 $7, $3;
        // Take max of `ai` and `b`.
        max.f32 $4, $4, $9;
        max.f32 $5, $5, $10;
        max.f32 $6, $6, $11;
        max.f32 $7, $7, $12;
        """,
        constraints=(
            # 8 output registers, namely
            #   $0=ai0, $1=ai1, $2=ai2, $3=ai3,
            #   $4=m0,  $5=m1,  $6=m2,  $7=m3.
            "=r,=r,=r,=r,=r,=r,=r,=r,"
            # 5 input registers, namely
            #   $8=ai,
            #   $9=b0, $10=b1, $11=b2, $12=b3.
            # The four elements from `a` are all packed into one register.
            "r,r,r,r,r"),
        args=[a, b],
        dtype=(tl.int32, tl.float32),
        is_pure=True,
        pack=4,
    )
    tl.store(C + tl.arange(0, BLOCK), c)
    tl.store(D + tl.arange(0, BLOCK), d)
参数:
  • asm – 要执行的汇编代码。必须与目标的汇编格式匹配。

  • constraintsLLVM 格式的汇编约束。

  • args – 输入张量,其值会传递给汇编块。

  • dtype – 返回张量的元素类型或类型元组。

  • is_pure – 如果为 true,编译器假定汇编块没有副作用。

  • pack – 一个内联汇编实例处理的元素数量。

  • _builder – 构建器。

返回:

一个张量或给定 dtype 的张量元组。