pytorch: Label Smooth
2272 단어 pytorch
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
class LabelSmoothSoftmaxCE(nn.Module):
def __init__(self,
lb_pos=0.9,
lb_neg=0.005,
reduction='mean',
lb_ignore=255,
):
super(LabelSmoothSoftmaxCE, self).__init__()
self.lb_pos = lb_pos
self.lb_neg = lb_neg
self.reduction = reduction
self.lb_ignore = lb_ignore
self.log_softmax = nn.LogSoftmax(1)
def forward(self, logits, label):
logs = self.log_softmax(logits)
ignore = label.data.cpu() == self.lb_ignore
n_valid = (ignore == 0).sum()
label = label.clone()
label[ignore] = 0
lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1)
label = self.lb_pos * lb_one_hot + self.lb_neg * (1-lb_one_hot)
ignore = ignore.nonzero()
_, M = ignore.size()
a, *b = ignore.chunk(M, dim=1)
label[[a, torch.arange(label.size(1)), *b]] = 0
if self.reduction == 'mean':
loss = -torch.sum(torch.sum(logs*label, dim=1)) / n_valid
elif self.reduction == 'none':
loss = -torch.sum(logs*label, dim=1)
return loss
if __name__ == '__main__':
torch.manual_seed(15)
criteria = LabelSmoothSoftmaxCE(lb_pos=0.9, lb_neg=5e-3)
net1 = nn.Sequential(
nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1),
)
net1.cuda()
net1.train()
net2 = nn.Sequential(
nn.Conv2d(3, 3, kernel_size=3, stride=2, padding=1),
)
net2.cuda()
net2.train()
with torch.no_grad():
inten = torch.randn(2, 3, 5, 5).cuda()
lbs = torch.randint(0, 3, [2, 5, 5]).cuda()
lbs[1, 3, 4] = 255
lbs[1, 2, 3] = 255
print(lbs)
import torch.nn.functional as F
logits1 = net1(inten)
logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
logits2 = net2(inten)
logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')
# loss1 = criteria1(logits1, lbs)
loss = criteria(logits1, lbs)
# print(loss.detach().cpu())
loss.backward()
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
정확도에서 스케일링의 영향데이터셋 스케일링은 데이터 전처리의 주요 단계 중 하나이며, 데이터 변수의 범위를 줄이기 위해 수행됩니다. 이미지와 관련하여 가능한 최소-최대 값 범위는 항상 0-255이며, 이는 255가 최대값임을 의미합니다. 따...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.