精选文章 pytorch triple-loss

pytorch triple-loss

作者:vieo 时间: 2021-02-05 09:43:11
vieo 2021-02-05 09:43:11
【摘要】一、Triplet结构: 
triplet loss是一种比较好理解的loss,triplet是指的是三元组:Anchor、Positive、Negative: 
 
 整个训练过程是: 
首先从训练集中随机选一个样本,称为Anchor(记为x_a)。然后再随机选取一个和Anchor属于同一类的样本,称为Positive (记为x_p)最后再随机选取一个和Anchor属于不同类的样本,称为N...

一、Triplet结构

triplet loss是一种比较好理解的loss,triplet是指的是三元组:Anchor、Positive、Negative:

pytorch triple-loss1


整个训练过程是:

  1. 首先从训练集中随机选一个样本,称为Anchor(记为x_a)。
  2. 然后再随机选取一个和Anchor属于同一类的样本,称为Positive (记为x_p)
  3. 最后再随机选取一个和Anchor属于不同类的样本,称为Negative (记为x_n)

由此构成一个(Anchor,Positive,Negative)三元组。

二、Triplet Loss:

在上一篇讲了Center Loss的原理和实现,会发现现在loss的优化的方向是比较清晰好理解的。在基于能够正确分类的同时,我们更希望模型能够:1、把不同类之间分得很开,也就是更不容易混淆;2、同类之间靠得比较紧密,这个对于模型的鲁棒性的提高也是比较有帮助的(基于此想到Hinton的Distillation中给softmax加的一个T就是人为的对训练过程中加上干扰,让distribution变得更加soft从而去把错误信息放大,这样模型能够不光知道什么是正确还知道什么是错误。即:模型可以从仅仅判断这个最可能是7,变为可以知道这个最可能是7、一定不是8、和2比较相近,论文讲解可以参看Hinton-Distillation)。

回归正题,三元组的三个样本最终得到的特征表达计为:pytorch triple-loss2pytorch triple-loss3pytorch triple-loss4

triplet loss的目的就是让Anchor这个样本的feature和positive的feature直接的距离比和negative的小,即:

pytorch triple-loss5

除了让x_a和x_p特征表达之间的距离尽可能小,而x_a和x_n的特征表达之间的距离尽可能大之外还要让x_a与x_n之间的距离和x_a与x_p之间的距离之间有一个最小的间隔α,于是修改loss为:

pytorch triple-loss6

于是目标函数为:

pytorch triple-loss7

距离用欧式距离度量,+表示[  ***  ]内的值大于零的时候,取该值为损失,小于零的时候,损失为零。

故也可以理解为:

L = max([ ] ,  0)

在code中就是这样实现的,利用marginloss,详见下节。

 

三、Code实现:

笔者使用pytorch:

from torch import nn
from torch.autograd import Variable
 
class TripletLoss(object):
  def __init__(self, margin=None): self.margin = margin if margin is not None: self.ranking_loss = nn.MarginRankingLoss(margin=margin) else: self.ranking_loss = nn.SoftMarginLoss() def __call__(self, dist_ap, dist_an): """ Args: dist_ap: pytorch Variable, distance between anchor and positive sample, shape [N] dist_an: pytorch Variable, distance between anchor and negative sample, shape [N] Returns: loss: pytorch Variable, with shape [1] """ y = Variable(dist_an.data.new().resize_as_(dist_an.data).fill_(1)) if self.margin is not None: loss = self.ranking_loss(dist_an, dist_ap, y) else: loss = self.ranking_loss(dist_an - dist_ap, y) return loss

理解起来非常简单,当margin为空时,使用SoftMarginLoss:

pytorch triple-loss8

当margin不为空时,使用MarginRankingLoss,y中填充的都是1,代表希望dist_an>dist_ap,即anchor到negative样本的距离大于到positive样本的距离,margin为dist_an - dist_ap的值需要大于多少:

pytorch triple-loss9

与我们要得到的loss类似:当与正例距离+固定distance大于负例距离时为正值,则惩罚,否则不惩罚。

 

四、github项目介绍

class TripletLoss(nn.Module): """Triplet loss with hard positive/negative mining. Reference: Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. Imported from ``_. Args: margin (float, optional): margin for triplet. Default is 0.3. """ def __init__(self, margin=0.3,global_feat, labels): super(TripletLoss, self).__init__() self.margin = margin self.ranking_loss = nn.MarginRankingLoss(margin=margin) def forward(self, inputs, targets): """ Args: inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). targets (torch.LongTensor): ground truth labels with shape (num_classes). """ n = inputs.size(0)	# batch_size # Compute pairwise distance, replace by the official when merged dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) dist = dist + dist.t() dist.addmm_(1, -2, inputs, inputs.t()) dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability # For each anchor, find the hardest positive and negative mask = targets.expand(n, n).eq(targets.expand(n, n).t()) dist_ap, dist_an = [], [] for i in range(n): dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) dist_ap = torch.cat(dist_ap) dist_an = torch.cat(dist_an) # Compute ranking hinge loss y = torch.ones_like(dist_an) loss = self.ranking_loss(dist_an, dist_ap, y) return loss

 

