triton.language.topk

triton.language.topk(x, k: constexpr, dim: constexpr = None, descending: constexpr = True)

返回沿指定维度输入张量的 k 个最大(或最小)元素。

返回的元素按排序顺序排列(最大的优先)。

参数:
  • x (Tensor) – 输入张量。

  • k (int) – 要返回的顶部元素数量。必须是 2 的幂。

  • dim (int, optional) – 要查找 top k 元素的维度。如果为 None,则使用最后一个维度。目前仅支持最后一个维度。

  • descending (bool, optional) – 如果设置为 True,则返回 k 个最大元素。如果设置为 False,则返回 k 个最小元素。

返回:

一个包含沿指定维度 k 个最大元素的张量。

返回类型:

Tensor

示例

# Get top 4 elements from a 1D tensor
x = tl.arange(0, 16)
top4 = tl.topk(x, 4)  # Returns [15, 14, 13, 12]