# Flash Attention[1]

传统 Transformer 模型在成处理长序列,
Flash Attention 是一种在 IO 上优化的推理加速策略。通过算子融合将 Attention 操作合并,引用分块技术 Tiling 计算注意力矩阵和 Online Softmax , 可实现每次从 SRAM 去读取数据而避免频繁的 HBM 的 IO 延迟.

# Motivation

当把 Transformer 模型的上下文拓展到更长时是非常困难的,主要是因为:

  • self-attention 模块在计算复杂度和空间复杂度均是N2(N指代上下文的长度)N^{2}(N指代上下文的长度)
  • 现有的一些方法主要的关注点在于去减少计算量,即 flops 而不是去考虑减少它的 Memory 开销。但是往往内存的搬运占用 attention 非常大的时间[2](很有可能成为整个 attentionbound

# Standard Attention

在一般的注意力机制的实现中(如下图算法 0),对Q,K,VRseq_len×hidden_sizeQ,K,V \in \mathcal R^{seq\_len × hidden\_size} 三个输入需要做如下处理:

  • HBM 中加载 Q、K 矩阵,计算S=QKTRseq_len×seq_lenS=QK^{T} \in \mathcal R^{seq\_len × seq\_len},并将结果写入 HBM
  • HBM 中读取 S 矩阵,计算P=softmax(S)Rseq_len×seq_lenP=softmax(S) \in \mathcal R^{seq\_len × seq\_len},并将矩阵 P 写入 HBM
  • HBM 中读取 S 矩阵,计算O=PVRseq_len×hidden_sizeO=PV \in \mathcal R^{seq\_len × hidden\_size},并将矩阵 O 写入 HBM

由于输入矩阵 Q、K和V 矩阵和产生的中间矩阵 S和P 非常大,无法将其完整的保留在 SRAM 中,而被迫需要将其存如相对而言高内存空间的 HBM . 上面的三个操作中容易看出,分别涉及到四次 HBM 的读和写,是一个大的 IO 开销.

# Flash Attention

# 核心思想

Flash Attention 的核心思想是将所有的运算操作所需要的数

# 矩阵分块计算 Tiling


# 增量求解 softmax

如上 矩阵分块Tiling 思想,在做 QKV Attention 矩阵相乘的操作时,可以将矩阵进行分块分别得到多个 output 矩阵,最终对所以的 output 矩阵进行 element add 操作即可得到最后的 output 结果。由于 Attention 架构内有 softmax 操作( softmax 需要全局行信息),固在得到outputioutput_{i} 的时候需要做增量的更新操作,而不是简单的矩阵 element add 操作。以下推导增量 softmax 操作:

不妨设输入向量为x={x1,x2,,xk,xk+1,,xn},p1={x1,x2,,xk},p2={xk+1,xk+2,,xn}x=\{x_{1},x_{2}, \cdots, x_{k},x_{k+1}, \cdots, x_{n}\}, p_{1}=\{x_{1},x_{2},\cdots,x_{k}\}, p_{2}=\{x_{k+1},x_{k+2},\cdots,x_{n}\},需要计算P=softmax(x)=softmax({p1,p2})P=softmax(x)=softmax(\{p_{1}, p_{2}\}),以下将表述如何通过计算局部softmax(p1)softmax(p2)softmax({p_{1}})、softmax({p_{2}}) 更新得到全局softmax(x)softmax(x)

  • 计算局部s(p1)=softmax(x1,x2,,xk)s({p_1})=softmax({x_1},x_{2},\cdots,x_{k})

    • m(p1)=max({x1,x2,,xk})m({p_1})=max(\{x_{1}, x_{2}, \cdots, x_{k}\})
    • f(p1)=[ex1m(p1),ex2m(p1),,exkm(p1)]f(p_1)=[e^{x_{1}-m(p_1)},e^{x_{2}-m(p_1)},\cdots,e^{x_{k}-m(p_1)}]
    • l(p1)=i=1kexim(p1)l(p_{1})=\sum_{i=1}^{k}{e^{x_{i}-m(p_1)}}
    • s(p1)=[ex1m(p1)l(p1),ex2m(p1)l(p1),,exkm(p1)l(p1)]s(p_1)=[\frac{e^{x_{1}-m(p_1)}}{l(p_1)},\frac{e^{x_{2}-m(p_1)}}{l(p_1)},\cdots,\frac{e^{x_{k}-m(p_1)}}{l(p_1)}]
  • 同理计算局部s(p2)=softmax(xk+1,xk+2,,xn)s({p_2})=softmax({x_{k+1}},x_{k+2},\cdots,x_{n})

    • m(p2)=max({xk+1,xk+2,,xn})m({p_2})=max(\{x_{k+1}, x_{k+2}, \cdots, x_{n}\})
    • f(p2)=[exk+1m(p1),exk+2m(p1),,exnm(p1)]f(p_2)=[e^{x_{k+1}-m(p_1)},e^{x_{k+2}-m(p_1)},\cdots,e^{x_{n}-m(p_1)}]
    • l(p2)=i=k+1nexim(p1)l(p_{2})=\sum_{i=k+1}^{n}{e^{x_{i}-m(p_1)}}
    • s(p2)=[exk+1m(p2)l(p2),exk+2m(p2)l(p2),,exnm(p2)l(p2)]s(p_2)=[\frac{e^{x_{k+1}-m(p_2)}}{l(p_2)},\frac{e^{x_{k+2}-m(p_2)}}{l(p_2)},\cdots,\frac{e^{x_{n}-m(p_2)}}{l(p_2)}]
  • 易知m(x)=max({m(p1),m(p2)})m(x)=max(\{m(p_1), m(p_2)\})

  • 更新局部f(p1)f(p_1) 至全局f(p1)globalf(p_{1})^{global},即

    f(p1)global=[ex1m(p1)em(p1)m(x),ex2m(p1)em(p1)m(x),,exkm(p1)em(p1)m(x)]=[ex1m(x),ex2m(x),,exkm(x)]=f(p1)em(p1)m(x)f(p_{1})^{global}=[e^{x_{1}-m(p_1)}\cdot e^{m({p_1})-m(x)},e^{x_{2}-m(p_1)}\cdot e^{m({p_1})-m(x)},\cdots,e^{x_{k}-m(p_1)}\cdot e^{m({p_1})-m(x)}]=[e^{x_{1}-m(x)}, e^{x_{2}-m(x)},\cdots,e^{x_{k}-m(x)}]=f(p_{1}) \cdot e^{m({p_1})-m(x)}

  • 更新局部f(p2)f(p_2) 至全局f(p2)globalf(p_{2})^{global},即

    f(p2)global=[exk+1m(p2)em(p2)m(x),exk+2m(p2)em(p2)m(x),,exnm(p2)em(p2)m(x)]=[exk+1m(x),exk+2m(x),,exnm(x)]=f(p2)em(p2)m(x)f(p_{2})^{global}=[e^{x_{k+1}-m(p_2)}\cdot e^{m({p_2})-m(x)},e^{x_{k+2}-m(p_2)}\cdot e^{m({p_2})-m(x)},\cdots,e^{x_{n}-m(p_2)}\cdot e^{m({p_2})-m(x)}]=[e^{x_{k+1}-m(x)}, e^{x_{k+2}-m(x)},\cdots,e^{x_{n}-m(x)}]=f(p_{2}) \cdot e^{m({p_2})-m(x)}

  • 更新局部l(p1)l({p_1}) 至全局l(p1)globall(p_{1})^{global},即

    l(p1)global=i=1k(exim(p1)em(p1)m(x))=i=1kexim(x)=l(p1)em(p1)m(x)l(p_{1})^{global}=\sum_{i=1}^{k}(e^{x_{i}-m(p_1)} \cdot e^{m(p_{1})-m(x)})=\sum_{i=1}^{k}e^{x_{i}-m(x)}=l(p_{1}) \cdot e^{m(p_{1})-m(x)}

  • 更新局部l(p2)l({p_2}) 至全局l(p2)globall(p_{2})^{global},即

    l(p2)global=i=k+1n(exim(p1)em(p1)m(x))=i=k+1kexim(x)=l(p2)em(p2)m(x)l(p_{2})^{global}=\sum_{i=k+1}^{n}(e^{x_{i}-m(p_1)} \cdot e^{m(p_{1})-m(x)})=\sum_{i=k+1}^{k}e^{x_{i}-m(x)}=l(p_{2}) \cdot e^{m(p_{2})-m(x)}

  • 故全局l(x)=l(p1)global+l(p2)globall(x)=l(p_{1})^{global}+l(p_{2})^{global}.

  • 更新s(p1)globals(p2)globals(p_1)^{global}、s(p_2)^{global}:

    s(p1)global=l(p1)globall(x)(广播机制)s(p2)global=l(p2)globall(x)(广播机制)s(p_1)^{global}=\frac{l(p_1)^{global}}{l(x)}(广播机制)\\ s(p_2)^{global}=\frac{l(p_2)^{global}}{l(x)}(广播机制)

  • 全局softmax(x)={s(p1)global,s(p2)global}softmax(x)=\{s(p_{1})^{global},s(p_2)^{global}\}

总而言之,通过分块的思想分别计算向量x={x1,x2,,xk,xk+1,,xn}x=\{x_{1},x_{2}, \cdots, x_{k},x_{k+1}, \cdots, x_{n}\} 的每个 block 的局部 softmax ,并通过增量的方式来更新至全局 softmax . 在 SRAM 中维护两个标量l(x)m(x)l(x)和m(x) 并更新.




# 代码实现

  • BcB_{c}K 按列分块,每块的列数.
  • BrB_{r}Q 按行分块,每块的行数

# 参考文章


  1. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness ↩︎

  2. Data Movement Is All You Need: A Case Study on Optimizing Transformers ↩︎

Edited on Views times

Give me a cup of [coffee]~( ̄▽ ̄)~*

Value WeChat Pay

WeChat Pay