【기계 학습 입문】k-근방법으로 간단한 분류 문제를 풀어 보자

이 기사 개요



가장 간단한 학습 알고리즘인 k-근방법(k-NN: k-nearest neighbor)을 사용하여 간단한 분류 문제를 풀어 실천적으로 사용법을 배웁니다.

목차



1. k-근방법(k-NN)의 개요
2. 데이터 생성
3. k-근방법에 의한 분류
4. 결론

1. k-근방법(k-NN)의 개요



k-최근 방법은 교사 있어 학습의 분류 문제에 이용되는 기계 학습 수법입니다. 최근 이웃의 데이터를 k개 가져와서 그들이 가장 많이 소속하는 클래스로 분류합니다. 요컨대, 다수결을 하고 있는 것입니다. k의 수는 임의로 결정되는 하이퍼 파라미터이며, k의 수가 너무 작으면 이상치(노이즈)에 약하고, 너무 많으면 정밀도가 나빠집니다.

k-근방법의 이미지도입니다. 테스트 데이터는 빨간색 별 플롯으로 표시됩니다.


테스트 데이터에서 트레이닝 데이터까지의 거리를 측정하는 것은 일반적으로 유클리드 거리를 사용합니다. 유클리드 거리는 아래 그림과 같이 사람이 눈금자로 2점 사이를 측정하는 거리를 말합니다.



유클리드 거리를 식으로 나타내면 아래와 같습니다.
p = (p_1, p_2, \cdots , p_n), q = (q_1, q_2, \cdots , q_n)\\
d(p,q)=\sqrt{(q_1-p_1)^2 + (q_2-p_2)^2 + \cdots + (q_n-p_n)^2} = \sqrt{\sum_{i=1}^{n} (q_i-p_i)^2}

위 그림의 플롯 p, q를 계산해 보면,
d(p,q) = \sqrt{(3-(-1))^2+(2-(-1))^2} = 5

네요.

이 거리를 이용하는 것으로 테스트 데이터로부터 각 훈련 데이터의 거리를 측정해, 거리가 가까운 k개의 훈련 데이터의 라벨로 제일 많은 클래스로 분류하는 것이, k-근방법의 알고리즘이 됩니다.

2. 데이터 생성



먼저 numpy를 활용하여 분석하는 데이터를 작성해 봅시다.
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# クラス0、1に対応するデータ点の作成
x0 = np.random.normal(size=50).reshape(-1, 2) - 1 #(-1, -1)を中心とした正規分布
x1 = np.random.normal(size=50).reshape(-1, 2) + 1 #(1, 1)を中心とした正規分布
X_train = np.concatenate([x0, x1])

# 教師データの作成
y_train = np.concatenate([np.zeros(25), np.ones(25)]).astype(np.int)

#データ形状の確認
print("x0:{}, x1:{}, X_train:{}, y_train:{}".format(x0.shape, x1.shape, X_train.shape, y_train.shape))
> x0:(25, 2), x1:(25, 2), X_train:(50, 2), y_train:(50,)

클래스 0과 클래스 1의 데이터 포인트가 각각 25점씩 있습니다. 그래프를 시각화해 봅시다.
sns.scatterplot(X_train[:,0], X_train[:,1], hue=y_train)
plt.xlim(-3.5, 3.5)
plt.ylim(-3.5, 3.5)

클래스별로 색으로 구분하고 있습니다.


3. k-근방법에 의한 분류(k-근방법에 의한 분류)



이제 작성한 모델을 학습하고 새로운 점에 대한 분류 문제를 풀어 봅시다.
# 予測したいデータ点の作成
x = np.random.normal(size=2).reshape(-1, 2)
print(x)
> [[-0.56838004 -1.52474231]]

from sklearn.neighbors import KNeighborsClassifier
# kの設定
n_neighbors = 3
# モデルの学習
knc = KNeighborsClassifier(n_neighbors=n_neighbors).fit(X_train, y_train)
# テストデータの分類予測
y_pred = knc.predict(x)
print(y_pred)
> [0] # 予測結果=クラス0

# 可視化
sns.scatterplot(X_train[:,0], X_train[:,1], hue=y_train)
sns.scatterplot(x[:,0], x[:,1], color='r', marker='*', s=300, label='test_data')

예측하고 싶은 데이터 점을 적성으로 플롯하고 있습니다만, 예측 결과(y_pred)가 클래스 0으로 분류되어 있어 그래프로부터도 타당한 결과인 것을 알 수 있습니다.



처음에 하이퍼파라미터 k에 대해 언급했지만 k가 분류 경계에 미치는 영향을 시각화하고 확인해 보겠습니다.
xx0, xx1 = np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100))
X_test = np.array([xx0, xx1]).reshape(2, -1).T

fig, axes = plt.subplots(1, 5, figsize=(25,4))

for i, ax in zip(range(5), axes.flat):
  n_neighbors = i*2 + 1
  knc = KNeighborsClassifier(n_neighbors=n_neighbors).fit(X_train, y_train)
  y_pred = knc.predict(X_test)

  sns.scatterplot(X_train[:, 0], X_train[:, 1], hue=y_train, ax=ax)
  axes[i].contourf(xx0, xx1, y_pred.reshape(100, 100).astype(dtype=np.float), alpha=0.2, levels=np.linspace(0, 1, 3), cmap='bwr')
  axes[i].set_title('n_neighbors={:.0f}'.format(n_neighbors))

plt.show()

k=1~k=9일 때의 분류 경계를 나타냅니다. k=1일 때는 클래스 0 안에 클래스 1의 분류 범위가 있는 것을 알 수 있고, 분류 경계가 복잡해져 훈련 데이터에 과잉 적합하는 결과가 됩니다(범화 성능이 낮다). k를 늘려가면 경계가 매끄럽게 되어 단순한 모델이 되어 간다고 말할 수 있습니다.



4. 결론



이 기사에서는 가장 간단한 학습 알고리즘인 k-근방법(k-NN: k-nearest neighbor)을 이용하여 간단한 분류 문제를 해결했습니다. k의 값을 변경하여 분류 모델의 복잡성을 제어할 수 있다는 것을 배웠습니다.

기계 학습 입문 링크



앞으로도 계속 기계 학습의 구현 입문을 기사로 해 나갈 것입니다.
참고로 부디.

· 선형 회귀 모델
· 비선형 회귀 모델
・k-근방법 본 기사
· 로지스틱 회귀 모델
· 주성분 분석
· K-means
・지원 벡터 머신 ※향후 갱신 예정

좋은 웹페이지 즐겨찾기