Pytorch 에서 torch.nn.softmax 의 dim 매개 변수 용법 설명

6923 단어 PytorchSoftmaxdim
Pytorch 에서 torch.nn.softmax 의 dim 매개 변수 사용 의미
다 차원 tensor 와 관련 되 었 을 때 softmax 의 매개 변수 dim 에 항상 빠 져 있 습 니 다.다음은 하나의 예 로 설명 하 겠 습 니 다.

import torch.nn as nn
m = nn.Softmax(dim=0)
n = nn.Softmax(dim=1)
k = nn.Softmax(dim=2)
input = torch.randn(2, 2, 3)
print(input)
print(m(input))
print(n(input))
print(k(input))
출력:
input
tensor([[[ 0.5450, -0.6264, 1.0446],
[ 0.6324, 1.9069, 0.7158]],
[[ 1.0092, 0.2421, -0.8928],
[ 0.0344, 0.9723, 0.4328]]])
dim=0
tensor([[[0.3860, 0.2956, 0.8741],
[0.6452, 0.7180, 0.5703]],
[[0.6140, 0.7044, 0.1259],
[0.3548, 0.2820, 0.4297]]])
dim=0 시,0 차원 에서 sum=1,즉:
[0][0][0]+[1][0][0]=0.3860+0.6140=1
[0][0][1]+[1][0][1]=0.2956+0.7044=1
… …
dim=1
tensor([[[0.4782, 0.0736, 0.5815],
[0.5218, 0.9264, 0.4185]],
[[0.7261, 0.3251, 0.2099],
[0.2739, 0.6749, 0.7901]]])
dim=1 시,1 차원 에서 sum=1,즉:
[0][0][0]+[0][1][0]=0.4782+0.5218=1
[0][0][1]+[0][1][1]=0.0736+0.9264=1
… …
dim=2
tensor([[[0.3381, 0.1048, 0.5572],
[0.1766, 0.6315, 0.1919]],
[[0.6197, 0.2878, 0.0925],
[0.1983, 0.5065, 0.2953]]])
dim=2 시,2 차원 에서 sum=1,즉:
[0][0][0]+[0][1]+[0][0][2]=0.3381+0.1048+0.5572=1.0001(반올림 문제)
[0][1][0]+[0][1][1]+[0][1][2]=0.1766+0.6315+0.1919=1
… …
그림 으로 223 의 장 량 을 다음 과 같이 표시 한다.
在这里插入图片描述
다 중 분류 문제 torch.nn.softmax 사용
왜 이 문 제 를 이야기 합 니까?제 가 일 을 하 는 과정 에서 의미 분할 예측 수출 특징 도 개 수 는 16 이 고 이른바 16 분류 문제 에 부 딪 혔 기 때 문 입 니 다.
각 채널 의 픽 셀 값 의 크기 는 픽 셀 이 이 채널 에 속 하 는 클래스 의 크기 를 대표 하기 때문에 한 장의 그림 에 다른 색 으로 표시 하기 위해 서 나 는 torch.nn.softmax 의 사용 을 배 워 야 했다.
먼저 간단 한 예 를 들 어 출력 이(3,4,4)이면 4x4 의 특징 도 3 장 입 니 다.

import torch
img = torch.rand((3,4,4))
print(img)
출력:
tensor([[[0.0413, 0.8728, 0.8926, 0.0693],
[0.4072, 0.0302, 0.9248, 0.6676],
[0.4699, 0.9197, 0.3333, 0.4809],
[0.3877, 0.7673, 0.6132, 0.5203]],
[[0.4940, 0.7996, 0.5513, 0.8016],
[0.1157, 0.8323, 0.9944, 0.2127],
[0.3055, 0.4343, 0.8123, 0.3184],
[0.8246, 0.6731, 0.3229, 0.1730]],
[[0.0661, 0.1905, 0.4490, 0.7484],
[0.4013, 0.1468, 0.2145, 0.8838],
[0.0083, 0.5029, 0.0141, 0.8998],
[0.8673, 0.2308, 0.8808, 0.0532]]])
우 리 는 모두 세 장의 특징 도 를 볼 수 있 는데 각 특징 도 에 대응 하 는 값 이 클 수록 이 특징 도 에 대응 하 는 확률 이 높다 는 것 을 설명 한다.

