精选文章 pytorch triple-loss

pytorch triple-loss

作者:vieo 时间: 2020-08-05 09:14:01
vieo 2020-08-05 09:14:01

一、Triplet结构

triplet loss是一种比较好理解的loss,triplet是指的是三元组:Anchor、Positive、Negative:

pytorch triple-loss1


整个训练过程是:

  1. 首先从训练集中随机选一个样本,称为Anchor(记为x_a)。
  2. 然后再随机选取一个和Anchor属于同一类的样本,称为Positive (记为x_p)
  3. 最后再随机选取一个和Anchor属于不同类的样本,称为Negative (记为x_n)

由此构成一个(Anchor,Positive,Negative)三元组。

二、Triplet Loss:

在上一篇讲了Center Loss的原理和实现,会发现现在loss的优化的方向是比较清晰好理解的。在基于能够正确分类的同时,我们更希望模型能够:1、把不同类之间分得很开,也就是更不容易混淆;2、同类之间靠得比较紧密,这个对于模型的鲁棒性的提高也是比较有帮助的(基于此想到Hinton的Distillation中给softmax加的一个T就是人为的对训练过程中加上干扰,让distribution变得更加soft从而去把错误信息放大,这样模型能够不光知道什么是正确还知道什么是错误。即:模型可以从仅仅判断这个最可能是7,变为可以知道这个最可能是7、一定不是8、和2比较相近,论文讲解可以参看Hinton-Distillation)。

回归正题,三元组的三个样本最终得到的特征表达计为:pytorch triple-loss2pytorch triple-loss3pytorch triple-loss4

triplet loss的目的就是让Anchor这个样本的feature和positive的feature直接的距离比和negative的小,即:

pytorch triple-loss5

除了让x_a和x_p特征表达之间的距离尽可能小,而x_a和x_n的特征表达之间的距离尽可能大之外还要让x_a与x_n之间的距离和x_a与x_p之间的距离之间有一个最小的间隔α,于是修改loss为:

pytorch triple-loss6

于是目标函数为:

pytorch triple-loss7

距离用欧式距离度量,+表示[  ***  ]内的值大于零的时候,取该值为损失,小于零的时候,损失为零。

故也可以理解为:

                                                                           L = max([ ] ,  0)

在code中就是这样实现的,利用marginloss,详见下节。

 

三、Code实现:

笔者使用pytorch:

from torch import nn
from torch.autograd import Variable
 
class TripletLoss(object):
  def __init__(self, margin=None):
    self.margin = margin
    if margin is not None:
      self.ranking_loss = nn.MarginRankingLoss(margin=margin)
    else:
      self.ranking_loss = nn.SoftMarginLoss()
 
  def __call__(self, dist_ap, dist_an):
    """
    Args:
      dist_ap: pytorch Variable, distance between anchor and positive sample, 
        shape [N]
      dist_an: pytorch Variable, distance between anchor and negative sample, 
        shape [N]
    Returns:
      loss: pytorch Variable, with shape [1]
    """
    y = Variable(dist_an.data.new().resize_as_(dist_an.data).fill_(1))
    if self.margin is not None:
      loss = self.ranking_loss(dist_an, dist_ap, y)
    else:
      loss = self.ranking_loss(dist_an - dist_ap, y)
    return loss

理解起来非常简单,当margin为空时,使用SoftMarginLoss:

pytorch triple-loss8

当margin不为空时,使用MarginRankingLoss,y中填充的都是1,代表希望dist_an>dist_ap,即anchor到negative样本的距离大于到positive样本的距离,margin为dist_an - dist_ap的值需要大于多少:

pytorch triple-loss9

与我们要得到的loss类似:当与正例距离+固定distance大于负例距离时为正值,则惩罚,否则不惩罚。

 

四、github项目介绍

