triton.language.cumsum

triton.language.cumsum(input, axis=0, reverse=False, dtype: constexpr | None = None)

返回 input 张量中沿给定 axis 的所有元素的累积和。

参数:
  • input (Tensor) – 输入值

  • axis (int) – 执行扫描操作的维度

  • reverse (bool) – 如果为 true,则沿反方向执行扫描

  • dtype (tl.dtype) – 返回张量所需的数据类型。如果指定,输入张量将在执行操作前转换为 dtype。如果未指定,则小整数类型(小于 32 位)将被提升以防止溢出。请注意,tl.bfloat16 输入会自动提升为 tl.float32

此函数也可以作为成员函数在 tensor 上调用,以 x.cumsum(...) 的形式,而不是 cumsum(x, ...)