triton.language.reduce¶
- triton.language.reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None)¶
沿提供的
axis对input张量中的所有元素应用 combine_fn- 参数:
input (Tensor) – 输入张量,或张量元组
axis (int | None) – 应该执行归约的维度。如果为 None,则归约所有维度
combine_fn (Callable) – 用于组合两组标量张量的函数(必须使用 @triton.jit 标记)
keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
此函数也可以作为
tensor的成员函数调用,例如x.reduce(...)而不是reduce(x, ...)。