triton.language.where

triton.language.where(condition, x, y, _semantic=None)

根据 condition 的值,返回一个由 xy 中的元素组成的张量。

请注意,无论 condition 的值如何,xy 都会被评估。

如果您想避免意外的内存操作,请改用 triton.loadtriton.store 中的 mask 参数。

xy 的形状都将广播到 condition 的形状。xy 必须具有相同的数据类型。

参数:
  • condition (Block of triton.bool) – 当为 True(非零)时,返回 x,否则返回 y。

  • x – 在 condition 为 True 的索引处选择的值。

  • y – 在 condition 为 False 的索引处选择的值。