triton.language.make_tensor_descriptor

triton.language.make_tensor_descriptor(base: tensor, shape: List[tensor], strides: List[tensor], block_shape: List[constexpr], padding_option='zero', _semantic=None) tensor_descriptor

创建一个张量描述符对象

参数:
  • base – 张量的基指针,必须是 16 字节对齐的

  • shape – 一个表示张量形状的非负整数列表

  • strides – 张量的步长列表。前导维度必须是 16 字节步长的倍数,且最后一个维度必须是连续的。

  • block_shape – 从全局内存加载/存储的数据块的形状

注意

在支持 TMA 的 NVIDIA GPU 上,这将产生一个 TMA 描述符对象,对此描述符的加载和存储将由 TMA 硬件支持。

目前仅支持 2-5 维张量。

示例

@triton.jit
def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
    desc = tl.make_tensor_descriptor(
        in_out_ptr,
        shape=[M, N],
        strides=[N, 1],
        block_shape=[M_BLOCK, N_BLOCK],
    )

    moffset = tl.program_id(0) * M_BLOCK
    noffset = tl.program_id(1) * N_BLOCK

    value = desc.load([moffset, noffset])
    desc.store([moffset, noffset], tl.abs(value))

# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
    return torch.empty(size, device="cuda", dtype=torch.int8)

triton.set_allocator(alloc_fn)

M, N = 256, 256
x = torch.randn(M, N, device="cuda")
M_BLOCK, N_BLOCK = 32, 32
grid = (M / M_BLOCK, N / N_BLOCK)
inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)