一、Triplet结构:
triplet loss是一种比较好理解的loss,triplet是指的是三元组:Anchor、Positive、Negative:
整个训练过程是:
- 首先从训练集中随机选一个样本,称为Anchor(记为x_a)。
- 然后再随机选取一个和Anchor属于同一类的样本,称为Positive (记为x_p)
- 最后再随机选取一个和Anchor属于不同类的样本,称为Negative (记为x_n)
由此构成一个(Anchor,Positive,Negative)三元组。
二、Triplet Loss:
在上一篇讲了Center Loss的原理和实现,会发现现在loss的优化的方向是比较清晰好理解的。在基于能够正确分类的同时,我们更希望模型能够:1、把不同类之间分得很开,也就是更不容易混淆;2、同类之间靠得比较紧密,这个对于模型的鲁棒性的提高也是比较有帮助的(基于此想到Hinton的Distillation中给softmax加的一个T就是人为的对训练过程中加上干扰,让distribution变得更加soft从而去把错误信息放大,这样模型能够不光知道什么是正确还知道什么是错误。即:模型可以从仅仅判断这个最可能是7,变为可以知道这个最可能是7、一定不是8、和2比较相近,论文讲解可以参看Hinton-Distillation)。
回归正题,三元组的三个样本最终得到的特征表达计为:
triplet loss的目的就是让Anchor这个样本的feature和positive的feature直接的距离比和negative的小,即:
除了让x_a和x_p特征表达之间的距离尽可能小,而x_a和x_n的特征表达之间的距离尽可能大之外还要让x_a与x_n之间的距离和x_a与x_p之间的距离之间有一个最小的间隔α,于是修改loss为:
于是目标函数为:
距离用欧式距离度量,+表示[ *** ]内的值大于零的时候,取该值为损失,小于零的时候,损失为零。
故也可以理解为:
L = max([ ] , 0)
在code中就是这样实现的,利用marginloss,详见下节。
三、Code实现:
笔者使用pytorch:
from torch import nn
from torch.autograd import Variable
class TripletLoss(object):
def __init__(self, margin=None):
self.margin = margin
if margin is not None:
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
else:
self.ranking_loss = nn.SoftMarginLoss()
def __call__(self, dist_ap, dist_an):
"""
Args:
dist_ap: pytorch Variable, distance between anchor and positive sample,
shape [N]
dist_an: pytorch Variable, distance between anchor and negative sample,
shape [N]
Returns:
loss: pytorch Variable, with shape [1]
"""
y = Variable(dist_an.data.new().resize_as_(dist_an.data).fill_(1))
if self.margin is not None:
loss = self.ranking_loss(dist_an, dist_ap, y)
else:
loss = self.ranking_loss(dist_an - dist_ap, y)
return loss
理解起来非常简单,当margin为空时,使用SoftMarginLoss:
当margin不为空时,使用MarginRankingLoss,y中填充的都是1,代表希望dist_an>dist_ap,即anchor到negative样本的距离大于到positive样本的距离,margin为dist_an - dist_ap的值需要大于多少:
与我们要得到的loss类似:当与正例距离+固定distance大于负例距离时为正值,则惩罚,否则不惩罚。
四、github项目介绍
class TripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Imported from ``_.
Args:
margin (float, optional): margin for triplet. Default is 0.3.
"""
def __init__(self, margin=0.3,global_feat, labels):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, inputs, targets):
"""
Args:
inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).
targets (torch.LongTensor): ground truth labels with shape (num_classes).
"""
n = inputs.size(0) # batch_size
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs, inputs.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return loss