梯度檢查點 (Gradient Checkpointing)

Gradient Checkpointing (Activation Checkpointing)

前向時只存部分激活為檢查點,反向時再重算其餘激活,以時間換取 反向傳播 (倒傳遞) 記憶體,可訓練更大模型。

詳細解釋

梯度檢查點(activation checkpointing)將計算圖分段,前向時只保留分段邊界的激活(檢查點),其餘丟棄;反向時需要中間激活再從最近檢查點重算該段前向,從而大幅降低激活記憶體(約 50–80%),代價是訓練變慢約 20–30%。深層 大型語言模型 (大語言模型 / 大模型)Transformer架構 (變換器 / 注意力模型) (Switch Transformer) 與大 卷積神經網絡 (CNN) 常用以在有限 圖形處理單元 (GPU / 圖形處理器) 上加大 batch 或加深模型。PyTorch (Torch Compile) 的 torch.utils.checkpoint 與 JAX 的 remat 均支援。與 反向傳播 (倒傳遞)PyTorch (Torch Compile)大型語言模型 (大語言模型 / 大模型) 相關。

探索更多AI詞彙

查看所有分類,繼續學習AI知識