勿删,copyright占位
分享文章到微博
分享文章到朋友圈

上一篇:实验室作业之英文文本聚类

下一篇:51单片机硬件基础知识

您可能感兴趣

  • Windows10下安装pytorch并导入pycharm

    1.安装Anaconda 下载:https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/ 安装Anaconda3,最新版本的就可以了,我安装的是5.3.1。安装完这个python也一起安装了。如果之前安装过python的,为了避免版本冲突的问题,建议把之前安装的给卸载。 2.添加路径到path,添加后可以使用conda命令了 ...

  • 强烈推荐的TensorFlow、Pytorch和Keras的样例资源(深度学习初学者必须收藏)

    TensorFlow、Keras和Pytorch是目前深度学习的主要框架,也是入门深度学习必须掌握的三大框架,但是官方文档相对内容较多,初学者往往无从下手。本人从github里搜到三个非常不错的学习资源,并对资源目录进行翻译,强烈建议初学者下载学习,这些资源包含了大量的代码示例(含数据集),个人认为,只要把以上资源运行一次,不懂的地方查官方文档,很快就能理解和运用这三大框架。 ...

  • 如何手动安装pytorch whl文件

    2019独角兽企业重金招聘Python工程师标准>>> pytorch 官方网站。 官方网站提供多种安装方式,例如pip和conda等,可以参考官方安装文档,文档很详细,建议仔细查看和使用。但安装过程中可能存在网络不稳定导致无法正常安装等情况,需要使用离线方式下载whl文件安装。 pip安装软件可以使用国内安装源加速安装,可以参考安装教程。 首先不推荐pip安装方式。Windows ...

  • Pytorch | 入门之框架介绍

    题外话: 为了监督自己可以好好地学习Pytorch框架,我准备提前挖几个坑,然后凭借自觉力(填坑强迫症)把任务完成。预祝自己成功2333333   回归正题,Pytorch入门系列的文章是基于 廖星宇 所著的《深度学习入门之PyTorch》对应章节的学习笔记。本篇文章就第二章《深度学习框架》进行学习笔记。之所以就Pytorch进行学习,第一是之前用过一段时间,发现代码可读性很强;第二是...

  • Pytorch(笔记6)--nn.Module功能详解

    在接触了之前所说的Conv,pool,Batchnorm,ReLU等方法都是神经网络中常见的操作,我们可以根据这些方法来自定义网络模型,也可以根据需求对经典模型进行调整,他们都继承共同的抽象类nn.Module,其中包含好了很多函数。 1.nn.Sequential    这个方法中可以封装多个子类,注意,一定继承nn.Module的类,在调用的时候,可以使用下面的方法 net = n...

  • pytorch实现Dropout与正则化防止过拟合

    numpy实现dropout与L1,L2正则化请参考我另一篇博客 https://blog.csdn.net/fanzonghao/article/details/81079757 pytorch使用dropout与L2  import torch import matplotlib.pyplot as plt torch.manual_seed(1) # Sets the seed ...

  • Pytorch(笔记8)--手写自己设计的神经网络

    在训练过程中经常做的一件事儿,就是拿已有网络模型ResNet,DenseNet等迁移到自己的数据集进行finetune,之后调整各个层级的输入输出等,我们先拿经典的lenet来说如何用Pytorch实现一个网络模型。(本人推荐使用Jupyter,方便调试,也可以保存成脚本)   添加依赖库 import torch import torch.nn as nn import torchvi...

  • Pytorch入门与实践——神经网络工具箱

    import torch as t from torch import nn from torch.autograd import Variable as V from torch.nn import functional as F from PIL import Image from torchvision.transforms import ToTensor, ToPILImage f...

CSDN

CSDN

中国开发者社区CSDN (Chinese Software Developer Network) 创立于1999年,致力为中国开发者提供知识传播、在线学习、职业发展等全生命周期服务。

华为云40多款云服务产品0元试用活动

免费套餐,马上领取!
pytorch triple-loss介绍:华为云为您免费提供pytorch triple-loss在博客、论坛、帮助中心等栏目的相关文章,同时还可以通过 站内搜索 查询更多pytorch triple-loss的相关内容。| 移动地址: pytorch triple-loss | 写博客