triton.language.split

triton.language.split(a, _semantic=None, _generator=None) tuple[tensor, tensor]

沿最后一个维度将张量一分为二,该维度的大小必须为 2。

例如,给定一个形状为 (4,8,2) 的张量,生成两个形状为 (4,8) 的张量。给定一个形状为 (2) 的张量,返回两个标量。

如果您想要拆分成超过两个部分,可以多次调用此函数(可能还需要调用 reshape)。这反映了 Triton 中的一个约束:张量必须具有 2 的幂次方大小。

split 是 join 的逆运算。

参数:

a (Tensor) – 要拆分的张量。

此函数也可以作为 tensor 的成员函数调用,即使用 x.split() 而非 split(x)