登录    注册      
    

News Message

FlashAttention-4



FlashAttention-4



3月5号,Tri Dao团队放出了FlashAttention-4的论文。如果你做大模型训练或者推理,FlashAttention这个名字你不可能没听过——它是目前几乎所有主流大模型底层都在用的Attention计算内核,从GPT到LLaMA到DeepSeek,都绑着它跑。每一代更新都意味着同样的卡,能跑更大的模型、更长的上下文、更快的速度。

而FA4这一版,我认为大概是这条优化路上的最后一次暴力美学了。

不是说Flash系列要完结——而是往上的空间,已经被物理定律卡死了。

FA4在NVIDIA最新的B200 GPU上打出1613 TFLOPS/s,硬件利用率71%。这数字什么概念?Attention本质上就是两坨矩阵乘法,中间夹一个softmax。FA4现在把这整套活儿的执行速度,逼到了跟"只算矩阵乘法、softmax当它不存在"几乎一样快的程度。

那个softmax的计算开销,被压成了一张纸。

放到时间线上看更有冲击力——2022年Tri Dao丢出第一版FlashAttention,大伙的反应是"哦,一个挺聪明的kernel trick"。到FA2、FA3,行内人开始意识到这东西不是trick,是基础设施。而现在FA4摸到了一个临界点:Attention层的计算效率,贴上了GPU硬件能力的物理天花板。

再往上?Sorry,没了。

Blackwell的跛脚问题

要理解FA4为什么能做到这一步、以及为什么说它接近天花板,得先搞清楚一个事实:NVIDIA的Blackwell架构是跛脚的。

Tensor Core的矩阵乘法吞吐量,从H100的1 PFLOPS翻倍到了B200的2.25 PFLOPS。但共享内存的带宽?没动。指数运算单元的吞吐?也没动,还是16 ops/clock/SM。

NVIDIA当然知道这个问题——在下一代B300上,指数运算单元的吞吐已经翻倍到32 ops/clock/SM了。但B300还没大规模铺开,眼下大家拿到手的B200就是这么个状况。

论文里有个roofline分析,数字非常直观:在B200上跑Attention,MMA矩阵乘法需要2048个cycle,指数运算需要2048个cycle,共享内存搬运需要1536个cycle。三条线几乎挤在一起。

以前做Attention优化的核心矛盾是矩阵乘法太慢。现在矩阵乘法翻倍了,瓶颈转移到了softmax里的  运算和数据在共享内存里的搬进搬出。

FA4的全部工作就是在解决这个"瓶颈转移"的问题。而它的解法,每一个都透着一股把硬件手册翻烂了的狠劲,可以说是非常之硬核。

FA4前向传播流水线设计。两个Q tile(标记为H和L)交替执行:一个tile跑矩阵乘法的时候,另一个tile同时算softmax,实现计算资源的最大化重叠。
FA4前向传播流水线设计。两个Q tile(标记为H和L)交替执行:一个tile跑矩阵乘法的时候,另一个tile同时算softmax,实现计算资源的最大化重叠。

用软件"伪造"硬件指令

FA4处理指数运算瓶颈的方式非常神奇。

GPU上有个专用硬件单元MUFU,专门算  这类超越函数。但这玩意的吞吐量只有Tensor Core的1/512——很离谱,差了五百多倍。在H100时代这个差距勉强能忍,因为矩阵乘法本身也不够快,大家互相等。但B200把矩阵乘法翻倍之后,MUFU就成了整条流水线上最慢的那个工位。

FA4的做法就很邪修——既然硬件指数单元不够快,那就用通用的乘加指令来模拟指数运算。

具体操作是经典的Cody-Waite范围折减:把  拆成 ,整数部分直接用IEEE 754浮点数的指数位做位操作——这基本上是"免费"的;小数部分用一个3次多项式近似,走Horner方法用FMA指令求值。

关于精度论文给了一张表,结论很干脆——在FP32级别上,3次多项式的最大相对误差是 ,确实比硬件MUFU差了600倍。但问题是——谁在乎FP32?

大模型训练用的是BF16。BF16本身的量化误差就有 ,比多项式近似的误差大了整整两个数量级。在BF16精度下,3次多项式和硬件MUFU的结果,99%的输入上完全一致。

根本就看不出区别。

但FA4还有更精明的操作:它不是把所有指数运算都用软件模拟,而是只模拟10%-25%。剩下的还走硬件MUFU。两条通路并行执行,利用的是FMA单元和MUFU本来就是不同的硬件功能单元、可以同时干活这个事实。

这种优化,真的是要对GPU微架构有很深的理解了,否则根本想不到啊。

90%的rescaling都是白干的

