triton.language.dot_scaled

triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, rhs_k_pack=True, out_dtype=triton.language.float32)

返回两个块在微缩放格式下的矩阵乘积。

lhs 和 rhs 使用此处描述的微缩放格式:https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

软件模拟支持没有原生微缩放操作支持的硬件架构。目前对于这种情况,微缩放的 lhs/rhs 在点积计算前会先向上转换为 bf16 元素类型,只有一个例外:对于 AMD CDNA3,如果其中一个输入的元素类型是 fp16,则另一个输入也会向上转换为 fp16 元素类型。这种行为是实验性的,未来可能会有所改变。

参数:
  • lhs (表示 fp4、fp8 或 bf16 元素的 2D 张量。Fp4 元素打包到 uint8 输入中,第一个元素位于低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 第一个要相乘的张量。

  • lhs_scale (表示为 uint8 张量的 e8m0 类型。) – lhs 张量的比例因子。

  • lhs_format (str) – lhs 张量的格式。可用格式:{e2m1, e4m3, e5m2, bf16, fp16}。

  • rhs (表示 fp4、fp8 或 bf16 元素的 2D 张量。Fp4 元素打包到 uint8 输入中,第一个元素位于低位。Fp8 存储为 uint8 或相应的 fp8 类型。) – 第二个要相乘的张量。

  • rhs_scale (表示为 uint8 张量的 e8m0 类型。) – rhs 张量的比例因子。

  • rhs_format (str) – rhs 张量的格式。可用格式:{e2m1, e4m3, e5m2, bf16, fp16}。

  • acc – 累加器张量。如果不是 None,结果会添加到此张量。

  • lhs_k_pack (bool, optional) – 如果为 false,lhs 张量沿 M 维打包到 uint8 中。

  • rhs_k_pack (bool, optional) – 如果为 false,rhs 张量沿 N 维打包到 uint8 中。