TensorFlow.js 학습 메모 ① k 근방법(k-nearlest neighbor)으로 좌표로부터 집의 가격을 예측

소개



일의 관계에서 기계 학습에 대해 학습할 필요성이 나왔기 때문에, 알고리즘을 몇개인가 가볍게 공부하기로 했습니다.
Python이 얽히면 학습 부하가 단번에 높아져 버리는 생각이 들었으므로, 학습은 JavaScript(TensorFlow.js)로 실시했습니다.

먼저 좌표와 바닥 면적을 입력하여 k 근방법으로 집 가격을 예측하는 알고리즘을 구현했습니다.

거의 학습 메모 같은 느낌이므로, 깊은 내용을 요구하시는 분은 그 점 양해 바랍니다.

TensorFlow.js



TensorFlow.js는 Python의 numpy와 같은 조작을 할 수있는 라이브러리입니다.
JS로 쉽게 행렬 계산을 할 수 있습니다.

k근방법(k-nearest neighbor)



입력값에 대한 예측값을 내는 알고리즘의 하나입니다.
지정한 좌표(lat, long)에서 집 가격(price)을 예측하는 경우를 생각하면 예측까지 다음과 같은 흐름이 됩니다.
  • 좌표와 집 가격에 대한 실제 데이터를 많이 모으기 (Training data)
  • 가격을 알고 싶은 좌표를 지정하여 이전 단계에서 수집 한 좌표 데이터와의 차이를 각각 취한다.
  • 차이가 작은 순서로 데이터를 정렬합니다.
  • 차등이 작은 순서로 지정된 개수 (k) 분의 데이터를 취득한다
  • 취득한 데이터의 가격의 평균값을 취한다 (예측치)

  • 이상적인 k 값을 찾으려면 Training data와 Test data를 각각 준비합니다.
    knn에서 Test data의 좌표에서 가격을 내고 Test data의 가격과의 차이가 작아지도록 k를 조정합니다.

    또한 좌표뿐만 아니라 바닥 면적(sqft_lot)이나 거실 면적(sqft_living) 등의 여러 요소를 고려하면 보다 정밀한 가격 예측을 할 수 있을 것입니다.

    전제 지식



    구현시에 전제가 되는 초기본 지식을 정리했습니다.

    분류 및 회귀



    값 예측을 위해 분류(Classification)와 회귀(Regression)에 대해 이해해야 합니다.
    입력값에 대해 그 값이 합격인지 불합격인지를 예측하고 싶은 경우에는 분류, 어떠한 값이 되는지를 예측하고 싶은 경우에는 회귀를 사용합니다.

    정규화(Normalization) 및 표준화(Standardization)



    데이터를 최대값이 1, 최소값이 0인 데이터가 되도록 변환하는 것을 정규화, 원래 데이터의 평균을 0, 표준 편차가 1인 것으로 변환하는 것을 표준화라고 합니다.

    최대치 및 최소치가 정해져 있는 경우(화상 처리라든지) 등에는 정규화를 이용합니다.
    한편, 최대치 및 최소치가 정해져 있지 않은 경우나 이상값이 존재하는 경우에는, 가중을 학습하기 쉽게 하기 위해서 표준화를 이용합니다.

    정규화 공식


    표준화 공식 (μ : 평균, σ : 표준 편차)




    구현



    구현의 흐름을 정리했습니다.

    데이터 준비



    예측에 사용할 데이터( kc_house_data.csv )를 준비합니다.
    Training data와 Test data를 이중에서 추출합니다.


    index.js 만들기



    이 중에 데이터의 예측 처리를 작성합니다.

    라이브러리 로드



    index.js
    require('@tensorflow/tfjs-node');
    const tf = require('@tensorflow/tfjs');
    const loadCSV = require('./load-csv');
    

    knn 처리



    index.js
    //一番近いpriceを探索するknn
    function knn(features, labels, predictionPoint, k) {
      const { mean, variance } = tf.moments(features, 0); //平均値と分散の取得、第2引数でrowかcolumnを指定
    
      const scaledPrediction = predictionPoint.sub(mean).div(variance.pow(0.5)); //入力値の標準化 (入力値-平均値)/標準偏差
    
      return (
        features
          .sub(mean)
          .div(variance.pow(0.5)) //Training dataの標準化
          .sub(scaledPrediction) //入力値(標準化済)との差
          .pow(2) //各要素を2乗
          .sum(1) //各要素の和
          .pow(0.5) //ルートをとる(distance算出)
          .expandDims(1) //連結するためにdimを操作
          .concat(labels, 1) //labels(price)と連結
          .unstack() //sort, sliceの操作を行うためにobj化
          .sort((a, b) => (a.get(0) > b.get(0) ? 1 : -1)) //distance小さい順にソート
          .slice(0, k) //k個データを取得
          .reduce((acc, pair) => acc + pair.get(1), 0) / k //priceの平均値(knnの結果)
      );
    }
    

    데이터 검색


    splitTest:10 에서 Test data(testFeatures, testLabels)를 무작위로 가져오고, 그 외를 Training data(features, labels)로 가져옵니다.

    index.js
    let { features, labels, testFeatures, testLabels } = loadCSV(
      'kc_house_data.csv',
      {
        shuffle: true,
        splitTest: 10, //testFeatures, testLabelsの数を指定
        dataColumns: ['lat', 'long', 'sqft_lot', 'sqft_living'], //featureカラムを指定 指定数増やすと精度上がる
        labelColumns: ['price'], //labelカラムを指定
      }
    );
    
    features = tf.tensor(features);
    labels = tf.tensor(labels);
    

    예측값 출력



    10개의 Test data(testFeatures)를 knn으로 처리한 결과(Guess)와 testLabels와의 괴리 정도(Error)를 출력했습니다.

    index.js
    testFeatures.forEach((testPoint, i) => {
      //testFeature10個それぞれのerrをみる
      const result = knn(features, labels, tf.tensor(testPoint), 10); //knnの結果(予測値)
    
      const err = (testLabels[i][0] - result) / testLabels[i][0]; //Test dataのpriceと予測値の乖離度合
    
      console.log('Error', err * 100);
      console.log('Guess', result, testLabels[i][0]);
    });
    

    출력 결과
    Error -15.323502304147466
    Guess 1251260 1085000
    Error -11.344580119965723
    Guess 519756.5 466800
    Error -2.047058823529412
    Guess 433700 425000
    Error 19.327433628318584
    Guess 455800 565000
    Error 7.806324110671936
    Guess 699750 759000
    Error -14.106372465729613
    Guess 584260 512031
    Error -8.782552083333334
    Guess 835450 768000
    Error 13.227406199021207
    Guess 1329790 1532500
    Error -36.336911441815076
    Guess 279422.5 204950
    Error 7.381578947368421
    Guess 228767.5 247000
    

    결론



    정말 만지기 때문에 수식을 알면 거기까지 구현은 어렵지 않다는 인상.

    참고 자료

    좋은 웹페이지 즐겨찾기