pytorch에서 index_select () 의 사용법 상세 설명

3066 단어 pytorchindex select()
pytorch에서 index_select () 의 사용 방법

index_select(input, dim, index)
기능: 지정한 차원dim에서 데이터를 선택하는 것보다 일부 줄, 열을 선택하는 것이 낫다
매개 변수 소개
  • 첫 번째 인자 input는 색인을 위한 대상입니다
  • 두 번째 매개 변수dim는 찾아야 할 차원이다. 일반적인 상황에서 우리가 사용하는 것은 모두 2차원 장량이기 때문에 간단하게 기억할 수 있다. 0은 행을 대표하고 1은 열을 대표한다
  • 세 번째 매개 변수 index는 당신이 인덱스하려는 서열입니다. 이것은tensor 대상입니다
  • Pytorch를 배우기 시작했는데 index_를 만났어요.select (), 처음에는 몇 개의 매개 변수의 뜻을 잘 몰랐는데, 나중에 자료를 찾아보니 조금 이해한 셈이다.
    
    a = torch.linspace(1, 12, steps=12).view(3, 4)
    print(a)
    b = torch.index_select(a, 0, torch.tensor([0, 2]))
    print(b)
    print(a.index_select(0, torch.tensor([0, 2])))
    c = torch.index_select(a, 1, torch.tensor([1, 3]))
    print(c)
    
    먼저tensor를 정의했습니다. 여기에는linspace와view 방법이 사용됩니다.
    첫 번째 매개 변수는 인덱스의 대상이다. 두 번째 매개 변수 0은 줄 인덱스를 나타내고 1은 열에 따라 인덱스를 한다. 세 번째 매개 변수는 하나의tensor이다. 바로 인덱스의 번호이다. 예를 들어 b안에tensor[0,2]는 0줄과 2줄을 나타내고 c안에tensor[1,3]는 1열과 3열을 나타낸다.
    출력 결과는 다음과 같습니다.
    tensor([[ 1.,  2.,  3.,  4.],
            [ 5.,  6.,  7.,  8.],
            [ 9., 10., 11., 12.]])
    tensor([[ 1.,  2.,  3.,  4.],
            [ 9., 10., 11., 12.]])
    tensor([[ 1.,  2.,  3.,  4.],
            [ 9., 10., 11., 12.]])
    tensor([[ 2.,  4.],
            [ 6.,  8.],
            [10., 12.]])
    예제 2
    
    import torch
     
    x = torch.Tensor([[[1, 2, 3],
              [4, 5, 6]],
     
             [[9, 8, 7],
              [6, 5, 4]]])
    print(x)
    print(x.size())
    index = torch.LongTensor([0, 0, 1])
    print(torch.index_select(x, 0, index))
    print(torch.index_select(x, 0, index).size())
    print(torch.index_select(x, 1, index))
    print(torch.index_select(x, 1, index).size())
    print(torch.index_select(x, 2, index))
    print(torch.index_select(x, 2, index).size())
    
    input의 장량 형상은 2×이×3, index는 [0, 0, 1]의 벡터이다
    각각 0, 1, 2 세 차원에서 index_select () 함수, 결과와 모양을 출력합니다. 차원이 2보다 크면 오류가 발생합니다. input는 최대 세 차원만 있기 때문입니다.
    출력:
    tensor([[[1., 2., 3.],
             [4., 5., 6.]],
     
            [[9., 8., 7.],
             [6., 5., 4.]]])
    torch.Size([2, 2, 3])
    tensor([[[1., 2., 3.],
             [4., 5., 6.]],
     
            [[1., 2., 3.],
             [4., 5., 6.]],
     
            [[9., 8., 7.],
             [6., 5., 4.]]])
    torch.Size([3, 2, 3])
    tensor([[[1., 2., 3.],
             [1., 2., 3.],
             [4., 5., 6.]],
     
            [[9., 8., 7.],
             [9., 8., 7.],
             [6., 5., 4.]]])
    torch.Size([2, 3, 3])
    tensor([[[1., 1., 2.],
             [4., 4., 5.]],
     
            [[9., 9., 8.],
             [6., 6., 5.]]])
    torch.Size([2, 2, 3])
    결과 분석:
    index는 크기가 3인 벡터입니다. 입력한 장량 모양은 2입니다.×이×삼
    dim = 0 시 출력의 장량 모양은 3×이×삼
    dim = 1 시 출력의 장량 모양은 2×삼×삼
    dim = 2 시 출력의 장량 형태는 2×이×삼
    출력 장량 차원의 변화와 index 크기의 관계를 주의하고 출력의 장량과 원시 장량을 결합하여 index_ 분석select () 함수의 역할
    이 pytorch에 대한 index_select()의 용법에 대한 상세한 설명은 여기까지입니다.select () 내용은 저희의 이전 글을 검색하거나 아래의 관련 글을 계속 훑어보시기 바랍니다. 앞으로 많은 응원 부탁드립니다!

    좋은 웹페이지 즐겨찾기