triton.language.sum

triton.language.sum(input, axis=None, keep_dims=False, dtype: constexpr | None = None)

返回沿给定 axisinput 张量中所有元素的总和

参数:
  • input (张量) – 输入值

  • axis (int) – 应进行归约的维度。如果为 None,则对所有维度进行归约

  • keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度

  • dtype (tl.dtype) – 返回张量所需的数据类型。如果指定,输入张量在执行操作前会转换为 dtype。这对于防止数据溢出很有用。如果未指定,整数和布尔类型会提升为 tl.int32,浮点类型会提升为至少 tl.float32

此函数也可以作为成员函数在 tensor 上调用,例如 x.sum(...) 而非 sum(x, ...)