triton.language.dot_scaled¶
- triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, rhs_k_pack=True, out_dtype=triton.language.float32, _semantic=None)¶
返回两个以微缩放格式表示的块的矩阵乘积。
lhs 和 rhs 使用此处描述的微缩放格式:https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
软件仿真使得能够针对不具备原生微缩放操作支持的硬件架构。目前,在这种情况下,微缩放的 lhs/rhs 会在点计算前被上转换为
bf16
元素类型,但有一个例外:对于 AMD CDNA3,如果其中一个输入是fp16
元素类型,则另一个输入也会被上转换为fp16
元素类型。此行为是实验性的,未来可能会有所更改。- 参数:
lhs (表示 fp4、fp8 或 bf16 元素的二维张量。Fp4 元素被打包成 uint8 输入,第一个元素位于低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 第一个要相乘的张量。
lhs_scale (表示为 uint8 张量的 e8m0 类型。) – lhs 张量的缩放因子。
lhs_format (str) – lhs 张量的格式。可用格式:{
e2m1
,e4m3
,e5m2
,bf16
,fp16
}。rhs (表示 fp4、fp8 或 bf16 元素的二维张量。Fp4 元素被打包成 uint8 输入,第一个元素位于低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 第二个要相乘的张量。
rhs_scale (表示为 uint8 张量的 e8m0 类型。) – rhs 张量的缩放因子。
rhs_format (str) – rhs 张量的格式。可用格式:{
e2m1
,e4m3
,e5m2
,bf16
,fp16
}。acc – 累加器张量。如果不是 None,则将结果添加到此张量。
lhs_k_pack (bool, 可选) – 如果为 false,则 lhs 张量沿 M 维度打包为 uint8。
rhs_k_pack (bool, 可选) – 如果为 false,则 rhs 张量沿 N 维度打包为 uint8。