triton.language.reduce¶
- triton.language.reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None)¶
将
combine_fn
应用于input
张量中沿指定axis
的所有元素- 参数:
input (Tensor) – 输入张量,或张量元组
axis (int | None) – 进行归约运算的维度。如果为 None,则归约所有维度
combine_fn (Callable) – 一个用于组合两组标量张量的函数(必须使用 @triton.jit 标记)
keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
此函数也可以作为成员函数在
tensor
上调用,形式为x.reduce(...)
而不是reduce(x, ...)
。