triton.language.sum

triton.language.sum(input, axis=None, keep_dims=False, dtype: constexpr = None)

返回 input 张量中沿给定 axis 的所有元素的和。

归约操作应具有结合律和交换律。

参数:
  • input (Tensor) – 输入值

  • axis (int) – 执行归约操作的维度。如果为 None,则归约所有维度

  • keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度

  • dtype (tl.dtype) – 返回张量所期望的数据类型。如果指定了该参数,输入张量将在执行操作前被转换为 dtype。这对于防止数据溢出非常有用。如果不指定,整数和布尔类型会被向上转换为 tl.int32,浮点类型会被向上转换为至少 tl.float32

该函数也可以作为 tensor 的成员函数调用,即使用 x.sum(...) 而不是 sum(x, ...)