triton.language.where

triton.language.where(condition, x, y, _semantic=None)

根据 condition 返回从 xy 中选取的元素组成的张量。

请注意,无论 condition 的值如何,xy 都会被求值。

如果您想避免意外的内存操作,请改用 triton.loadtriton.store 中的 mask 参数。

xy 的形状都会被广播(broadcast)为 condition 的形状。xy 必须具有相同的数据类型。

参数:
  • condition (triton.bool 的块) – 当为 True(非零)时,输出 x,否则输出 y。

  • x – 在 condition 为 True 的索引处选取的数值。

  • y – 在 condition 为 False 的索引处选取的数值。