triton.language.cumsum

triton.language.cumsum(input, axis=0, reverse=False, dtype: constexpr | None = None)

返回 input 张量中沿给定 axis 的所有元素的累积和。

参数:
  • input (Tensor) – 输入值

  • axis (int) – 执行扫描的维度

  • reverse (bool) – 如果为 true,则沿反向执行扫描

  • dtype (tl.dtype) – 返回张量所需的数据类型。如果指定,输入张量将在执行操作之前转换为 dtype。如果未指定,较小的整数类型(< 32 位)将向上转换以防止溢出。请注意,tl.bfloat16 输入会自动提升为 tl.float32

此函数也可以作为成员函数在 tensor 上调用,例如 x.cumsum(...) 而不是 cumsum(x, ...)