02. Linear Model

Sung Kim님의 유투브 강의 자료인 PytorchZeroToAll를 바탕으로 공부한 내용에 대한 글입니다.

Machine Learning


01.Overview에서 다룬 4시간 공부한 학생이 받게될 성적을 예측하는 모델을 만든다고 생각해보자. Label을 가진 training dataset으로 모델을 학습하고, test dataset으로 모델의 성능을 검증해본다. 이러한 종류의 머신러닝을 Supervised Learning(지도 학습)이라고 한다.

Model design


모델을 설계하기 위해 먼저 해야할 일은 입력과 출력 데이터의 관계를 파악하는 것이다. 위와 같은 데이터의 경우 입력과 출력 데이터는 linear(선형) 관계를 가지고 있다. 선형 모델은 y^=wx+b\hat y = w * x + b

1. Linear Regression


지금 문제와 같이 주어진 선형 데이터가 Linear Regression의 가장 간단한 예이다. 예측 모델의 식은 y^=wx\hat y = w * x

2. Training Loss(error)


loss=(y^y)2=(xwy)2loss = (\hat y-y)^2 = (x*w - y)^2

Linear Regression에서 loss는 실제 값인 yy와 예측값인 y^\hat y

loss=1Nn=1N(y^nyn)2loss = {1\over N}\sum_{n=1}^N(\hat y_n-y_n)^2

Training loss를 구한 방식은 모든 점에서의 예측값과 실제값의 차를 제곱해서 그 개수만큼 나누므로 평균을 구하는 것과 같아서 이 방법을 MSE(Mean Square Error)라고 부른다.

3. Loss graph


ww에 따른 loss를 그래프에 찍은 후 그 점들을 이어보면 위와 같은 loss graph를 그려볼 수 있다. 이 그래프를 통해 loss가 가장 작은 값을 갖게하는 weight를 구할 수 있다.

4. Code


import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# our model for the forward pass
def forward(x):
    return x * w
    
# Loss function
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) * (y_pred - y)
    
# List of weights/Mean square Error (Mse) for each input
w_list = []
mse_list = []

for w in np.arange(0.0, 4.1, 0.1):
    # Print the weights and initialize the lost
    print("w=", w)
    l_sum = 0

    for x_val, y_val in zip(x_data, y_data):
        # For each input and output, calculate y_hat
        # Compute the total loss and add to the total error
        y_pred_val = forward(x_val)
        l = loss(x_val, y_val)
        l_sum += l
        print("\t%.1f %.1f %.1f %.1f" %(x_val, y_val, y_pred_val, l))
    # Now compute the Mean squared error (mse) of each
    # Aggregate the weight/mse from this run
    print("MSE=%5.3f" %(l_sum/len(x_data)))
    w_list.append(w)
    mse_list.append(l_sum / len(x_data))
    
# Plot it all
plt.plot(w_list, mse_list, '-b')
plt.plot(w_list, mse_list, '.r')
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

좋은 웹페이지 즐겨찾기