triton.language.cumprod

triton.language.cumprod(input, axis=0, reverse=False)

返回 input 张量中所有元素沿提供的 axis 的累积乘积

参数:
  • input (Tensor) – 输入值

  • axis (int) – 进行扫描操作的维度

此函数也可以作为成员函数在 tensor 上调用,例如 x.cumprod(...) 而不是 cumprod(x, ...)