pytorch에서 torch.Tensor.scatter 사용법
먼저 공식 문서의 정의를 살펴보자. (건너뛰기 가능)
===========================pytorch docs 분할선 =================================================
scatter_
(dim, index, src) → Tensor Writes all values from the tensor
src
into self
at the indices specified in the index
tensor. For each value in src
, its output index is specified by its index in src
for dimension != dim
and by the corresponding value in index
for dimension = dim
. For a 3-D tensor,
self
is updated as: self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
This is the reverse operation of the manner described in
gather()
. self
, index
and src
(if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d)
for all dimensions d
, and that index.size(d) <= self.size(d)
for all dimensions d != dim
. Moreover, as for
gather()
, the values of index
must be between 0
and self.size(dim) - 1
inclusive, and all values in a row along the specified dimension dim
must be unique. Parameters
자신의 실험을 통해 나는 이 함수의 목적을 이해했고 다음에 설명을 하겠다.
먼저 이 함수의 인터페이스를 보려면 세 개의 입력이 필요합니다. 1)차원dim2) 인덱스 그룹 index3) 원수 그룹 src는 이해하기 편리하도록 다음 글에서 src를 input 표시로 바꿉니다.최종 출력은 새 output 그룹입니다.
다음은 다음과 같습니다.
1) 차원dim: 정수, 0,1,2,3 가능...
2) 인덱스 그룹 index: 인덱스 그룹은 tensor이고 그 중의 데이터 형식은 정수이며 위치를 표시한다
3) 원수 그룹 input: 또한 tensor이며 그 중의 데이터 형식은 임의로
먼저 이 함수가 무엇인지 말해 보자. 내가 보기에 이 scatter 함수는 input 수조의 데이터를 재분배하는 것이다.index에서 원수 그룹의 데이터를 output 그룹에 분배할 위치를 표시합니다. 지정하지 않으면 0을 채웁니다.
예를 들어 다음 코드는 다음과 같습니다.
import torch
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)
실행 결과는 다음과 같습니다.
tensor([[-0.2558, -1.8930, -0.7831, 0.6100], [ 0.3246, 2.1289, 0.5887, 1.5588]]) tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000], [ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])
다음은 왜 이런 결과가 나왔는지 자세히 말씀드리겠습니다.
앞에서 말했듯이 scatter는 input 수조이다. index 수조에 따라 input 수조의 데이터를 재분배한다. 분배 과정이 어떠한지 살펴보자.
input:
tensor([[-0.2558, -1.8930, -0.7831, 0.6100], [ 0.3246, 2.1289, 0.5887, 1.5588]])
index:
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output:
tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000], [ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])
우선 input[0][0]를 재분배한다.기호->는 지정을 나타냅니다.scatter 방법의 1차원dim=1이기 때문에 input수 그룹의 데이터는 1차원에서만 재분배되고 0차원에서는 변하지 않는다.2차원수조로 예를 들면 첫 번째 줄의 데이터를 다시 분배한 후에 반드시 첫 번째 줄에 있고 두 번째 줄로 갈 수 없다.
input[0][0] -> output[0][index[0][0]] = output[0][3]
데이터 위치의 변화는 모두 1차원에서, 0차원에서는 변하지 않는다.
input[0][1] -> output[0][index[0][1]] = output[0][1]
input[0][2] -> output[0][index[0][2]] = output[0][2]
input[0][3] -> output[0][index[0][3]] = output[0][0]
주의해야 할 것은,
이해하기 편리하도록 우리는 input에서 데이터의 순서에 따라 인덱스하지만,pytorch에서는 index[0][0]에서 index[0][3]까지의 순서에 따라 인덱스를 인덱스합니다. 인덱스의 input 위치와 output 위치는 반드시 존재해야 합니다. 그렇지 않으면 오류가 발생할 수 있습니다.그러나 모든 input 데이터가 output에 나누어지는 것은 아니다. output도 모든 위치에 대응하는 input이 있는 것은 아니다. output에 대응하는 input이 없을 때 자동으로 0을 채운다.
일반 scatter는 다음과 같이 onehot 벡터를 생성하는 데 사용됩니다.
index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)
출력 결과:
tensor([[0., 1., 0., 0.], [0., 0., 1., 0.], [1., 0., 0., 0.], [0., 0., 0., 1.]])
만약 input이 하나의 숫자라면, 이것은 output에 분배되는 숫자가 얼마인지를 나타낸다.
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
로마 숫자를 정수로 또는 그 반대로 변환그 중 하나는 로마 숫자를 정수로 변환하는 함수를 만드는 것이었고 두 번째는 그 반대를 수행하는 함수를 만드는 것이었습니다. 문자만 포함합니다'I', 'V', 'X', 'L', 'C', 'D', 'M' ; 문자열이 ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.