摘要
很多人优化 AI 推理时,第一反应是“换更强的 GPU”。但在不少场景里,模型慢并不是因为计算单元不够,而是因为数据搬来搬去太频繁。
算子融合就是为了解决这个问题:把多个连续的小操作合并成一个更大的执行单元,减少中间结果读写,让数据尽量停留在更近、更快的地方。本文用一个简单例子讲清楚,为什么算子融合能提升推理性能,以及它并不是所有场景都适合。
一、模型推理慢,可能慢在“搬数据”
假设有这样一段计算:
y = relu(x * w + b)从数学上看,它只是三步:
- 乘法;
- 加法;
- ReLU 激活。
如果按普通方式执行,可能会变成三个独立算子:
x 和 w 读入内存 -> 计算乘法 -> 写出中间结果 tmp1 tmp1 和 b 读入内存 -> 计算加法 -> 写出中间结果 tmp2 tmp2 读入内存 -> 计算 ReLU -> 写出最终结果 y问题来了:每一步都要读写内存。
当张量很大时,真正耗时的不一定是乘法和加法,而是中间结果在显存和缓存之间不断搬运。计算单元可能还没吃饱,内存通道已经堵住了。
这就是所谓的内存带宽瓶颈。
二、算子融合到底融合了什么
算子融合的核心思想很朴素:
既然这几个操作总是连续执行,为什么不放在一起算?
融合后,执行过程可能变成:
读入 x、w、b 在一个 kernel 内完成乘法、加法、ReLU 直接写出 y中间的 tmp1、tmp2 不再频繁写回全局内存,而是在寄存器或更近的缓存里完成处理。
这样做带来的收益主要有三个:
- 减少中间张量写入;
- 减少 kernel 启动次数;
- 提高数据局部性。
对于很多小算子密集的网络,这种优化非常明显。
三、一个直观类比:厨房流水线
可以把推理过程想象成做饭。
没有融合时:
切菜的人切完,把菜端到仓库; 炒菜的人再去仓库取菜; 调味的人调完,又送回仓库; 装盘的人再去仓库取。每一步本身都不复杂,但来回搬运太多。
算子融合后:
同一个工作台上完成切菜、翻炒、调味、装盘。不是厨师变成超人了,而是少走了很多路。
AI 推理也是这样。融合优化的本质不是让每个计算更神奇,而是让数据少旅行。
四、哪些算子适合融合
并不是所有操作都适合融合。通常比较适合的有:
- element-wise 操作,比如 add、mul、relu、sigmoid;
- 形状不变的连续计算;
- 前后依赖明确的小算子;
- 计算量不大但内存访问频繁的操作;
- 常见模式,比如 Conv + Bias + Activation。
例如:
MatMul + Add Conv + BatchNorm + ReLU LayerNorm 中的部分连续步骤 Attention 里的某些投影和后处理这些模式如果拆开执行,会生成大量中间张量。融合后可以减少数据落地。
五、哪些情况融合反而不划算
算子融合不是万能药。下面几类场景要谨慎:
第一,融合后 kernel 太复杂。
如果把太多操作硬塞进一个 kernel,寄存器压力可能变大,反而降低并行度。
第二,不同算子的访存模式差异太大。
有些操作适合顺序访问,有些操作访问很跳跃,强行融合可能让缓存命中率更差。
第三,动态 shape 太多。
输入尺寸变化频繁时,编译器很难提前生成稳定高效的融合策略。
第四,调试成本上升。
多个操作合成一个执行单元后,定位数值误差会更麻烦。尤其在混合精度场景里,一个小小的融合顺序变化,可能影响最终结果。
所以工程上不是“能融合就融合”,而是“收益大于代价才融合”。
六、编译器在这里做了什么
现代 AI 推理引擎和编译器通常会做几件事:
- 分析计算图;
- 找到连续可合并的算子模式;
- 判断 shape、数据类型和设备约束;
- 生成融合后的执行计划;
- 在运行时选择合适 kernel。
从开发者角度看,你可能只是写了普通模型代码。但在底层,编译器会尝试把碎片化计算改写成更适合硬件执行的形式。
这也是为什么同一个模型,在不同推理引擎上的速度可能差很多。差别不一定只在硬件,而在图优化、kernel 选择和内存调度。
七、如何判断是不是内存带宽瓶颈
如果你遇到推理慢,可以观察几个现象:
- GPU 利用率不高,但显存带宽占用很高;
- kernel 数量很多,每个 kernel 执行时间很短;
- profiler 里大量时间花在数据移动和小算子上;
- batch 增大后吞吐提升不明显;
- 删除一些 element-wise 操作后速度明显变化。
这时可以重点检查:
是否有大量中间张量? 是否存在连续小算子? 是否启用了图优化? 是否使用了合适的数据格式? 是否存在动态 shape 阻碍优化?优化方向不一定是买更强的卡,而可能是减少访存、打开编译优化、调整模型结构或切换推理引擎。
八、开发者能做什么
如果你不是编译器开发者,也可以做一些事:
- 尽量使用推理框架推荐的模型导出方式;
- 避免在推理路径里插入大量零散 Python 逻辑;
- 固定输入 shape,减少动态分支;
- 打开图优化或编译模式;
- 用 profiler 看 kernel 数量和显存读写;
- 对性能敏感模块做单独 benchmark;
- 谨慎使用会打断计算图的自定义操作。
很多性能问题不是模型结构本身造成的,而是模型落到工程代码后,计算图被拆碎了。
总结
算子融合解决的是一个很现实的问题:AI 推理不只需要“会算”,还要“少搬”。
当多个小算子连续执行时,中间张量的读写会浪费大量内存带宽。通过融合,推理引擎可以减少数据搬运、降低 kernel 启动开销,并提升硬件利用率。
但融合不是越多越好。真正成熟的优化,需要在计算复杂度、寄存器压力、访存模式、动态 shape 和调试成本之间做平衡。
一句话记住:推理慢不一定是算力不够,很多时候是数据在路上堵车了。