triton.language.join

triton.language.join(a, b, _semantic=None)

在一个新的、次要的维度中连接给定的张量。

例如,给定两个形状为 (4,8) 的张量,会生成一个形状为 (4,8,2) 的新张量。给定两个标量,会返回一个形状为 (2) 的张量。

这两个输入会被广播为相同的形状。

如果要连接两个以上的元素,可以多次调用此函数。这反映了 Triton 中的一个约束,即张量的大小必须是2的幂。

join 是 split 的逆操作。

参数:
  • a (张量) – 第一个输入张量。

  • b (张量) – 第二个输入张量。