신경망으로 손글씨 숫자 인식하기

1. MNIST 인식 모델 Predict 예제

1-1. MNIST 소개

MNIST는 기계학습 공부, 연구 등 아주 다양한 분야에서 흔하게 사용되는 수기 숫자 데이터셋이다. 6만 장의 학습용 이미지, 1만 장의 테스트 이미지로 구성되어 있다.

  • 학습용 이미지는 기계학습 모델을 학습(훈련) 시키는 데 사용되며,
  • 테스트 이미지는 기계학습 모델이 얼마나 잘 훈련되었는지 테스트하는 데 사용된다.

MNIST의 각 이미지는 28*28 픽셀로 되어 있으며, 1채널(회색조, 흑백)이다. 각 픽셀의 값은 0~255까지의 값을 가지고 있다. 이 값은 연하게 칠해졌을수록(배경에 가까운) 작게 나타나고, 진하게 칠해졌을수록 크게 나타난다.

또, 각 이미지에는 해당 이미지가 어떤 숫자를 쓴 이미지인지 레이블이 붙어 있다.

(* 데이터-레이블 쌍이라는 표현을 쓴다. 문제집에 비유하자면 데이터는 '문제'이고, 레이블은 '답지'이다. 데이터를 보고 결과를 예측하도록 하고, 결과가 레이블과 얼마나 차이가 나는지 확인하여 신경망의 가중치와 편향을 조절하는 것이다. 대부분의 경우 데이터와 그 데이터에 대응하는 레이블이 있어야 기계 학습을 돌릴 수가 있다.)

1-2. 이번 목표

이미 학습이 어느 정도 되어 있는 모델의 가중치와 편향을 그대로 Load하여 신경망의 predict만을 테스트 하는 것이 목표이다.

신경망을 '학습' 시키는 부분은 추후에 다룰 예정이다. 그 때에는 본서의 저자가 제공하는 MNIST 데이터를 사용하지 않고, 직접 MNIST 데이터를 다운로드 받아 사용해볼 계획이다.

1-3. dataset.mnist 모듈

https://www.hanbit.co.kr/store/books/look.php?p_code=B8475831198

위 주소의 '부록/예제소스' 탭에 들어가면 예제 소스와 dataset 패키지가 압축된 파일을 내려받을 수 있다.

dataset.mnist 모듈은 load_mnist() 함수를 이용해 MNIST 데이터와 레이블을 가져오도록 할 수 있다. import 방법은 아래와 같다.

from dataset.mnist import load_mnist

load_mnist() 함수는 MNIST 데이터를 아래와 같이 반환한다.

(training_image, training_label), (test_image, test_label) = load_mnist()

load_mnist() 함수의 인수로는 normalize:bool, flatten:bool, one_hot_label:bool 세 가지를 넘겨줄 수 있다.

  • normalize = True: MNIST 데이터셋의 데이터를 0.0~1.0 사이로 정규화
  • flatten = True: MNIST 데이터셋의 데이터를 1차원 배열로 reshape하여 반환. 28*28 Image이므로, 28*28 = 784개의 원소로 이뤄진 1차원 배열을 반환함.
  • one_hot_label = True: MNIST 데이터셋의 레이블을 *One-Hot 방식으로 인코딩하여 내놓음. Default는 False. False인 경우 레이블은 0~9 사이의 숫자를 반환한다.

(* One-Hot Encoding이란, 정답을 뜻하는 원소만을 1로, 나머지는 0으로 설정한 배열을 내놓는 것을 말한다. 예를 들어, 레이블이 3이라면 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]을 레이블로 내놓는 것이다. 다중 Class 분류 시 주로 사용되는 방식이다.)

먼저, 데이터 이미지 중 하나를 화면에 띄워서 데이터를 확인해보자!

import sys, os
sys.path.append(os.pardir)
import numpy as np
from dataset.mnist import load_mnist
from PIL import Image

(x_train, t_train), (x_text, t_test) = load_mnist(flatten=True, normalize=False)

img = x_train[10]
label = t_train[10]
print(label)

print(img.shape)
img = img.reshape(28, 28)
print(img.shape)

pil_img = Image.fromarray(np.uint8(img))
pil_img.show()

1-4. 신경망을 이용한 추론 처리

import sys, os
sys.path.append(os.pardir)
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax


def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test

def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    p = np.argmax(y)
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

  1. get_data() 함수에서 load_mnist() 함수를 호출해 테스트 데이터셋을 얻어온다. 이번 목표는 이미 학습된 모델의 가중치와 편향을 그대로 이용해 신경망의 추론 처리를 구현하는 것이지, 모델을 학습 시킬 것이 아니기 때문에 Training data는 얻어오지 않고 Test data만 얻어오도록 했다.
  2. init_network() 함수에서는 Pickle 패키지를 이용해 파일로 저장된, 이미 학습되어 있는 모델의 Network를 load한다.
  3. predict() 함수에서는 Neural Network와 데이터(x)를 인자로 받아 추론 처리를 진행하고, 출력층에서 Softmax 함수를 이용해 각 레이블에 해당할 확률이 저장된 길이 10의 배열을 반환한다.
  4. for문 내에서는 network와 각 Test Data를 predict() 함수에 하나씩 넘겨주며 추론 처리를 진행한다. 그리고 np.argmax() 함수를 이용해 가장 큰 확률(최댓값)이 저장된 인덱스 번호를 얻어낸다.
  5. 인덱스 번호와 레이블(정답)이 일치하면 정확히 추론한 것이므로, 정확도 계산을 위한 accuracy_cnt 변수를 1 증가시킨다.

1-5. 정규화와 전처리

MNIST의 경우 각 픽셀이 가진 값이 0~255 사이이므로, 굳이 0.0~1.0 사이의 숫자로 변환하지 않고 그대로 모델을 학습시키거나 추론을 해도 올바른 결과가 나올 것이다. 그러나 실전적인 모델을 개발할 때에는 그렇지 않다. 데이터 한 조각의 크기가 엄청나게 크거나 작을 수 있다. 이런 경우에는 모든 값을 어떤 숫자로 나누는 등 처리를 하여 일정 범위 내에 집어넣어야 효율적으로 학습, 추론을 할 수 있다. 이렇게 데이터를 특정 범위로 변환하는 처리를 정규화(normalization)라고 하며, 신경망의 입력 데이터에 어떠한 변환을 가하는 것을 전처리(preprocessing)라고 한다.

예제의 경우 "입력 이미지 데이터에 대한 전처리로 정규화를 수행했다"고 표현할 수 있다.

좋은 웹페이지 즐겨찾기