triton.language.trans

triton.language.trans(input: tensor, *dims, _semantic=None)

置换张量的维度。

如果未指定参数 dims,则该函数默认交换最后两个轴,从而执行(可选批处理的)2D 转置。

参数:
  • input – 输入张量。

  • dims – 维度的期望顺序。例如,(2, 1, 0) 反转 3D 张量中的维度顺序。

dims 可以作为元组或作为单独的参数传入

# These are equivalent
trans(x, (2, 1, 0))
trans(x, 2, 1, 0)

permute() 与此函数等效,只是在未指定排列时没有特殊情况。

此函数也可以作为 tensor 的成员函数调用,例如 x.trans(...) 而不是 trans(x, ...)