PyTorch: nn.ModuleList
nn.ModuleList
Python List와 마찬가지로 nn.Module을 저장하는 역할을 하며, index 접근도 가능하다.
사용법
Python List와 마찬가지로 nn.Module을 저장하는 역할을 하며, index 접근도 가능하다.
Python List를 nn.ModuleList()로 감싸 주면 된다!
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
Python List와의 차이점
nn.ModuleList안에 Module들을 넣어 줌으로써 Module의 존재를 PyTorch에게 알려 주어야 한다. 만약 nn.ModuleList에 넣어 주지 않고, Python List에만 Module들을 넣어 준다면, PyTorch는 모델 파라미터의 존재를 알 수 없다! 때문에 optimizer 선언 시 model.parameter()를 사용하여 파라미터를 넘겨주려 할 때 에러가 발생한다. 따라서 Module들을 Python List에 넣어 보관하는 경우에는 마지막에 nn.ModuleList로 wrapping을 해줘야 한다!
Author And Source
이 문제에 관하여(PyTorch: nn.ModuleList), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://velog.io/@danbibibi/PyTorch-nn.ModuleList저자 귀속: 원작자 정보가 원작자 URL에 포함되어 있으며 저작권은 원작자 소유입니다.
우수한 개발자 콘텐츠 발견에 전념
(Collection and Share based on the CC Protocol.)