下一个FlashAttention的灵魂操作是online softmax——因为Attention矩阵太大放不进显存,所以得分块算,每算完一块就要用新的最大值去"重新缩放"之前所有的结果。这个rescaling操作在热循环里反复执行,每次都是一整行的向量乘法。

但FA4问了一个看似简单的问题——这个rescaling,真的每次都有必要吗?

答案是不需要。只有当新块的行最大值比旧的大超过一个阈值(  )的时候,rescaling才有意义。其余时候,旧的最大值还够用,直接跳过就行——反正最后一步会用真实的全局最大值做一次总的修正。

结果就是约90%的rescaling被跳过了。

这个优化本身的原理不难,难的是"敢跳"。因为online softmax的数值稳定性是FA系列的立身之本,动这个地方需要非常严格的误差分析来保证最终结果正确。论文花了不少篇幅论证这一点——设置阈值为8.0对应的是256倍的缩放因子,在BF16的动态范围内,这个"欠缩放"完全可以被最终的归一化步骤修正回来。

FA4前向传播在B200上的TFLOPS表现(causal attention,头维度128)。FA4(红色)在中长序列上全面领先cuDNN 9.13(蓝色)和Triton(绿色),峰值超过1600 TFLOPS/s。注意曲线随序列变长而上升——FA4越长越快。
FA4前向传播在B200上的TFLOPS表现(causal attention,头维度128)。FA4(红色)在中长序列上全面领先cuDNN 9.13(蓝色)和Triton(绿色),峰值超过1600 TFLOPS/s。注意曲线随序列变长而上升——FA4越长越快。

反向传播的优化

FA4在反向传播上的优化更激进,也更有工程意义。

这里用到了Blackwell的一个新特性:2-CTA MMA模式。简单说,两个CTA协作线程阵列可以配对执行一次矩阵乘法,输出的累加器在M维度上分给两个CTA,而operand B在N维度上各存一半。

这个带来的直接效果就是每个CTA只需要在自己的共享内存里存一半的B,共享内存带宽需求直接减半。论文的roofline分析显示,1-CTA模式下共享内存开销比MMA计算多30%,切到2-CTA之后这个差距缩小到5%。

但更精彩的是  的处理。 的计算需要在KV序列维度上做归约,原本每次迭代都要做全局原子加(atomic add)——这是GPU编程里出了名的慢操作,而且会引入非确定性。FA4利用2-CTA的分工,用分布式共享内存(DSMEM)在两个CTA之间交换各自的半块 ,让每个CTA拿到完整的归约维度后在本地完成计算,原子加的次数直接砍半。

2-CTA模式下反向传播的dQ计算步骤。
2-CTA模式下反向传播的dQ计算步骤。

FA4实现了确定性的backward pass,而且性能损失极小——能达到非确定性版本75%的速度。

这个很重要了,现在大模型训练越来越依赖强化学习——RLHF、GRPO、各种reward model驱动的后训练。这些方法对梯度的可复现性要求极高,你的梯度算一次一个样,debug就是噩梦。FA4的确定性模式通过信号量锁来序列化全局归约,配合精心设计的CTA调度顺序(SPT——最短处理时间优先),把锁等待的开销压到了最低。

这个功能,在当下reasoning model大爆发的背景下,几乎是刚需。

为DeepSeek量身定制的那张性能图

论文里有一张图专门展示了(192, 128)头维度配置下的性能——这就是DeepSeek-V3用的配置:192维的query和128维的key/value。

FA4专门为这种"非标"头维度做了适配和benchmark,这个细节很值得玩味。

DeepSeek-V3的MLA架构对头维度的选择和传统的Transformer不太一样,query和key/value维度不等。这在以前的FA版本里是不太好处理的——tile size和寄存器分配都是按标准维度(64或128)来的。FA4的模块化设计让这种非标配置的支持变得更自然。

说白了,你做一个基础设施级的开源库,客户就是这些大模型团队。你的客户在用什么架构,你就得支持什么架构。DeepSeek-V3是当下最受关注的开源模型之一,FA4不可能不做这个适配。

Python写GPU内核:NVIDIA的"阳谋"

要知道FA4完全用CuTe-DSL编写的——这是NVIDIA CUTLASS团队推出的嵌入Python的DSL。没有一行CUDA C++,编译时间从FA3的55秒降到2.5秒。

把这件事放到NVIDIA的整体生态战略里,画面就不一样了。

过去这些年,CUDA生态的最大壁垒是什么?是CUDA本身?不,是CUDA C++的痛苦开发体验。C++ template metaprogramming写出来的代码,发明者自己三个月后回来可能都看不懂。这个门槛把绝大多数AI研究者挡在了"高性能GPU编程"的门外。

