Deconvolution 2D 학습 경과를 jupyter에서 그리기 (Bokeh)

소개



딥 러닝의 학습 경과의 이미지를 jupyter상에서 실시간으로 표시할 수 있는 것이 없을까 찾고 있었는데, 보케 라는 라이브러리가 있었으므로 시험해 보았습니다.

학습 내용



Chainer를 사용하여 1층의 Deconvolution 2D(간단한 필터 같은 것)를 학습시켜 보았습니다. Deconvolution에 대해서는 이전에 여기에서 소개한 것 등 참고해 주세요.
아래의 gif에서는 점 ⇒ 구 모양으로 변화하도록 학습시키고 있습니다.



Bokeh에서는 이미지를 마우스 휠로 확대 축소 등 할 수 있으므로 학습 확인에 편리합니다.

출처



이전 Chainer에서는 numpy를 한 번 Variable로 변환해야 했지만 Ver 1.17의 현재는 자동으로 Variable로 변환 해주는 것 같습니다.
Bokeh의 jupyter상에서의 표시에 대해서는 여기 등 참고로 하고 있습니다.
import chainer.links as L
import chainer.functions as F
from chainer import Variable, optimizers
import numpy as np
import math
import time

#1つの球状の模様を作成(ガウスですが)
def make_one_core():            
    max_xy=15    
    sig=5.0
    sig2=sig*sig
    c_xy=7
    core=np.zeros((max_xy, max_xy), dtype= np.float32)
    for px in range(0, max_xy):
        for py in range(0, max_xy):
            r2=(px-c_xy)*(px-c_xy)+(py-c_xy)*(py-c_xy)
            core[py][px]=math.exp(-r2/sig2)*1
    return core.reshape((1, 1, core.shape[0], core.shape[1]))

#点と球状のimageを作成
def get_image(N=1, img_w=128, img_h=128):

    #ランダムに0.1%の点を作る
    img_p = np.random.randint(0, 10000, size = N*img_w*img_h)
    img_p[img_p < 9990]=0
    img_p[img_p >= 9990]=255

    img_p = img_p.reshape((N,1,img_h, img_w)).astype(np.float32)

    decon_core = L.Deconvolution2D(1, 1, 15, stride=1, pad=7)
    #Wに球状の模様をあてる
    decon_core.W.data = make_one_core()

    #点⇒球に変換
    img_core = decon_core(img_p)#Variableに変換なしでもOK

    return img_p, img_core.data    


#初期描画
from bokeh.plotting import figure
from bokeh.io import push_notebook, show, output_notebook
from bokeh.layouts import gridplot

output_notebook()

palette_256 = ['#%02x%02x%02x' %(i,i,i) for i in range(256)] #256段階で白黒表示用

img_p, img_core = get_image()#点と球状のimageを取得

img_h = img_p.shape[2]
img_w = img_p.shape[3]

plt1 = figure(title = 'epoch = --', x_range=[0, img_w], y_range=[0, img_h])
rend1 = plt1.image(image=[img_p[0][0]],x=[0], y=[0], dw=[img_w], dh=[img_h], palette=palette_256)

plt2 = figure(title = 'loss  = 0', x_range=plt1.x_range, y_range=plt1.y_range)
rend2 = plt2.image(image=[img_core[0][0]],x=[0], y=[0], dw=[img_w], dh=[img_h], palette=palette_256)

plts = gridplot([[plt1,plt2]], plot_width=300, plot_height=300)
handle = show(plts, notebook_handle=True)

#モデル・オプティマイザ設定
model =  L.Deconvolution2D(1, 1, 15, stride=1, pad=7)#1層のDeconvolution
optimizer = optimizers.SGD(lr=0.001)#大きいと発散する
optimizer.setup(model)

#計算
for epoch in range(0,31):    

    #1層のDeconvolutionを通してロスを計算しアップデート
    model.cleargrads()
    img_y = model(img_p)
    loss = F.mean_squared_error(img_y, img_core)
    loss.backward()
    optimizer.update()

    #画像・ロスデータをセット
    rend1.data_source.data['image'] = [img_p[0][0]]
    rend2.data_source.data['image'] = [img_y.data[0][0]]
    plt1.title.text='epoch = '+str(epoch)
    plt2.title.text='loss  = '+str(loss.data)
    push_notebook(handle = handle)#表示をアップデート
    time.sleep(0.5)

좋은 웹페이지 즐겨찾기