triton.language.reduce

triton.language.reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None)

沿提供的 axisinput 张量中的所有元素应用 combine_fn

参数:
  • input (Tensor) – 输入张量,或张量元组

  • axis (int | None) – 应该执行归约的维度。如果为 None,则归约所有维度

  • combine_fn (Callable) – 用于组合两组标量张量的函数(必须使用 @triton.jit 标记)

  • keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度

此函数也可以作为 tensor 的成员函数调用,例如 x.reduce(...) 而不是 reduce(x, ...)