Tensorflow + matplotlib에서 추론 및 결과보기

1. 소개



Tensorflow나 PyTorch, Chainer에서 모델의 평가는 할 수 있었지만, Deep Learning을 많이 모르는 사람에게 Accuracy나 Loss의 그래프만 보여도…
또한 손쉽게이 이미지는 정답! 이것은 부정해! 라는 것이 보이면 기쁜 장면도 있을 것입니다.
그런 소원을 이루기 위해, 복수 줄지어진 화상에 판정 결과를 실어 가는, 라고 하는 것을 Matplotlib로 실현해 보고 싶습니다.

「―――기쁘게 소년. 네 소원은 드디어 이루어진다」


또, 이번은 예로서 Tensorflow를 이용합니다만, 화상의 표시 부분은 프레임워크가 무엇이어도 괜찮습니다.

2. 구현 & 해설



예를 들어 Tensorflow를 사용하여 MNIST를 학습한 모델을 준비했습니다.
이번 이미지는 40장 사용합니다.

validation.py
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf

# 表示画像枚数の設定
row = 4
col = 10

# データのロード
mnist = tf.keras.datasets.mnist
(_, _), (x_test, y_test) = mnist.load_data()
x_test = np.asarray(x_test[0:row*col])
y_test = np.asarray(y_test[0:row*col])

# モデルのロード
path = 'mnist.h5' # 学習済みモデルのパス
model = tf.keras.models.load_model(path)

# 推論
x_test_flat = x_test.reshape(-1, 784) / 255.0
result = model.predict(x_test_flat)

# 画像の整列
plt.figure(figsize=(5, 5))
image_array = []
for i, image in enumerate(x_test):
    image_array.append(plt.subplot(row, col, i + 1))
    plt.axis('off')
    plt.imshow(image, cmap='gray')
plt.pause(0.1)

# ラベルの配置
for j, image in enumerate(x_test):
    bg_color = 'skyblue' if y_test[j] == np.argmax(result[j]) else 'red'
    image_array[j].text(0, 0, np.argmax(result[j]), color='black', backgroundcolor=bg_color)
    plt.pause(0.1)

# 画像全体の保存
plt.savefig('judge_result.png')

추론에 대해



학습·평가를 하고 인식률의 그래프를 출력하는 곳까지는 여러가지 사이트에 써 있습니다.
그러나, 학습한 모델로 추론해 뭔가 한다, 라고 하는 곳을 써 있는 것은 의외로 적거나 합니다. (자신의 체감이지만…)
x_test_flat = x_test.reshape(-1, 784) / 255.0
result = model.predict(x_test_flat)
tensorflow.keras.models 에는 predict() 라는 메소드가 있어 추론에서는 이것을 사용합니다.
이 메소드에, 추론하고 싶은 이미지의 배열을 건네줍니다. 모델의 입력이 1차원이므로 reshape(-1, 784) 에서 1차원 배열로 변환했습니다.
이번에는 40장을 한 번에 처리하므로 (40, 784) 의 배열을 건네줍니다만, 1장만 처리할 때도 (1, 784) 로서 건네줄 필요가 있습니다.

Chainer에서는 result = model.predictor(test_x).data[0] 라고 기술하는 것으로 추론이 가능합니다.

이미지 표시 정보



matplotlib에서는 객체 지향적으로 작성할 수 있습니다.
plt.figure(figsize=(5, 5))
image_array = []
for i, image in enumerate(x_test):
    image_array.append(plt.subplot(row, col, i + 1))
    plt.axis('off')
    plt.imshow(image, cmap='gray')
plt.pause(0.05)

첫 번째 for문에서 plt.subplot 를 사용하여 이미지를 정렬합니다.plt.subplotfigure 에 대한 자식 요소입니다. 인수에 (세로, 가로, 몇 번째)를 전달합니다.
숫자를 나타내는 숫자는 1부터 계산합니다. (0부터가 아니므로 주의)
먼저 전체 이미지를 표시하고 나중에 자식 요소에 라벨을 추가하기 때문에 조작할 수 있도록 append() 해 둡시다.
(단지 표시하는 것만이라면 배열에 넣어 둘 필요는 없습니다만, 나중에 라벨을 추가하고 싶으므로, 이번은 이렇게 해 둡니다.)
또한 이번에는 그래프가 아닌 이미지이므로 plt.axis('off') 로 좌표 표시를 지우고 있습니다.
한 번에 늘어놓을 수 있는 화상의 준비가 끝나면, plt.pause() 로 화상을 표시합니다.plt.show() 로 해 버리면 거기에서 처리가 멈추어 버리므로 plt.pause() 를 사용하고 있습니다.
for j, image in enumerate(x_test):
    bg_color = 'skyblue' if y_test[j] == np.argmax(result[j]) else 'red'
    image_array[j].text(0, 0, np.argmax(result[j]), color='black', backgroundcolor=bg_color)
    plt.pause(0.05)

두 번째 for문에서는 image_array의 요소에 하나씩 레이블을 추가합니다.image_array[j].text(0, 0, np.argmax(result[j]) 에서 추론 결과를 이미지에 추가합니다.
추론 결과 np.argmax(result[j]) 와 정답 라벨 y_test[j] 가 일치하면 배경을 파란색으로, 불일치라면 빨강으로 해 보았습니다.plt.pause() 에서 화면에 레이블을 표시합니다. 인수는 이미지를 표시하는 시간이며, 이 수치를 바꾸는 것으로 표시의 갱신 속도가 바뀝니다. 어디까지나 "표시"의 갱신 속도이며, 모델의 처리 속도가 아니기 때문에 주의합시다.

3. 정리



이번에 작성한 것입니다만, 학생실험의 프리젠테이션으로 사용하고 싶었지만, 당시의 자신은 실장할 수 없고…
지금 생각하면 간단해도, 학습 단계에 따라서는 어려운 일이 있지요.
Deep Learning을 시작한지 ​​얼마 안된 사람이나 프로그래밍 강의를 받았지만 싫어서 별로 이해할 수 없다는 사람에게 도착하면 다행입니다.

좋은 웹페이지 즐겨찾기