Chainer의 extensions.PlotReport를 로그 표시로 변경

Chainer에는 extensions.PlotReport() 라고 하는 extensions.LogReport 로 출력되는 로그 결과를 가시화해 주는 기능이 있다. 그러나 학습이 어느 정도 진행되면 loss도 거의 변화하지 않기 때문에 로그 그래프에서 가시화하고 싶어진다. 최근 릴리스된 chainerui를 이용하는 것도 손이지만, extensions.PlotReport()를 조금 변경하는 것만으로 로그 그래프로 출력할 수 있게 된다.

운영 환경


  • 우분투 16.04.3 LTS
  • Python 3.5.2
  • chainer 3.2

  • 그래프 이미지



    이하에 설명하는 함수를 다른 extensions 마찬가지로 import해 사용하면 이하와 같이 로그 그래프의 PlotReport가 생성된다.



    로그 그래프로 만들기



    소스 코드 얻기



    우선은 extensions.PlotReport()를 chainer의 소스 코드에서 가져온다.

    소스 코드 변경



    변경 전


    def __call__(self, trainer) 의 중간 정도의 이하를
    f = plt.figure()
    a = f.add_subplot(111)
    a.set_xlabel(self._x_key)
    if self._grid:
        a.grid()
    
    for k in keys:
        xy = data[k]
        if len(xy) == 0:
            continue
    
        xy = numpy.array(xy)
        plt.yscale("log")
        a.plot(xy[:, 0], xy[:, 1], marker=self._marker, label=k)
    
    if a.has_data():
        if self._postprocess is not None:
            self._postprocess(f, a, summary)
    

    변경 후



    다음과 같이 수정한다( # 追加 부분). 마지막 annotate() 좋아하는.
    f = plt.figure()
    a = f.add_subplot(111)
    a.set_xlabel(self._x_key)
    if self._grid:
         a.grid(which='major', color='gray', linestyle=':')# 追加
         a.grid(which='minor', color='gray', linestyle=':')# 追加
        # a.grid()は削除
    
    for k in keys:
        xy = data[k]
        if len(xy) == 0:
            continue
    
        xy = numpy.array(xy)
        plt.yscale("log")# 追加
        a.plot(xy[:, 0], xy[:, 1], marker=self._marker, label=k)
    
    if a.has_data():
        if self._postprocess is not None:
            self._postprocess(f, a, summary)
    
        # 追加(validationの最新の値を表示)
        a.annotate('validation\n{0:8.6f}'.format(xy[-1, 1]),
                   xy=(xy[-1]), xycoords='data',
                   xytext=(-90, 75), textcoords='offset points',
                   bbox=dict(boxstyle="round", fc="0.8"),
                   arrowprops=dict(arrowstyle="->",
                                   connectionstyle="arc,angleA=0,armA=50,rad=10"))
    
    

    가능하면 표준으로 구현해 주었으면 한다( PlotReport 에서 flg 지정하는 느낌)

    이상.

    좋은 웹페이지 즐겨찾기