pytorch의 CrossEntropyLoss에 대한 weight 매개 변수

우선 이 weight 매개 변수는 생각보다 많이 고려했으니 다음 코드를 시험해 보세요
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)

tensor(1.4803)
여기의 수동 계산은 loss1 = 0 + ln(e0 + e0 + e0 + e0) = 1.098 loss2 = 0 + ln(e1 + e0 + e1) = 1.86 구평균=(loss1 + loss2 *1)/2 = 1.4803 xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)

tensor(1.6075)
손수 계산해 보니 단순한 그 권중상승이 아니다:loss1=0+ln(e0+e0+e0)=1.098loss2=0+ln(e1+e0+e1)=1.86구평균=(loss1*1+loss2*2)/2=2.4113
대신 loss1 = 0 + ln(e0 + e0 + e0) = 1.098 loss2 = 0 + ln(e1 + e0 + e1) = 1.86 구평균 = (loss1 * 1 + loss2 *2)/3 = 1.6075
발견했느냐, 가중된 후에 가중된 합을 제외하고는 수량의 합이 아니다.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)

tensor(1.5472)
수산:loss1 = 0 + ln(e0 + e0 + e0 + e0) = 1.098 loss2 = 0 + ln(e1 + e0 + e1) = 1.86 loss3 = 0 + ln(e2 + e0 + e0) = 2.2395 loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943 구평균 = (loss1 * 1 + loss2 * 2 * 2 + loss3 * 3 + loss4 * 3 * 3 + loss4 * 3)/9 = 1.5472
누군가가 로스의 CE 계산 과정에 대해 의문이 있을 수 있습니다. 저는 여기에 교차 엔트로피의 계산 과정을 세밀하게 쓰고 마지막 예의 로스4의 계산 설명을 드리겠습니다.
4
  • loss4의 inputs(i4)(i 4)(i4)와 targets(t4)(t 4)(t4)(t4)가 i 4 = i n p u s t s [0,:,3] = [0,0,0.5] i 로 우선 결정4=inpusts[0,,3]=[0,0,0.5] i4 = inpusts[0,:,3]=[0,0,0.5]와 t4 = o u t p u t s [0,3] = 2 t4=outputs[0,3]=2 t4​=outputs[0,3]=2

  • 교차 엔트로피 공식에 따라 4, l (e 4 [t 4] e 4 [0] + e i 4 [1] + e 4 [2] = i4 [1] + e 4 [2] = i 4 [t 4] + l n (e i 4 [t 4] e 4 [0] + e i 4 [2] = 0.5 + l n (e0 + e 0 + e 0 + e 0.5) = 0.79943 -ln ({{{e^i 4 [t 4]} 4 [1] + e 4 [1] + e i 4 [1] + e 4 [1 + e 4 [2]) = = {0.5 0 + e 4 [4]} {4} {4} {4} 4} {4} 4} {4} 4} {4}-i4[t_4]+ln(e^{i_4[0]}+e^{i_4[1]}+e^{i_4[2]})=-0.5+ln(e^{0}+e^{0}+e^{0.5})=0.7943 −ln(ei4​[0]+ei4​[1]+ei4​[2]ei4​[t4​]​)=−i4​[t4​]+ln(ei4​[0]+ei4​[1]+ei4​[2])=−0.5+ln(e0+e0+e0.5)=0.7943

    좋은 웹페이지 즐겨찾기