triton.language.where

triton.language.where(condition, x, y)

根据 condition 返回一个包含来自 xy 元素的张量。

注意,无论 condition 的值如何,xy 都会被求值。

如果想避免意外的内存操作,请改用 triton.loadtriton.store 中的 mask 参数。

xy 的形状都会被广播到 condition 的形状。 xy 必须具有相同的数据类型。

参数:
  • condition (triton.bool 块) – 当为 True (非零) 时,返回 x;否则返回 y。

  • x – condition 为 True 的索引处选定的值。

  • y – condition 为 False 的索引处选定的值。