triton.language.reduce

triton.language.reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None)

combine_fn 应用于 input 张量中沿指定 axis 的所有元素

参数:
  • input (Tensor) – 输入张量,或张量元组

  • axis (int | None) – 进行归约运算的维度。如果为 None,则归约所有维度

  • combine_fn (Callable) – 一个用于组合两组标量张量的函数(必须使用 @triton.jit 标记)

  • keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度

此函数也可以作为成员函数在 tensor 上调用,形式为 x.reduce(...) 而不是 reduce(x, ...)