class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining.
    
    Reference:
        Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
    
    Imported from ``_.
    
    Args:
        margin (float, optional): margin for triplet. Default is 0.3.
    """
    
    def __init__(self, margin=0.3,global_feat, labels):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)
 
    def forward(self, inputs, targets):
        """
        Args:
            inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).
            targets (torch.LongTensor): ground truth labels with shape (num_classes).
        """
        n = inputs.size(0)	# batch_size
        
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return loss

 

勿删,copyright占位
分享文章到微博
分享文章到朋友圈

上一篇:51单片机硬件基础知识

下一篇:SIP语音环境中十大经典问题及解决办法

您可能感兴趣

  • 【收藏】万字综述,核心开发者全面解读PyTorch内部机制

    (给机器学习算法与Python学习加星标,提升AI技能) 选自ezyang博客 作者:Edward Z. Yang 本文由机器之心(nearhuman2014)整理 斯坦福大学博士生与 Facebook 人工智能研究所研究工程师 Edward Z. Yang 是 PyTorch 开源项目的核心开发者之一。他在 5 月 14 日的 PyTorch 纽约聚会上做了一个有关 ,本文是他有关PyTo...

  • pytorch教程之损失函数详解——多种定义损失函数的方法

    转载自:https://blog.csdn.net/qq_27825451/article/details/95165265 不管是定义层还是损失函数,方法有很多,但是通过统一的接口nn.Module是最便于查看的,这也是pytorch的优点之一,模型、层、损失函数的定义具有统一性,都是通过Module类来完成。 一、回顾 1.1 关于nn.Module 和 nn.functional 区别...

  • 【深度学习】深度学习之Pytorch基础教程!

    作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展,深度学习框架开始大量的出现。尤其是近两年,Google、Facebook、Microsoft等巨头都围绕深度学习重点投资了一系列新兴项目,他们也一直在支持一些开源的深度学习框架。目前研究人员正在使用的深度学习框架不尽相同,有 TensorFlow 、Pytorch、Caffe、Theano、Keras等。 这其中,Te...

  • tensorboard显示过程踩到的巨坑!!!

    今天从早到晚,搞tensorboard的显示问题........................ 搞得我实在没脾气了...................... 去求助了一下别人.....................

  • pytorch 入坑 安装+基本概念+入门demo

    windows版本的安装 先来到官网 https://pytorch.org/get-started/locally/ 选择合适的选项会得到一条pip命令,允许这条命令即可下载合适版本。 若出现网速不给力的情况,可以手动将pip输出的信息中的url粘贴到下载工具中下载,然后使用pip instal **.whl 直接安装这个包。 如果需要使用cuda加速,则还需要安装对应版本的cuda和“足...

  • CNN如何用于NLP任务?一文简述文本分类任务的7个模型

    点击上方,选择星标或置顶,每天给你送干货! 阅读大概需要20分钟 跟随小博主,每天进步一丢丢 选自 | Ahmed BESBES 作者 | Ahmed Besbes 转自 | 机器之心 本文介绍了用于文本分类任务的 7 个模型,包括传统的词袋模型、循环神经网络,也有常用于计算机视觉任务的卷积神经网络,以及 RNN + CNN。 本文是我之前写过的一篇基于推特数据进行情感分析的文章,那时我建立...

  • 使用PyTorch的TensorBoard-可视化深度学习指标 | PyTorch系列(二十五)

    点击上方“AI算法与图像处理”,选择加"星标"或“置顶” 重磅干货,第一时间送达 文 |AI_study 原标题:TensorBoard With PyTorch - Visualize Deep Learning Metrics 在本系列的这一点上,我们刚刚完成了训练过程中的网络运行。现在,我们希望获得有关此过程的更多指标,以更好地了解幕后情况。 鸟瞰我们在训练过程中所处的位置。 准备数据...

  • pytorch loss反向传播出错(转载)

    今天在使用pytorch进行训练,在运行 loss.backward() 误差反向传播时出错 : RuntimeError: grad can be implicitly created only for scalar outputs File "train.py", line 143, in train loss.backward() File "/usr/local/lib/python...

华为云40多款云服务产品0元试用活动

免费套餐,马上领取!
CSDN

CSDN

中国开发者社区CSDN (Chinese Software Developer Network) 创立于1999年,致力为中国开发者提供知识传播、在线学习、职业发展等全生命周期服务。