torch.repeat()가 뭐야
class Decoder(nn.Module):
def __init__(self, seq_len, input_dim=64, n_features=1):
super(Decoder, self).__init__()
self.seq_len, self.input_dim = seq_len, input_dim
self.hidden_dim, self.n_features = 2 * input_dim, n_features
self.rnn1 = nn.LSTM(
input_size=input_dim,
hidden_size=input_dim,
num_layers=1,
batch_first=True
)
self.rnn2 = nn.LSTM(
input_size=input_dim,
hidden_size=self.hidden_dim,
num_layers=1,
batch_first=True
)
self.output_layer = nn.Linear(self.hidden_dim, n_features)
def forward(self, x):
x = x.repeat(self.seq_len, self.n_features)
x = x.reshape((self.n_features, self.seq_len, self.input_dim))
x, (hidden_n, cell_n) = self.rnn1(x)
x, (hidden_n, cell_n) = self.rnn2(x)
x = x.reshape((self.seq_len, self.hidden_dim))
return self.output_layer(x)
decoder의 forward에서 x.repeat라는걸 쓴다
뭘까?
https://seducinghyeok.tistory.com/9
이 블로그에서는 torch.repeat()와 torch.expand()를 비교하여 보여줬다
차원을 늘리는거란다
Decoder에 딱 맞는 코드구만
덤으로
expand 또한 텐서를 반복시키는건데 1차원에서만 쓰는거라고 한다
끝.
Author And Source
이 문제에 관하여(torch.repeat()가 뭐야), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://velog.io/@khj9204/torch.repeat가-뭐야저자 귀속: 원작자 정보가 원작자 URL에 포함되어 있으며 저작권은 원작자 소유입니다.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)