# Flash Attention
传统 Transformer 模型在成处理长序列,
Flash Attention 是一种在 IO 上优化的推理加速策略。通过算子融合将 Attention 操作合并,引用分块技术 Tiling
计算注意力矩阵和 Online Softmax
, 可实现每次从 SRAM 去读取数据而避免频繁的 HBM 的 IO 延迟.
# Motivation
当把 Transformer
模型的上下文拓展到更长时是非常困难的,主要是因为:
self-attention
模块在计算复杂度和空间复杂度均是N2(N指代上下文的长度)- 现有的一些方法主要的关注点在于去减少计算量,即
flops
而不是去考虑减少它的 Memory
开销。但是往往内存的搬运占用 attention
非常大的时间(很有可能成为整个 attention
的 bound
)
# Standard Attention
在一般的注意力机制的实现中(如下图算法 0),对Q,K,V∈Rseq_len×hidden_size 三个输入需要做如下处理:
- 从
HBM
中加载 Q、K
矩阵,计算S=QKT∈Rseq_len×seq_len,并将结果写入 HBM
- 从
HBM
中读取 S
矩阵,计算P=softmax(S)∈Rseq_len×seq_len,并将矩阵 P
写入 HBM
- 从
HBM
中读取 S
矩阵,计算O=PV∈Rseq_len×hidden_size,并将矩阵 O
写入 HBM
由于输入矩阵 Q、K和V
矩阵和产生的中间矩阵 S和P
非常大,无法将其完整的保留在 SRAM 中,而被迫需要将其存如相对而言高内存空间的 HBM
. 上面的三个操作中容易看出,分别涉及到四次 HBM
的读和写,是一个大的 IO 开销.
# Flash Attention
# 核心思想
Flash Attention 的核心思想是将所有的运算操作所需要的数
# 矩阵分块计算 Tiling
# 增量求解 softmax
如上 矩阵分块Tiling
思想,在做 QKV Attention
矩阵相乘的操作时,可以将矩阵进行分块分别得到多个 output
矩阵,最终对所以的 output
矩阵进行 element add
操作即可得到最后的 output
结果。由于 Attention
架构内有 softmax
操作( softmax
需要全局行信息),固在得到outputi 的时候需要做增量的更新操作,而不是简单的矩阵 element add
操作。以下推导增量 softmax
操作:
不妨设输入向量为x={x1,x2,⋯,xk,xk+1,⋯,xn},p1={x1,x2,⋯,xk},p2={xk+1,xk+2,⋯,xn},需要计算P=softmax(x)=softmax({p1,p2}),以下将表述如何通过计算局部softmax(p1)、softmax(p2) 更新得到全局softmax(x)
计算局部s(p1)=softmax(x1,x2,⋯,xk)
- m(p1)=max({x1,x2,⋯,xk})
- f(p1)=[ex1−m(p1),ex2−m(p1),⋯,exk−m(p1)]
- l(p1)=∑i=1kexi−m(p1)
- 故s(p1)=[l(p1)ex1−m(p1),l(p1)ex2−m(p1),⋯,l(p1)exk−m(p1)]
同理计算局部s(p2)=softmax(xk+1,xk+2,⋯,xn)
- m(p2)=max({xk+1,xk+2,⋯,xn})
- f(p2)=[exk+1−m(p1),exk+2−m(p1),⋯,exn−m(p1)]
- l(p2)=∑i=k+1nexi−m(p1)
- 故s(p2)=[l(p2)exk+1−m(p2),l(p2)exk+2−m(p2),⋯,l(p2)exn−m(p2)]
易知m(x)=max({m(p1),m(p2)})
更新局部f(p1) 至全局f(p1)global,即
f(p1)global=[ex1−m(p1)⋅em(p1)−m(x),ex2−m(p1)⋅em(p1)−m(x),⋯,exk−m(p1)⋅em(p1)−m(x)]=[ex1−m(x),ex2−m(x),⋯,exk−m(x)]=f(p1)⋅em(p1)−m(x)
更新局部f(p2) 至全局f(p2)global,即
f(p2)global=[exk+1−m(p2)⋅em(p2)−m(x),exk+2−m(p2)⋅em(p2)−m(x),⋯,exn−m(p2)⋅em(p2)−m(x)]=[exk+1−m(x),exk+2−m(x),⋯,exn−m(x)]=f(p2)⋅em(p2)−m(x)
更新局部l(p1) 至全局l(p1)global,即
l(p1)global=i=1∑k(exi−m(p1)⋅em(p1)−m(x))=i=1∑kexi−m(x)=l(p1)⋅em(p1)−m(x)
更新局部l(p2) 至全局l(p2)global,即
l(p2)global=i=k+1∑n(exi−m(p1)⋅em(p1)−m(x))=i=k+1∑kexi−m(x)=l(p2)⋅em(p2)−m(x)
故全局l(x)=l(p1)global+l(p2)global.
更新s(p1)global、s(p2)global:
s(p1)global=l(x)l(p1)global(广播机制)s(p2)global=l(x)l(p2)global(广播机制)
全局softmax(x)={s(p1)global,s(p2)global}
总而言之,通过分块的思想分别计算向量x={x1,x2,⋯,xk,xk+1,⋯,xn} 的每个 block
的局部 softmax
,并通过增量的方式来更新至全局 softmax
. 在 SRAM
中维护两个标量l(x)和m(x) 并更新.
# 代码实现
- Bc:
K
按列分块,每块的列数. - Br:
Q
按行分块,每块的行数
# 参考文章