Example code to understand pytorch padding
## test
import torch
import torch.nn as nn
batch_size = 3
max_length = 3
batch_in = torch.LongTensor(batch_size, max_length).zero_()
batch_in[0] = torch.LongTensor([1, 2, 3])
batch_in[1] = torch.LongTensor([4, 5, 0])
batch_in[2] = torch.LongTensor([6, 0, 0])
seq_lengths = [3, 2, 1]
# real seq lengths which means non-zero(padding) element in batch_in
print(batch_in)
rnn = torch.nn.GRU(10,16, batch_first=True)
emb = torch.nn.Embedding(20, 10)
t1 = emb(batch_in)
t2 = torch.nn.utils.rnn.pack_padded_sequence(t1, seq_lengths, batch_first=True)
t1.shape
y1, _ = rnn(t1)
y2, _ = rnn(t2)
y2, _ = torch.nn.utils.rnn.pad_packed_sequence(y2)
print(y1.shape)
print(y2.shape)
The outputs are as followed. Note that in y2, corresponding places are set to 0; while it is not in y1.
print(y1)
tensor([[[-0.2798, -0.3856, -0.0821, 0.1419, -0.0711, 0.3442, -0.2280,
-0.3013, 0.4323, -0.0061, -0.0059, 0.0153, 0.0043, -0.0575,
-0.0899, 0.0167],
[-0.3635, -0.1256, -0.1606, 0.0027, -0.1943, 0.2244, -0.2939,
-0.4380, 0.0758, 0.0657, 0.2524, 0.1803, 0.3322, -0.1517,
-0.0686, 0.2098],
[-0.5649, 0.1938, 0.1958, -0.1223, 0.2357, 0.2236, -0.1336,
-0.4538, -0.0373, 0.0701, 0.2694, 0.1607, 0.0233, -0.4146,
-0.2436, -0.1231]],
[[-0.1592, 0.1595, -0.1047, 0.1702, 0.2267, 0.0230, -0.1799,
0.0939, 0.0281, 0.3224, 0.0434, -0.0119, -0.1329, -0.3400,
-0.1222, -0.1626],
[-0.3771, -0.0135, -0.3905, -0.2723, -0.0649, 0.0405, -0.4896,
-0.3519, 0.5476, 0.5207, 0.1581, -0.3331, -0.2819, -0.4226,
-0.5203, -0.1131],
[-0.4858, 0.2086, -0.2146, -0.2873, 0.1970, -0.0200, -0.3474,
-0.3273, -0.1428, 0.3200, 0.3431, -0.1556, 0.0167, -0.4716,
-0.4962, 0.0401]],
[[-0.2618, -0.0852, 0.3207, 0.1903, -0.1152, 0.0338, -0.0622,
0.2939, 0.2542, 0.2764, 0.1221, 0.2355, -0.2659, 0.3058,
0.0794, -0.1431],
[-0.4471, 0.2109, 0.1460, -0.1279, 0.2117, -0.0292, -0.0887,
0.1132, -0.2051, 0.2217, 0.2989, 0.0810, 0.0741, -0.0756,
-0.0479, 0.1462],
[-0.5627, 0.3191, 0.0562, -0.2493, 0.3844, -0.0904, -0.1191,
-0.0179, -0.4021, 0.1875, 0.3719, -0.0419, 0.1960, -0.2602,
-0.1424, 0.2273]]], grad_fn=)
print(y2)
tensor([[[-0.2798, -0.3856, -0.0821, 0.1419, -0.0711, 0.3442, -0.2280,
-0.3013, 0.4323, -0.0061, -0.0059, 0.0153, 0.0043, -0.0575,
-0.0899, 0.0167],
[-0.1592, 0.1595, -0.1047, 0.1702, 0.2267, 0.0230, -0.1799,
0.0939, 0.0281, 0.3224, 0.0434, -0.0119, -0.1329, -0.3400,
-0.1222, -0.1626],
[-0.2618, -0.0852, 0.3207, 0.1903, -0.1152, 0.0338, -0.0622,
0.2939, 0.2542, 0.2764, 0.1221, 0.2355, -0.2659, 0.3058,
0.0794, -0.1431]],
[[-0.3635, -0.1256, -0.1606, 0.0027, -0.1943, 0.2244, -0.2939,
-0.4380, 0.0758, 0.0657, 0.2524, 0.1803, 0.3322, -0.1517,
-0.0686, 0.2098],
[-0.3771, -0.0135, -0.3905, -0.2723, -0.0649, 0.0405, -0.4896,
-0.3519, 0.5476, 0.5207, 0.1581, -0.3331, -0.2819, -0.4226,
-0.5203, -0.1131],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]],
[[-0.5649, 0.1938, 0.1958, -0.1223, 0.2357, 0.2236, -0.1336,
-0.4538, -0.0373, 0.0701, 0.2694, 0.1607, 0.0233, -0.4146,
-0.2436, -0.1231],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]]], grad_fn=)
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
다양한 언어의 JSONJSON은 Javascript 표기법을 사용하여 데이터 구조를 레이아웃하는 데이터 형식입니다. 그러나 Javascript가 코드에서 이러한 구조를 나타낼 수 있는 유일한 언어는 아닙니다. 저는 일반적으로 '객체'{}...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.