pytorch에서 torch.Tensor.scatter 사용법

4665 단어 pythonpytorch
오늘 코드를 읽다가 어떤 사람이 torch를 쓰는 것을 보았다.Tensor.scatter라는 함수.이 함수는 전에도 본 적이 있지만, 무슨 용도로 쓰는지 이해하지 못해서, 오늘 나는 이해했다.
먼저 공식 문서의 정의를 살펴보자. (건너뛰기 가능)
===========================pytorch docs 분할선 =================================================scatter_ (dim, index, src) → Tensor
Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src , its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim .
For a 3-D tensor, self is updated as:
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

This is the reverse operation of the manner described in gather() . self , index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d , and that index.size(d) <= self.size(d) for all dimensions d != dim .
Moreover, as for gather() , the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.
Parameters
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
  • src (Tensor) – the source element(s) to scatter, incase value is not specified
  • value (float) – the source element(s) to scatter, incase src is not specified
  • ===============================================================================

  • 자신의 실험을 통해 나는 이 함수의 목적을 이해했고 다음에 설명을 하겠다.
    먼저 이 함수의 인터페이스를 보려면 세 개의 입력이 필요합니다. 1)차원dim2) 인덱스 그룹 index3) 원수 그룹 src는 이해하기 편리하도록 다음 글에서 src를 input 표시로 바꿉니다.최종 출력은 새 output 그룹입니다.
    다음은 다음과 같습니다.
    1) 차원dim: 정수, 0,1,2,3 가능...
    2) 인덱스 그룹 index: 인덱스 그룹은 tensor이고 그 중의 데이터 형식은 정수이며 위치를 표시한다
    3) 원수 그룹 input: 또한 tensor이며 그 중의 데이터 형식은 임의로
    먼저 이 함수가 무엇인지 말해 보자. 내가 보기에 이 scatter 함수는 input 수조의 데이터를 재분배하는 것이다.index에서 원수 그룹의 데이터를 output 그룹에 분배할 위치를 표시합니다. 지정하지 않으면 0을 채웁니다.
    예를 들어 다음 코드는 다음과 같습니다.
    import torch
    
    input = torch.randn(2, 4)
    print(input)
    output = torch.zeros(2, 5)
    index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
    output = output.scatter(1, index, input)
    print(output)

     
    실행 결과는 다음과 같습니다.
    tensor([[-0.2558, -1.8930, -0.7831,  0.6100],         [ 0.3246,  2.1289,  0.5887,  1.5588]]) tensor([[ 0.6100, -1.8930, -0.7831, -0.2558,  0.0000],         [ 0.5887,  0.3246,  2.1289,  1.5588,  0.0000]])
    다음은 왜 이런 결과가 나왔는지 자세히 말씀드리겠습니다.
    앞에서 말했듯이 scatter는 input 수조이다. index 수조에 따라 input 수조의 데이터를 재분배한다. 분배 과정이 어떠한지 살펴보자.
    input:
    tensor([[-0.2558, -1.8930, -0.7831,  0.6100],              [ 0.3246,  2.1289,  0.5887,  1.5588]])
    index:
    index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
    output:
    tensor([[ 0.6100, -1.8930, -0.7831, -0.2558,  0.0000],              [ 0.5887,  0.3246,  2.1289,  1.5588,  0.0000]])
    우선 input[0][0]를 재분배한다.기호->는 지정을 나타냅니다.scatter 방법의 1차원dim=1이기 때문에 input수 그룹의 데이터는 1차원에서만 재분배되고 0차원에서는 변하지 않는다.2차원수조로 예를 들면 첫 번째 줄의 데이터를 다시 분배한 후에 반드시 첫 번째 줄에 있고 두 번째 줄로 갈 수 없다.
    input[0][0] -> output[0][index[0][0]] = output[0][3]
    데이터 위치의 변화는 모두 1차원에서, 0차원에서는 변하지 않는다.
    input[0][1] -> output[0][index[0][1]] = output[0][1]
    input[0][2] -> output[0][index[0][2]] = output[0][2]
    input[0][3] -> output[0][index[0][3]] = output[0][0]
    주의해야 할 것은,
    이해하기 편리하도록 우리는 input에서 데이터의 순서에 따라 인덱스하지만,pytorch에서는 index[0][0]에서 index[0][3]까지의 순서에 따라 인덱스를 인덱스합니다. 인덱스의 input 위치와 output 위치는 반드시 존재해야 합니다. 그렇지 않으면 오류가 발생할 수 있습니다.그러나 모든 input 데이터가 output에 나누어지는 것은 아니다. output도 모든 위치에 대응하는 input이 있는 것은 아니다. output에 대응하는 input이 없을 때 자동으로 0을 채운다.
    일반 scatter는 다음과 같이 onehot 벡터를 생성하는 데 사용됩니다.
    index = torch.tensor([[1], [2], [0], [3]])
    onehot = torch.zeros(4, 4)
    onehot.scatter_(1, index, 1)
    print(onehot)
    

    출력 결과:
    tensor([[0., 1., 0., 0.],              [0., 0., 1., 0.],              [1., 0., 0., 0.],              [0., 0., 0., 1.]])
    만약 input이 하나의 숫자라면, 이것은 output에 분배되는 숫자가 얼마인지를 나타낸다.
     
     

    좋은 웹페이지 즐겨찾기