Pytorch×MNIST 필기 숫자 인식 PNG 이미지를 입력으로 예측해 본다

pytorch로 이미지 인식 모델을 만들어. 테스트 데이터를 평가한다, 라고 기사는 많이 있었습니다만, JPEG라든지 PNG라든지의 화상을 실제로 읽어 예측해 보았던 기사가 그다지 없는 생각이 들었으므로, 정리해 보았습니다.

이번 목표



PNG 이미지를 PyTorch로 만든 학습 모델을 통해 예측해 봅니다.
모델은 MNIST의 필기 숫자 인식을 사용합니다.

학습 모델 구축



Google Colaboratory에서 PyTorch에서 MNIST를 학습한 모델을 저장하고 읽고 사용하는 간단한 샘플 - 인공 지능 프로그래밍 블로그
이 기사를 참고로 학습 모델을 만듭니다.
움직이면 1,725,616바이트의 mnist_cnn.pt가 생겼습니다.

기계 학습 모델을 사용하여 예측



PyTorch 1.1 Tutorials : 이미지 : PyTorch를 사용한 화풍 변환 – PyTorch
이 기사를 참고로 코드를 작성했습니다.
mnist_cnn.pt 및 필기 숫자 이미지 파일을 준비한 후 다음 코드를 실행합니다.

모델 정의 및 로드



먼저 모델을 정의하고 로드합니다.
# 必要なモジュールを読み込む
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from PIL import Image, ImageOps

# モデルの定義
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

device = torch.device("cpu")
model = 0
model = Net().to(device)
# 学習モデルをロードする
model.load_state_dict(torch.load("/適当なPath/mnist_cnn.pt", map_location=lambda storage, loc: storage))
model = model.eval()

이미지 로드



그런 다음 이미지를 로드합니다. JPEG에서도 PNG에서도 괜찮을 것입니다.
# 画像ファイルを読み込む
image = Image.open("/適当なPath/mnist_9_70x70.png")
# convert('L')でグレースケールに変換する。
# そして画像のサイズを28ピクセル四方にリサイズします。
# さらにinvertで白黒変換する。画像は文字部分が0(黒)、背景が白(1)で学習元のデータと反対のため。
image = ImageOps.invert(image.convert('L')).resize((28,28))
# データの前処理の定義
# transforms.Normalize((0.1307,), (0.3081,)は学習元データと同様の正規化を行ってる。
# 0.1307を平均、0.3081を標準偏差に指定しています
transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])
# 元のモデルに合わせて次元を追加
image = transform(image).unsqueeze(0)

예측



마지막으로 예측(Predict)
# 予測を実施
output = model(image)
_, prediction = torch.max(output, 1)
# 結果を出力
print("result=" + str(prediction[0].item()))

이렇게하면 결과가 표시됩니다.

입력 이미지


결과
result=9

요약



Chainer를 만진 적이 있지만, PyTorch는 사용하기 시작한 이틀 초보자이지만 구구는 조사하면서 이미지 인식 코드를 쓰고 움직일 수있었습니다.
똑똑하지 않다고 하고 있을지도 모르기 때문에, 지적해 주시면 기쁩니다.

덧붙여서, Flask의 어드벤트 캘린더로 화상 인식 앱의 재료를 쓰려고 생각해 PyTorch의 부분의 코드 써 있으면 한 기사 정도 정도의 양이 되었으므로 PyTorch의 어드벤트 캘린더에도 등록해 보았습니다.

다음은 YOLO를 움직이고 싶다!

좋은 웹페이지 즐겨찾기