[공부회 정리] Mean teachers are better role models

읽은 논문



"Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results"
논문 링크
2017년에 SoTA를 취하고 있었다.

d-SNE 의 Semi-Supervised Extension에서 사용되고 있는 기법이므로 읽어 보았다.

무슨 일이야?


  • Temporal Ensembling의 큰 데이터 세트를 다루기 어려운 문제를 해결하기 위해 모델 가중치를 평균화하는 방법을 제안합니다
  • Temporal Ensembling보다 적은 라벨이 붙은 데이터로 학습 가능

  • 신규성


  • Temporal Ensembling에서 사용 된 예측 값 대신 가중치 평균을 사용하는 점

  • 기술과 기법



    반교사 있음 학습



    Temporal Ensembling과 Mean Teacher는 Student 모델과 Teacher 모델을 사용한다. Student 모델을 Teacher 모델과 닮은 것으로 반교사 있어 학습을 실시한다.
    Teacher 모델이 대상 라벨을 만들고 Student 모델이 학습합니다.

    Mean Teacher





    마지막 가중치를 직접 사용하는 것보다 학습 단계에서 가중치를 평균화하는 것이 더 좋은 모델을 만들 수 있는 것 같다. 또한, 가중치의 평균화에 의해 중간 표현의 표현 능력도 향상되는 것 같다.
    Student 모델과 Teacher 모델의 가중치를 공유하는 대신 Teacher 모델의 가중치는 Student 모델의 가중치의 지수 이동 평균(Exponential Moving Average)을 사용한다.

    Temporal Ensembling에 비해 실용적인 면에서의 이점이 2가지 들 수 있다.
  • 타겟 라벨의 정확도가 높아지면 Student 모델과 Teacher 모델의 피드백 루프가 빨라지고 추정 정확도가 향상됩니다.
  • 대규모 데이터 세트와 온라인 학습을 지원합니다.

  • Consistency Cost



    consistency cost는 Student 모델의 출력과 Teacher 모델의 출력 사이의 거리로 정의됩니다.
    consistency cost$J$는
    Student 모델의 가중치를 $\theta$, 노이즈를 $\eta$,
    Teacher 모델의 가중치를 $\theta'$, 노이즈를 $\eta'$
    그러면 다음과 같이 쓸 수 있다.

    $${\displaystyle J(\theta) =\mathbb{E}_{x,\eta',\eta}\left[\lVert f(x,\theta',\eta')-f(x,\theta,\eta)\rVert^2\right]}$$

    $\Pi$ 모델과 Temporal Ensembling과 Mean Teacher의 차이는 Teacher 모델의 예측을 만드는 방법이다. $\Pi$ 모델에서는 $\theta'=\theta$를 사용하고 Temporal Ensembling에서는 연속 예측의 가중 평균으로 $f(x,\theta',\eta')$를 근사하지만 Mean Teacher에서는 학습 스텝 $t$에서의 $\theta_t'$를 이하의 식과 같이 했다.

    $${\displaystyle\theta_t'=\alpha\theta_{t-1}'+(1-\alpha)\theta_t}$$

    $\alpha$는 평활화 계수로 하이퍼파라미터이다.
    확률 기울기 강하 (SGD)를 사용하여 각 학습 단계에서 노이즈 $\eta$, $\eta'$를 샘플링하여 일관성 비용을 근사 할 수 있습니다.
    실험에서 평균 제곱 오차 (MSE)는 일관성 비용으로 사용됩니다.

    Mean Teacher 학습 방법



    라벨이있는 데이터로 Student 모델을 학습하여 가중치를 업데이트하고 Student 모델의 가중치에서 Teacher 모델의 가중치를 업데이트하는 절차에 따라 학습합니다 (?)
    Student 모델의 출력과 정답 라벨의 loss에, 정규화항으로서 Teacher 모델과 Student 모델의 출력의 거리를 추가해, 전체의 loss로 하고 있다(?)

    감상



    반교사 있어 학습의 $\Pi$모델이나 Temporal Ensembling등의 수법을 모르기 때문에 Mean Teacher도, 지금 일단 확실히 오지 않았다. Temporal Ensembling과 같은 기술을 이해할 필요가 있습니다.
    (K.S)

    좋은 웹페이지 즐겨찾기