triton.language.dot_scaled¶
- triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, rhs_k_pack=True, out_dtype=triton.language.float32)¶
返回两个块在微缩放格式下的矩阵乘积。
lhs 和 rhs 使用此处描述的微缩放格式:https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
软件模拟支持没有原生微缩放操作支持的硬件架构。目前对于这种情况,微缩放的 lhs/rhs 在点积计算前会先向上转换为
bf16
元素类型,只有一个例外:对于 AMD CDNA3,如果其中一个输入的元素类型是fp16
,则另一个输入也会向上转换为fp16
元素类型。这种行为是实验性的,未来可能会有所改变。- 参数:
lhs (表示 fp4、fp8 或 bf16 元素的 2D 张量。Fp4 元素打包到 uint8 输入中,第一个元素位于低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 第一个要相乘的张量。
lhs_scale (表示为 uint8 张量的 e8m0 类型。) – lhs 张量的比例因子。
lhs_format (str) – lhs 张量的格式。可用格式:{
e2m1
,e4m3
,e5m2
,bf16
,fp16
}。rhs (表示 fp4、fp8 或 bf16 元素的 2D 张量。Fp4 元素打包到 uint8 输入中,第一个元素位于低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 第二个要相乘的张量。
rhs_scale (表示为 uint8 张量的 e8m0 类型。) – rhs 张量的比例因子。
rhs_format (str) – rhs 张量的格式。可用格式:{
e2m1
,e4m3
,e5m2
,bf16
,fp16
}。acc – 累加器张量。如果不是 None,结果会添加到此张量。
lhs_k_pack (bool, optional) – 如果为 false,lhs 张量沿 M 维打包到 uint8 中。
rhs_k_pack (bool, optional) – 如果为 false,rhs 张量沿 N 维打包到 uint8 中。