pytorch는 여러 모델의 weights를 평균적으로 실현하고 weights를 수정합니다

2808 단어 #

문서 목록

  • 1. 운영 지침
  • 2.코드
  • 1. 조작설명


    3개의 구조가 같지만 weights의 다른 모델은 하나의list를 구성하고 모델스=[모델1,모델2,모델3], 그리고 중심모델fl모델, 이 네 모델의 구조와 슈퍼 파라미터는 모두 같다.
    이런 조작이 필요하다. 평균 모델 안의 세 모델의 weights를 평균 이후의 weights에'값'을 부여한다fl모델의 weights.

    2. 코드


    tensorflow에서는 모델을 직접 사용할 수 있습니다.get_weights () 및 모델.set_weights()로 하면 직관적이고 편리합니다.피토치 안이 좀 복잡한 것 같아서요.이러한 작업을 수행하는 코드는 다음과 같습니다.
    worker_state_dict=[x.state_dict() for x in models]
    weight_keys=list(worker_state_dict[0].keys())
    fed_state_dict=collections.OrderedDict()
    for key in weight_keys:
        key_sum=0
        for i in range(len(models)):
            key_sum=key_sum+worker_state_dict[i][key]
        fed_state_dict[key]=key_sum/len(models)
    #### update fed weights to fl model
    fl_model.load_state_dict(fed_state_dict)
    

    좋은 웹페이지 즐겨찾기