Numpy와 Pytorch가 서로 변환될 때의 구덩이를 해결합니다.
앞말.
최근 Numpy 패키지와 Pytorch를 사용하여 신경 네트워크를 쓸 때 서로 전환해야 하기 때문에 이 노트로 코드 코드를 기록할 때 밟은 구덩이를 인터넷에서 어떤 사람이 말한다.
Pytorch는 GPU 버전의 Numpy라고도 불리는데 양자의 많은 기능은 모두 잘 대응하고 있다.
그러나 사용할 때는 주의를 기울여야 한다. 조심하지 않으면 담배 한 개비, 술 한 잔, 버그 한 개가 하룻밤을 찾을 지경에 빠진다.
1.1、numpy ――> torch
torch를 사용합니다.from_numpy () 변환, 둘 다 메모리를 공유합니다.예는 다음과 같습니다.
import torch
import numpy as np
a = np.array([1,2,3])
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(' a', a)
print(' b', b)
#
a [2 3 4]
b tensor([2, 3, 4], dtype=torch.int32)
1.2、torch――> numpy
사용numpy () 변환, 마찬가지로 메모리를 공유합니다.예는 다음과 같습니다.
import torch
import numpy as np
a = torch.zeros((2, 3), dtype=torch.float)
c = a.numpy()
np.add(c, 1, out=c)
print('a:', a)
print('c:', c)
#
a: tensor([[1., 1., 1.],
[1., 1., 1.]])
c: [[1. 1. 1.]
[1. 1. 1.]]
주의해야 할 것은 프로그램의 np를dd(c,1,out=c)를 c=c+1로 바꾸면 둘이 메모리를 공유하지 않는 것 같지만 사실은 그렇지 않다. 왜냐하면 후자는 c의 저장 주소를 바꾸는 것과 맞기 때문이다.id(c)를 사용하여 c의 메모리 위치가 변한 것을 발견할 수 있습니다.보충:pytorch에서tensor 데이터와numpy 데이터 변환에 주의하는 문제
pytorch에서numpy를.array 데이터가 장량tensor 데이터로 변환되는 상용 함수는torch입니다.from_numpy (array) 또는 torch.Tensor(array), 첫 번째 함수는 더 자주 사용됩니다.
다음은 코드를 통해 차이점을 살펴보겠습니다.
import numpy as np
import torch
a=np.arange(6,dtype=int).reshape(2,3)
b=torch.from_numpy(a)
c=torch.Tensor(a)
a[0][0]=10
print(a,'
',b,'
',c)
[[10 1 2]
[ 3 4 5]]
tensor([[10, 1, 2],
[ 3, 4, 5]], dtype=torch.int32)
tensor([[0., 1., 2.],
[3., 4., 5.]])
c[0][0]=10
print(a,'
',b,'
',c)
[[10 1 2]
[ 3 4 5]]
tensor([[10, 1, 2],
[ 3, 4, 5]], dtype=torch.int32)
tensor([[10., 1., 2.],
[ 3., 4., 5.]])
print(b.type())
torch.IntTensor
print(c.type())
torch.FloatTensor
수정수 그룹 a의 원소 값을 볼 수 있습니다. 장량 b의 원소 값도 바뀌었지만 장량 c는 변하지 않습니다.장량 c의 원소 값을 수정하면 수조 a와 장량 b의 원소 값은 변하지 않습니다.이것은 torch를 설명한다.from_numpy(array)는 수조의 얕은 복사, torch입니다.Tensor(array)는 수조의 깊은 복사본이다.
이상의 개인적인 경험으로 여러분께 참고가 되었으면 좋겠습니다. 또한 많은 응원 부탁드립니다.
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
Python 기반 Numpy의 기본 사용법 상세 정보a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])#1차원 그룹 b = np.array([[1,2],[3,4]])#2차원 그룹 1.2 시퀀스 배열 numpy.random.이 방법...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.