PyTorch: nn.ModuleList
nn.ModuleList
Python List와 마찬가지로 nn.Module을 저장하는 역할을 하며, index 접근도 가능하다.
사용법
Python List와 마찬가지로 nn.Module을 저장하는 역할을 하며, index 접근도 가능하다.
Python List를 nn.ModuleList()
로 감싸 주면 된다!
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
Python List와의 차이점
nn.ModuleList안에 Module들을 넣어 줌으로써 Module의 존재를 PyTorch에게 알려 주어야 한다. 만약 nn.ModuleList에 넣어 주지 않고, Python List에만 Module들을 넣어 준다면, PyTorch는 모델 파라미터의 존재를 알 수 없다! 때문에 optimizer 선언 시 model.parameter()
를 사용하여 파라미터를 넘겨주려 할 때 에러가 발생한다. 따라서 Module들을 Python List에 넣어 보관하는 경우에는 마지막에 nn.ModuleList로 wrapping을 해줘야 한다!
Author And Source
이 문제에 관하여(PyTorch: nn.ModuleList), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://velog.io/@danbibibi/PyTorch-nn.ModuleList저자 귀속: 원작자 정보가 원작자 URL에 포함되어 있으며 저작권은 원작자 소유입니다.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)