CANN/catlass MXFP4矩阵乘法示例
2026/6/24 6:36:10 网站建设 项目流程

MXFP4 Matmul TLA Example Readme

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

Function Introduction

  • DemonstratesMX FP4 matrix multiplicationon Ascend 950: The left matrix A and right matrix B are scaled by MX (float8_e8m0) and then multiply-accumulate is performed on Cube. The output is FP32.
  • In this example, the element types of A and B arefloat4_e2m1x2_t(E2M1 packed); the scale factor type isfloat8_e8m0_t. Bias is not enabled (ElementBiasisvoid).
  • The default layout is ARowMajor, BColumnMajor, and CRowMajor, which is consistent with the data generated bygen_data.pywhentrans_a=0, trans_b=1.

Code Organization

├── 54_ascend950_fp4_mx_matmul │ ├── CMakeLists.txt # CMake build file │ ├── README.md │ ├── gen_data.py │ ├── fp4_mx_matmul.cpp # Main file │ └── fp4_mx_matmul_aswt.cpp # BlockSchedulerAswt variant

Usage Example

  • After obtaining the code, compile the corresponding operator executable. See Template Library Quick Start. This case is an Ascend 950 (3510) operator, so add-DCATLASS_ARCH=3510during compilation. The L1 tile is 256×256×448, and L0 is 256×256×128, to meet the 512 KiB L1 and L0 capacity constraints. Do not increase the L1 K dimension arbitrarily; otherwise,L1TileShape exceeding the L1 spacemay occur.
  • Run the operator.
# Compile the specified case bash scripts/build.sh 54_ascend950_fp4_mx_matmul -DCATLASS_ARCH=3510 # Generate test cases (generate input/ and golden/ under examples/54_ascend950_fp4_mx_matmul/data) python3 examples/54_ascend950_fp4_mx_matmul/gen_data.py 256 512 1024 0 1 # Optional: --data-root <DIR> specifies generation under DIR/data/ (generated under the script directory by default) # The input parameters correspond to m, n, k, trans_a, and trans_b respectively # trans_a indicates whether matrix A is transposed. 0 means not transposed, and 1 means transposed # trans_b indicates whether matrix B is transposed. 0 means not transposed, and 1 means transposed # Run the test case ./output/bin/54_ascend950_fp4_mx_matmul 256 512 1024 0 # ASWT scheduling variant (shares the same data/ as above; run gen_data first) bash scripts/build.sh 54_ascend950_fp4_mx_matmul_aswt -DCATLASS_ARCH=3510 ./output/bin/54_ascend950_fp4_mx_matmul_aswt 256 512 1024 0 # Executable file name | matrix m axis | n axis | k axis | Device ID # Device ID is optional and defaults to 0

The execution result is as follows, indicating that the precision comparison succeeds.

Compare success.

Usage Notes

  1. gen_data.pysupports trans_a and trans_b as inputs, but the 54_ascend950_fp4_mx_matmul executable does not. It is only an example where trans_a is 0 and trans_b is 1.

To support the corresponding transposition case, modify the layout in the example, because layout implicitly represents the transposition state. That is, layout::RowMajor indicates no transposition, and layout::ColumnMajor indicates transposition.

The corresponding relationship is shown in the following table:

trans_atrans_bLayoutALayoutB
00layout::RowMajorlayout::RowMajor
01layout::RowMajorlayout::ColumnMajor
10layout::ColumnMajorlayout::RowMajor
11layout::ColumnMajorlayout::ColumnMajor
  1. This example completes MX quantized matrix multiplication: C = (MxScaleA x A) * (MxScaleB x B) + Bias A and B support the float4_e1m2 or float4_e2m1 data type. MxScaleA and MxScaleB support the float8_e8m0 data type.

The data layout requirements for MxScaleA and MxScaleB are as follows: When A is RowMajor, the shape of MxScaleA is (m, ceil(k/64), 2). When A is ColumnMajor, the shape of MxScaleA is (ceil(k/64), m, 2). When B is RowMajor, the shape of MxScaleB is (ceil(k/64), n, 2). When B is ColumnMajor, the shape of MxScaleB is (n, ceil(k/64), 2).

  1. The DispatchPolicy used byMxMatmulTlawithBlockMmadTlaisGemm::MmadMx(defined ininclude/catlass/gemm/dispatch_policy.hpp). The template parameter order and default values are as follows:
Template ParameterDefault ValueParameter Description
ArchTagNoneArchitecture tag, for example,Arch::Ascend950
ENABLE_UNIT_FLAGfalseWhether to enable UnitFlag; whenL0C_STAGES > 1(L0C multi-buffering), it must befalse
L1_SCALE_FACTOR_K16Number ofL1 K-direction stripescovered by one residency of the MX scale from GM to L1; when the value is1, each L1 K stripe loads the scale once. See the comments in the type
L0C_STAGES1Number of L0C buffer stages; setting it to2enables L0C double buffering and must be consistent with theENABLE_UNIT_FLAGconstraint
ENABLE_L1_RESIDENTfalseWhether to enable L1 residency
L1A_STAGES2Number of buffers for loading matrix A on L1
L1B_STAGES2Number of buffers for loading matrix B on L1
L0A_STAGES2Number of buffers for loading matrix A on L0
L0B_STAGES2Number of buffers for loading matrix B on L0

Assume the matrix Shape isM N K, the tile size on L1 ism1 n1 k1, the number of tiles in the M direction ismTiles = CeilDiv(M, m1), the number of tiles in the N direction isnTiles = CeilDiv(N, n1), and the total number of tasks istaskBlocks = mTiles * nTiles. enableL1Resident can be enabled in the following two cases:

  1. mTiles = 1,nTiles > CoreNum, andK < 2 * k1. In this case,l0CStages=2can also be set (enableUnitFlag must be disabled). If there is not enough space andl0CStages=2cannot be set, setn1to half of the original value.

  2. nTiles = 1,mTiles > CoreNum, andK < 2 * k1. In this case,l0CStages=2can also be set (enableUnitFlag must be disabled). If there is not enough space andl0CStages=2cannot be set, setm1to half of the original value.

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询