詳細解釋
多頭注意力(Multi-Head Attention)是並行執行多個注意力計算,讓模型在不同子空間捕捉不同類型的關係,是Transformer架構 (變換器 / 注意力模型) (Switch Transformer)的標準組件。
概念:
- 單頭:一組Q、K、V投影
- 多頭:h組獨立的Q、K、V投影
- 並行:每個頭獨立計算注意力
- 拼接:所有頭的輸出拼接後線性變換
計算:
- head_i = Attention(XW_i^q, XW_i^k, XW_i^v)
- MultiHead = Concat(head_1,...,head_h)W^o
- 每個頭維度:d_model / h
為何需要多頭:
- 不同關係:
- 頭1:句法關係(主語-動詞)
- 頭2:語義關係(同義詞)
- 頭3:指代關係(代詞-先行詞)
- 表達力:增加模型容量
- 並行:計算效率好
頭數選擇:
- 原始Transformer:8頭
- BERT-base:12頭
- GPT-3:96頭
- 維度:d_model / heads 通常64
可視化研究:
- 特定頭學習特定語言現象
- 有些頭:位置相關
- 有些頭:內容相關
- 冗餘:部分頭可被剪枝
計算優化:
與單頭的對比:
- 單大頭:一個更大的注意力
- 多頭:多個較小的注意力
- 實證:多頭通常更好
多頭注意力是Transformer表達力的關鍵。