import torch.nn as nn
sogtmax = nn.Softmax(dim=0)
img = sogtmax(img)
print(img)
출력:
tensor([[[0.2780, 0.4107, 0.4251, 0.1979],
[0.3648, 0.2297, 0.3901, 0.3477],
[0.4035, 0.4396, 0.2993, 0.2967],
[0.2402, 0.4008, 0.3273, 0.4285]],
[[0.4371, 0.3817, 0.3022, 0.4117],
[0.2726, 0.5122, 0.4182, 0.2206],
[0.3423, 0.2706, 0.4832, 0.2522],
[0.3718, 0.3648, 0.2449, 0.3028]],
[[0.2849, 0.2076, 0.2728, 0.3904],
[0.3627, 0.2581, 0.1917, 0.4317],
[0.2543, 0.2898, 0.2175, 0.4511],
[0.3880, 0.2344, 0.4278, 0.2686]]])
위의 코드 는 각 특징 도 에 대응 하 는 위치의 픽 셀 값 을 Softmax 함수 로 처리 하고 그림 에 빨간색 위치 더하기=1,같은 이치 로 파란색 위치 더하기=1 을 표시 하 는 것 을 볼 수 있 습 니 다.
우 리 는 Softmax 함수 가 원래 특징 그림 의 모든 픽 셀 의 값 을 대응 차원(여기 dim=0,즉 1 차원)에서 계산 하여 0~1 사이 로 처리 하고 크기 가 고정 되 지 않 는 것 을 보 았 다.

print(torch.max(img,0))
출력:
torch.return_types.max(
values=tensor([[0.4371, 0.4107, 0.4251, 0.4117],
[0.3648, 0.5122, 0.4182, 0.4317],
[0.4035, 0.4396, 0.4832, 0.4511],
[0.3880, 0.4008, 0.4278, 0.4285]]),
indices=tensor([[1, 0, 0, 1],
[0, 1, 1, 2],
[0, 0, 1, 2],
[2, 0, 2, 0]]))
여기 서 3x4 x4 가 1x4 x4 로 바 뀌 었 고 해당 위치 에 있 는 값 은 픽 셀 로 각 채널 의 최대 값 에 대응 하 며 indices 는 대응 하 는 분류 임 을 알 수 있 습 니 다.
위의 절 차 를 똑똑히 이해 하면 우 리 는 쉽게 처리 할 수 있다.
구체 적 인 사례 를 보면 여기 출력 output 의 크기 는 16x416 x416 입 니 다.

output = torch.tensor(output)
sm = nn.Softmax(dim=0)
output = sm(output)
mask = torch.max(output,0).indices.numpy()
 
#       RGB   ,      
rgb_img = np.zeros((output.shape[1], output.shape[2], 3))
for i in range(len(mask)):
    for j in range(len(mask[0])):
        if mask[i][j] == 0:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 255
        if mask[i][j] == 1:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 180
            rgb_img[i][j][2] = 0
        if mask[i][j] == 2:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 180
            rgb_img[i][j][2] = 180
        if mask[i][j] == 3:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 180
            rgb_img[i][j][2] = 255
        if mask[i][j] == 4:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 180
        if mask[i][j] == 5:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 0
        if mask[i][j] == 6:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 180
        if mask[i][j] == 7:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 255
        if mask[i][j] == 8:
            rgb_img[i][j][0] = 255
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 0
        if mask[i][j] == 9:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 0
        if mask[i][j] == 10:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 255
        if mask[i][j] == 11:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 180
        if mask[i][j] == 12:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 255
        if mask[i][j] == 13:
            rgb_img[i][j][0] = 180
            rgb_img[i][j][1] = 255
            rgb_img[i][j][2] = 180
        if mask[i][j] == 14:
            rgb_img[i][j][0] = 0
            rgb_img[i][j][1] = 180
            rgb_img[i][j][2] = 255
        if mask[i][j] == 15:
            rgb_img[i][j][0] = 0
            rgb_img[i][j][1] = 0
            rgb_img[i][j][2] = 0
 
cv2.imwrite('output.jpg', rgb_img)
마지막 으로 저 장 된 그림 은:

이상 은 개인 적 인 경험 이 므 로 여러분 에 게 참고 가 되 기 를 바 랍 니 다.여러분 들 도 저 희 를 많이 응원 해 주시 기 바 랍 니 다.

좋은 웹페이지 즐겨찾기