triton.language.argmax

triton.language.argmax(input, axis, tie_break_left=True, keep_dims=False)

返回 input 张量中沿给定 axis 的所有元素的最大索引。

归约操作应具有结合律和交换律。

参数:
  • input (Tensor) – 输入值

  • axis (int) – 执行归约操作的维度。如果为 None,则归约所有维度

  • keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度

  • tie_break_left (bool) – 如果为 true,则在出现平局(即多个元素具有相同的最大索引值)时,对于非 NaN 值,返回最左边的索引

此函数也可以作为 tensor 的成员函数调用,例如 x.argmax(...),而不是 argmax(x, ...)