triton.language.dot_scaled

triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, rhs_k_pack=True, out_dtype=triton.language.float32, _semantic=None)

返回两个微缩放(microscaling)格式块的矩阵乘积。

lhs 和 rhs 使用此处描述的微缩放格式: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

软件模拟使得在不支持原生微缩放操作的硬件架构上也能进行计算。目前对于此类情况,微缩放的 lhs/rhs 会在点积计算前预先转换为 bf16 元素类型,但有一个例外:对于 AMD CDNA3,如果其中一个输入是 fp16 元素类型,则另一个输入也会被转换为 fp16 元素类型。此行为属于实验性功能,未来可能会发生变化。

参数:
  • lhs (表示 fp4fp8 或 bf16 元素的 2D 张量。Fp4 元素被打包进 uint8 输入中,第一个元素位于低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 要相乘的第一个张量。

  • lhs_scale (表示为 uint8 张量的 e8m0 类型,或 None。) – lhs 张量的缩放因子。当 lhs 为 [M, K] 时,形状应为 [M, K//group_size],其中如果缩放类型为 e8m0,则 group_size 为 32。

  • lhs_format (str) – lhs 张量的格式。可用格式:{e2m1, e4m3, e5m2, bf16, fp16}。

  • rhs (表示 fp4fp8 或 bf16 元素的 2D 张量。Fp4 元素被打包进 uint8 输入中,第一个元素位于低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 要相乘的第二个张量。

  • rhs_scale (表示为 uint8 张量的 e8m0 类型,或 None。) – rhs 张量的缩放因子。当 rhs 为 [K, N] 时,形状应为 [N, K//group_size]。重要提示:请勿转置 rhs_scale。

  • rhs_format (str) – rhs 张量的格式。可用格式:{e2m1, e4m3, e5m2, bf16, fp16}。

  • acc – 累加器张量。如果不为 None,结果将加到此张量上。

  • lhs_k_pack (bool, 可选) – 如果为 false,则 lhs 张量将沿 M 维度打包为 uint8。

  • rhs_k_pack (bool, 可选) – 如果为 false,则 rhs 张量将沿 N 维度打包为 uint8。