torch.repeat()가 뭐야

class Decoder(nn.Module):

    def __init__(self, seq_len, input_dim=64, n_features=1):
        super(Decoder, self).__init__()

        self.seq_len, self.input_dim = seq_len, input_dim
        self.hidden_dim, self.n_features = 2 * input_dim, n_features

        self.rnn1 = nn.LSTM(
            input_size=input_dim,
            hidden_size=input_dim,
            num_layers=1,
            batch_first=True
        )

        self.rnn2 = nn.LSTM(
            input_size=input_dim,
            hidden_size=self.hidden_dim,
            num_layers=1,
            batch_first=True
        )

        self.output_layer = nn.Linear(self.hidden_dim, n_features)

    def forward(self, x):
        x = x.repeat(self.seq_len, self.n_features)
        x = x.reshape((self.n_features, self.seq_len, self.input_dim))

        x, (hidden_n, cell_n) = self.rnn1(x)
        x, (hidden_n, cell_n) = self.rnn2(x)
        x = x.reshape((self.seq_len, self.hidden_dim))

        return self.output_layer(x)

decoder의 forward에서 x.repeat라는걸 쓴다

뭘까?


https://seducinghyeok.tistory.com/9
이 블로그에서는 torch.repeat()와 torch.expand()를 비교하여 보여줬다

차원을 늘리는거란다

Decoder에 딱 맞는 코드구만

덤으로

expand 또한 텐서를 반복시키는건데 1차원에서만 쓰는거라고 한다

끝.

좋은 웹페이지 즐겨찾기