triton.language.reduce¶
- triton.language.reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None)¶
将
combine_fn应用于input张量中沿指定axis的所有元素- 参数:
input (Tensor) – 输入张量,或张量元组
axis (int | None) – 进行归约运算的维度。如果为 None,则归约所有维度
combine_fn (Callable) – 一个用于组合两组标量张量的函数(必须使用 @triton.jit 标记)
keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
此函数也可以作为成员函数在
tensor上调用,形式为x.reduce(...)而不是reduce(x, ...)。