pytorch는 여러 모델의 weights를 평균적으로 실현하고 weights를 수정합니다
2808 단어 #
문서 목록
1. 조작설명
3개의 구조가 같지만 weights의 다른 모델은 하나의list를 구성하고 모델스=[모델1,모델2,모델3], 그리고 중심모델fl모델, 이 네 모델의 구조와 슈퍼 파라미터는 모두 같다.
이런 조작이 필요하다. 평균 모델 안의 세 모델의 weights를 평균 이후의 weights에'값'을 부여한다fl모델의 weights.
2. 코드
tensorflow에서는 모델을 직접 사용할 수 있습니다.get_weights () 및 모델.set_weights()로 하면 직관적이고 편리합니다.피토치 안이 좀 복잡한 것 같아서요.이러한 작업을 수행하는 코드는 다음과 같습니다.
worker_state_dict=[x.state_dict() for x in models]
weight_keys=list(worker_state_dict[0].keys())
fed_state_dict=collections.OrderedDict()
for key in weight_keys:
key_sum=0
for i in range(len(models)):
key_sum=key_sum+worker_state_dict[i][key]
fed_state_dict[key]=key_sum/len(models)
#### update fed weights to fl model
fl_model.load_state_dict(fed_state_dict)
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
Rails Turbolinks를 페이지 단위로 비활성화하는 방법원래 Turobolinks란? Turbolinks는 링크를 생성하는 요소인 a 요소의 클릭을 후크로 하고, 이동한 페이지를 Ajax에서 가져옵니다. 그 후, 취득 페이지의 데이터가 천이 전의 페이지와 동일한 것이 있...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.