triton.language.where
- triton.language.where(condition, x, y, _semantic=None)
根据
condition返回从x或y中选取的元素组成的张量。请注意,无论
condition的值如何,x和y都会被求值。如果您想避免意外的内存操作,请改用 triton.load 和 triton.store 中的
mask参数。x和y的形状都会被广播(broadcast)为condition的形状。x和y必须具有相同的数据类型。- 参数:
condition (triton.bool 的块) – 当为 True(非零)时,输出 x,否则输出 y。
x – 在 condition 为 True 的索引处选取的数值。
y – 在 condition 为 False 的索引处选取的数值。