现在NVIDIA通过CuTe-DSL在重新招揽人才。

它解决的更多的就是目前CUDA的生态问题。当越来越多的研究者开始用CuTe-DSL写自定义kernel,他们就更深地嵌入了CUDA生态。AMD的ROCm再怎么追赶CUDA的API兼容性,它追不上的是这种"用Python在NVIDIA硬件上写高性能内核"的开发者体验。

FA4用自己的存在证明了CuTe-DSL的可行性——连FlashAttention这种对性能要求极端苛刻的kernel都能用Python DSL写出来而且跑得更快,那其他kernel还有什么理由继续用C++?

论文里提到,已经有开发者在FA4基础上构建了FlexAttention和block-sparse attention变体。这才刚开始。

FA4越快,Mamba们越尴尬

过去两年,Mamba、RWKV、各种线性注意力架构之所以受到关注,一个核心论点是:Attention的二次复杂度是个根本缺陷,序列越长代价越高,所以需要用线性复杂度的替代方案。

这个论点在理论上没问题。但工程现实是另一回事。

FA4现在证明了虽然Attention的理论复杂度确实是 ,但当实现层面的常数因子被压缩到极致之后,这个二次方的实际开销比你想象的小得多。在32K序列长度上,FA4跑出了接近硬件极限的吞吐——这意味着在目前绝大多数实际应用的上下文长度范围内,Attention的"二次方"还远没有成为真正的痛点。

换句话说,FA系列每快一代,"我们需要用线性注意力替代Transformer"这个叙事的说服力就弱一分。

当然,百万级、千万级的上下文场景下,二次复杂度终究会成为不可逾越的墙。但那个场景什么时候真正成为主流需求?目前看,至少还不是今天。

这就形成了一个有意思的局面——替代架构在等待一个Attention力不从心的场景,而FA系列在不断推迟这个场景到来的时间。

我等着继位,你怎么还活蹦乱跳的?

Mamba和RWKV并不是没有价值。但它们的生存空间,正在被FA4这样的工作一寸一寸地压缩。它们需要找到一个"即使Attention再怎么优化也搞不定"的场景来证明自己——目前来看,那个场景还在地平线之外了。

FA4的benchmark是在B200上跑的。这卡目前大规模可用性还很有限,很多团队连H100的集群都没凑齐呢。FA4相比cuDNN 9.13的1.3倍优势很亮眼,但论文作者自己也说了,他们已经和cuDNN团队合作把FA4的技术融入了cuDNN 9.13/9.14。最新版cuDNN的性能已经非常接近了。所以FA4作为独立库的"性能红利窗口期"可能比前几代更短。

但这恰恰说明了开源的价值——FA4倒逼NVIDIA的闭源库升级,整个行业受益。至于FA4本身是不是"必须单独用",反而不重要了。

再看看FA4的六位作者分别来自Princeton、Meta、Colfax Research、NVIDIA和Georgia Tech,其中三位标注为equal contribution。这不再是Tri Dao的独角戏了——它已经演变成了一个跨机构的协作项目。Meta和NVIDIA的深度参与,既说明了这个方向的重要性,也意味着FA的发展方向可能会越来越受到工业界需求的驱动。

最后说回开头的判断。当Attention的实现效率已经摸到了硬件物理极限的天花板,FA系列的"大版本"优化空间确实在缩小。未来的版本大概率是适配新硬件(B300、下一代Rubin架构)、支持新精度(FP4/FP8)、覆盖新变体(各种稀疏注意力)——重要,但不再是FA1→2→3→4这种"重写流水线"级别的飞跃了。



Share Http URL:  http://www.wittx.cn/get_news_message.do?new_id=1544



请输入评论





























Best Last Month

恒大汽车宣布 40 亿港元融资:腾讯、红杉、云锋入局,折价 19.96%



Optimization Algorithms

Optimization Algorithms

Information industry

by wittx


2020/09/22 金融行情

2020/09/22 金融行情

Information industry

by wittx


2020/12/11 金融行情

2020/12/11 金融行情

Information industry

by wittx


Google Brain的优化器Lion

Google Brain的优化器Lion

Information industry

by wittx


An Introduction To Computational Methods

An Introduction To Computational Methods

Information industry

by wittx


金属材料重要突破进展汇总

金属材料重要突破进展汇总

Information industry

by wittx


2020/12/26 金融行情

2020/12/26 金融行情

Information industry

by wittx


Machine Learning on Graphs

Machine Learning on Graphs

Information industry

by wittx


投资组合与CAPM模型

投资组合与CAPM模型

Information industry

by wittx