PyTorch: nn.ModuleList

3402 단어 PyTorchPyTorch

nn.ModuleList

Python List와 마찬가지로 nn.Module을 저장하는 역할을 하며, index 접근도 가능하다.

사용법

Python List를 nn.ModuleList()로 감싸 주면 된다!

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

Python List와의 차이점

nn.ModuleList안에 Module들을 넣어 줌으로써 Module의 존재를 PyTorch에게 알려 주어야 한다. 만약 nn.ModuleList에 넣어 주지 않고, Python List에만 Module들을 넣어 준다면, PyTorch는 모델 파라미터의 존재를 알 수 없다! 때문에 optimizer 선언 시 model.parameter()를 사용하여 파라미터를 넘겨주려 할 때 에러가 발생한다. 따라서 Module들을 Python List에 넣어 보관하는 경우에는 마지막에 nn.ModuleList로 wrapping을 해줘야 한다!

좋은 웹페이지 즐겨찾기