triton.language.associative_scan

triton.language.associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None)

将 combine_fn 应用于 input 张量沿所提供的 axis 的每个元素,并更新进位。

参数:
  • input (Tensor) – 输入张量,或张量元组

  • axis (int) – 执行归约的维度

  • combine_fn (Callable) – 一个用于组合两组标量张量的函数(必须用 @triton.jit 标记)

  • reverse (bool) – 是否沿轴反向应用关联扫描

此函数也可以作为成员函数在 tensor 上调用,形式为 x.associative_scan(...) 而不是 associative_scan(x, ...)