triton.language.trans

triton.language.trans(input: tensor, *dims)

重排张量的维度。

如果未指定参数 dims,函数默认执行 (1,0) 置换,实际上是转置一个二维张量。

参数:
  • input – 输入张量。

  • dims – 期望的维度顺序。例如,(2, 1, 0) 会反转 3D 张量中维度的顺序。

dims 可以作为元组或单独的参数传递

# These are equivalent
trans(x, (2, 1, 0))
trans(x, 2, 1, 0)

permute() 函数与此函数等价,但它没有未指定排列时的特殊情况。

此函数也可以作为 tensor 的成员函数调用,形式为 x.trans(...),而不是 trans(x, ...)