詳細解釋
Flash Attention是一種IO感知的注意力算法,通過分塊計算和軟件優化,大幅減少HBM(高帶寬記憶體)訪問,實現2-4倍加速和更少的記憶體使用。
標準注意力的問題:
- 計算:O(N²d),N為序列長度,d為維度
- 記憶體:O(N²)的注意力矩陣存儲
- 實際瓶頸:HBM讀寫,非計算
- Transformer架構 (變換器 / 注意力模型) (Switch Transformer)的記憶體和速度瓶頸
Flash Attention的創新:
- 分塊(Tiling):將Q、K、V分為小塊,在SRAM(快速緩存)中計算
- 在線softmax:不需要存儲完整注意力矩陣
- 重計算:反向傳播時重新計算注意力(計算換記憶體)
- IO感知:最小化HBM訪問次數
優勢:
- 速度:2-4倍加速(A100上)
- 記憶體:O(N)而非O(N²),支持更長序列
- 精確:數值結果與標準注意力相同(非近似)
- 可擴展:支持百萬級token序列
Flash Attention-2改進:
- 更好的工作劃分:減少warp閒置
- 更低的non-matmul FLOPs:優化softmax等操作
- 並行性:序列長度維度並行
Flash Decoding:
- 推理優化:解碼階段的注意力加速
- 分塊KV Cache:減少記憶體讀取
- 推理解碼:2-8倍加速
應用:
- 大型語言模型 (大語言模型 / 大模型)訓練和推理:長上下文支持
- 蛋白質建模:長序列建模
- 音頻/視頻:長時序數據
- 推理服務:更高吞吐更低延遲
整合:
- PyTorch (Torch Compile) 2.0:torch.nn.functional.scaled_dot_product_attention自動使用
- vLLM、TensorRT-LLM (NVIDIA LLM 推論加速):內建支持
- Hugging Face:多模型已整合
Flash Attention是長序列Transformer的革命性優化。