3中級訓練與優化
梯度檢查點 (Gradient Checkpointing)
Gradient Checkpointing (Activation Checkpointing)
前向時只存部分激活為檢查點,反向時再重算其餘激活,以時間換取 反向傳播 (倒傳遞) 記憶體,可訓練更大模型。
詳細解釋
梯度檢查點(activation checkpointing)將計算圖分段,前向時只保留分段邊界的激活(檢查點),其餘丟棄;反向時需要中間激活再從最近檢查點重算該段前向,從而大幅降低激活記憶體(約 50–80%),代價是訓練變慢約 20–30%。深層 大型語言模型 (大語言模型 / 大模型)、Transformer架構 (變換器 / 注意力模型) (Switch Transformer) 與大 卷積神經網絡 (CNN) 常用以在有限 圖形處理單元 (GPU / 圖形處理器) 上加大 batch 或加深模型。PyTorch (Torch Compile) 的 torch.utils.checkpoint 與 JAX 的 remat 均支援。與 反向傳播 (倒傳遞)、PyTorch (Torch Compile)、大型語言模型 (大語言模型 / 大模型) 相關。