triton.language.dot_scaled¶
- triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, rhs_k_pack=True, out_dtype=triton.language.float32, _semantic=None)¶
返回两个微缩放(microscaling)格式块的矩阵乘积。
lhs 和 rhs 使用此处描述的微缩放格式:https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
软件模拟使得可以面向没有原生微缩放操作支持的硬件架构。目前在这种情况下,微缩放的 lhs/rhs 会预先向上转换为
bf16元素类型进行点积计算,但有一个例外:特别是在 AMD CDNA3 上,如果其中一个输入的元素类型是fp16,那么另一个输入也会被向上转换为fp16元素类型。此行为是实验性的,未来可能会发生变化。- 参数:
lhs (表示 fp4、fp8 或 bf16 元素的二维张量。Fp4 元素被打包到 uint8 输入中,第一个元素在低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 要相乘的第一个张量。
lhs_scale (表示为 uint8 张量的 e8m0 类型,或 None。) – lhs 张量的缩放因子。当 lhs 形状为 [M, K] 时,其形状应为 [M, K//group_size],其中如果缩放类型为 e8m0,则 group_size 为 32。
lhs_format (str) – lhs 张量的格式。可用格式:{
e2m1,e4m3,e5m2,bf16,fp16}。rhs (表示 fp4、fp8 或 bf16 元素的二维张量。Fp4 元素被打包到 uint8 输入中,第一个元素在低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 要相乘的第二个张量。
rhs_scale (表示为 uint8 张量的 e8m0 类型,或 None。) – rhs 张量的缩放因子。当 rhs 形状为 [K, N] 时,其形状应为 [N, K//group_size]。重要提示:不要转置 rhs_scale
rhs_format (str) – rhs 张量的格式。可用格式:{
e2m1,e4m3,e5m2,bf16,fp16}。acc – 累加器张量。如果不为 None,则结果会加到此张量上。
lhs_k_pack (bool, 可选) – 如果为 false,则 lhs 张量会沿 M 维度打包成 uint8。
rhs_k_pack (bool, 可选) – 如果为 false,则 rhs 张量会沿 N 维度打包成 uint8。