triton.language.topk
- triton.language.topk(x, k: constexpr, dim: constexpr = None, descending: constexpr = True)
返回沿指定维度输入张量的 k 个最大(或最小)元素。
返回的元素按排序顺序排列(最大的优先)。
- 参数:
x (Tensor) – 输入张量。
k (int) – 要返回的顶部元素数量。必须是 2 的幂。
dim (int, optional) – 要查找 top k 元素的维度。如果为 None,则使用最后一个维度。目前仅支持最后一个维度。
descending (bool, optional) – 如果设置为 True,则返回 k 个最大元素。如果设置为 False,则返回 k 个最小元素。
- 返回:
一个包含沿指定维度 k 个最大元素的张量。
- 返回类型:
Tensor
示例
# Get top 4 elements from a 1D tensor x = tl.arange(0, 16) top4 = tl.topk(x, 4) # Returns [15, 14, 13, 12]