pytorch 상용 데이터 형식이 차지하는 바이트 대조표 일람

PyTorch의 일반적인 데이터 유형은 다음과 같습니다.


Data type
dtype
CPU tensor
GPU tensor
Size/bytes
32-bit floating
torch.float32 or torch.float
torch.FloatTensor
torch.cuda.FloatTensor

64-bit floating
torch.float64 or torch.double
torch.DoubleTensor
torch.cuda.DoubleTensor

16-bit floating
torch.float16or torch.half
torch.HalfTensor
torch.cuda.HalfTensor
-
8-bit integer (unsigned)
torch.uint8
torch.ByteTensor
torch.cuda.ByteTensor

8-bit integer (signed)
torch.int8
torch.CharTensor
torch.cuda.CharTensor
-
16-bit integer (signed)
torch.int16or torch.short
torch.ShortTensor
torch.cuda.ShortTensor

32-bit integer (signed)
torch.int32 or torch.int
torch.IntTensor
torch.cuda.IntTensor

64-bit integer (signed)
torch.int64 or torch.long
torch.LongTensor
torch.cuda.LongTensor

위의 PyTorch의 데이터 형식과numpy의 상대적인 바이트 크기도 같다
추가:pytorchtensor 비교 크기 데이터 형식 주의

아래와 같다


a = torch.tensor([[0, 0], [0, 0]])
print(a>=0.5)
출력
tensor([[1, 1],
[1, 1]], dtype=torch.uint8)
결과가 현저히 틀렸다. 분석 원인은 a는long 유형이고 0.5는float이기 때문이다.0.5는 롱으로 바뀌고 0으로 바뀐다.따라서 결과가 틀릴 수 있으므로 다음과 같은 수정을 하면 정답을 얻을 수 있다

올바른 사용법:


a = torch.tensor([[0, 0], [0, 0]]).float()
print(a>=0.5)
이상의 개인적인 경험으로 여러분께 참고가 되었으면 좋겠습니다. 또한 많은 응원 부탁드립니다.

좋은 웹페이지 즐겨찾기