詳細解釋
三元組損失(Triplet Loss)是度量學習(Metric Learning)的損失函數,學習一個嵌入空間,使相似樣本距離近,不相似樣本距離遠。
基本單位:三元組(Triplet)
- 錨點(Anchor):參考樣本
- 正樣本(Positive):與錨點同類別的樣本
- 負樣本(Negative):與錨點不同類別的樣本
損失定義:
L = max(0, d(a,p) - d(a,n) + margin)
- d(a,p):錨點與正樣本的距離
- d(a,n):錨點與負樣本的距離
- margin:間隔超參數(通常0.2)
目標:
- d(a,p) < d(a,n) - margin
- 正樣本比負樣本更接近錨點至少margin
- 否則產生損失
距離度量:
- 歐氏距離(L2):||a - b||₂
- 餘弦距離:1 - 餘弦相似度
- 其他:馬氏距離等
困難樣本挖掘(Hard Negative Mining):
- 簡單三元組:隨機選擇,可能太簡單,模型學不到東西
- 困難負樣本:d(a,n)很小的負樣本(看似相似但不同類)
- 半困難負樣本:d(a,n) > d(a,p)但仍產生損失的負樣本
- 策略:每個批次挖掘困難樣本
應用場景:
- 人臉識別:FaceNet使用三元組損失
- 圖像檢索:找到與查詢圖像相似的圖片
- 文本相似度:學習語義嵌入
- 推薦系統:用戶-物品匹配
優勢:
- 直接優化相對距離
- 適合檢索和識別任務
- 學習有判別性的特徵
挑戰:
- 三元組選擇策略影響大
- 收斂可能比較慢
- 大規模數據集需要高效挖掘
三元組損失是學習判別性嵌入的強大工具。