三元組損失

Triplet Loss

對比學習的常見損失

詳細解釋

三元組損失(Triplet Loss)是度量學習(Metric Learning)的損失函數,學習一個嵌入空間,使相似樣本距離近,不相似樣本距離遠。

基本單位:三元組(Triplet)

  • 錨點(Anchor):參考樣本
  • 正樣本(Positive):與錨點同類別的樣本
  • 負樣本(Negative):與錨點不同類別的樣本

損失定義:

L = max(0, d(a,p) - d(a,n) + margin)

  • d(a,p):錨點與正樣本的距離
  • d(a,n):錨點與負樣本的距離
  • margin:間隔超參數(通常0.2)

目標:

  • d(a,p) < d(a,n) - margin
  • 正樣本比負樣本更接近錨點至少margin
  • 否則產生損失

距離度量:

  • 歐氏距離(L2):||a - b||₂
  • 餘弦距離:1 - 餘弦相似度
  • 其他:馬氏距離等

困難樣本挖掘(Hard Negative Mining):

  • 簡單三元組:隨機選擇,可能太簡單,模型學不到東西
  • 困難負樣本:d(a,n)很小的負樣本(看似相似但不同類)
  • 半困難負樣本:d(a,n) > d(a,p)但仍產生損失的負樣本
  • 策略:每個批次挖掘困難樣本

應用場景:

  • 人臉識別:FaceNet使用三元組損失
  • 圖像檢索:找到與查詢圖像相似的圖片
  • 文本相似度:學習語義嵌入
  • 推薦系統:用戶-物品匹配

優勢:

  • 直接優化相對距離
  • 適合檢索和識別任務
  • 學習有判別性的特徵

挑戰:

  • 三元組選擇策略影響大
  • 收斂可能比較慢
  • 大規模數據集需要高效挖掘

三元組損失是學習判別性嵌入的強大工具。

探索更多AI詞彙

查看所有分類,繼續學習AI知識