pytorch label smoothing 코드

13731 단어 LDR2HDR
pytorch label smoothing 코드
import torch
import torch.nn as nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
 
 
class LabelSmoothing(nn.Module):
    # "Implement label smoothing."
 
    def __init__(self, size, smoothing=0.0):
 
        super(LabelSmoothing, self).__init__()
 
        self.criterion = nn.KLDivLoss(size_average=False)
 
        #self.padding_idx = padding_idx
 
        self.confidence = 1.0 - smoothing
 
        self.smoothing = smoothing
 
        self.size = size
 
        self.true_dist = None
 
    def forward(self, x, target):
        """
        x     (M,N)N   ,M     ,       log P
        target  label(M,)
        """
        assert x.size(1) == self.size
        x = x.log()
        true_dist = x.data.clone()#      
        #print true_dist
        true_dist.fill_(self.smoothing / (self.size - 1))#otherwise   
        #print true_dist
        #  one-hot  ,1      ,
        #target.data.unsqueeze(1)    ,confidence       
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
 
        self.true_dist = true_dist
        print(x.shape,true_dist.shape)
        
        return self.criterion(x, Variable(true_dist, requires_grad=False))



class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))


if __name__=="__main__":
# Example of label smoothing.
 
    crit = LabelSmoothingLoss(classes=5, smoothing= 0.1)
    #predict.shape 3 5
    predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
    
                                 [0, 0.9, 0.2, 0.1, 0], 
    
                                 [1, 0.2, 0.7, 0.1, 0]])

    v = crit(Variable(predict),
    
             Variable(torch.LongTensor([2, 1, 0])))
    print(v)
    
    # Show the target distributions expected by the system.
    
    plt.imshow(crit.true_dist)


좋은 웹페이지 즐겨찾기