scikit-learn, Spark.ml, TensorFlow에서 선형 회귀 ~ (1) 소개

기계 학습 라이브러리로 유명한 scikit-learn, Spark.ml 및 Tensorflow 선형 회귀 라이브러리를 사용해 보았습니다.
언어는 Python (3.5), 사용한 라이브러리는 상기 이외에서는 numpy, matplotlib, csv.
기계 학습 라이브러리의 버전은
  • scikit-learn (0.18.1)
  • spark (2.1.0)
  • tensorflow (1.1.0)

  • OS는 MacOS 10.12입니다.

    1. 선형 회귀



    선형 회귀는 주어진 데이터 (x1, y1), (x2, y2) ...에서 y = ax + b의 관계를 추정하는 것 (자세한 내용은 선형 회귀(wikipedia) 등 참조)이므로 먼저 대상 데이터를 만듭니다. 합니다.
    ax + b 의 a 와 b 를 주면, y±d 의 범위에서 랜덤하게 값을 생성하는 함수를 numpy 의 rand() 와 행렬 연산을 이용해 작성합니다.

    makeDataLR.py
    import numpy as np
    from numpy.random import rand
    
    def makeDataForLR(a, b, n=100, d=0.1, xs=0, xe=10):
        x = rand(n) * (xe - xs) + xs
        r = rand(n) * 2*d - d
        y = x * a + b + r
        return x,y
    

    numpy.random.rand(n) 은 0.0 ~ 1.0 내에서 n개의 랜덤 숫자를 생성하는 함수로, rand(n) * 3 이라고 하면 0.0 ~ 3.0 에서 rand(n) + 2 로 하면 , 2.0 ~ 3.0 사이의 임의의 숫자를 얻을 수 있습니다.
    >>> rand(10) * (10-5) + 5  # 5.0〜10.0 内のランダム数
    array([ 6.21226444,  6.77468084,  9.36730437,  5.11593757,  5.38383768,
            7.87395788,  9.63988158,  8.28096493,  5.0125407 ,  8.60225573])
    >>> rand(10) * 2*1 - 1   # -1.0〜1.0 内のランダム数
    array([ 0.4029865 , -0.31802214, -0.71503869, -0.71740942,  0.05573439,
           -0.85997408, -0.91677018, -0.72540234,  0.12157467, -0.77786667])
    

    작성한 데이터를 파일에 저장합니다. csv 라이브러리를 사용하면 numpy 행렬 데이터를 그대로 쓸 수 있습니다.

    makeDataLR.py
    import csv
    def writeArrayWithCSV(dataFile, data):
        f = open(dataFile, 'w')
        writer = csv.writer(f, lineterminator='\n')
        writer.writerows(data)
        f.close()
    
    # a=0.4, b=0.8, ax+b±0.2 のデータを x=0〜10の範囲で100個作成する                       
    x,y = makeDataForLR(0.4, 0.8, 100, 0.2, 0, 10)
    
    # x,y を結合して、csvファイルとして保存                           
    dataFile = 'sampleLR.csv'
    xy = np.c_[x, y]
    writeArrayWithCSV(dataFile, xy)
    

    자, 실제로 어떤 데이터가 생겼는지 시각화해 봅시다. matplotlib의 산점도를 사용합니다.

    makeDataLR.py
    import matplotlib.pyplot as plt
    def plotXY(title, x, y):
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.scatter(x, y)
        ax.set_title(title)
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        #fig.show()                                                                 
        imageFile = title + '.png'
        fig.savefig(imageFile)
    
    title = 'sampleLR'
    plotXY(title, x, y)
    



    MacOS에서 matplotlib 이 오류가 발생하면 ~/.matplotlib/matplotlibrc 를 다음 내용으로 만들면 사용할 수 있다고 생각합니다.

    ~/.matplotlib/matplotlibrc
    backend : TkAgg
    

    마지막으로 데이터 작성 프로그램의 전체입니다. 실제 선형 회귀 방법은 다음 기사에서.

    makeDataLR.py
    #!/usr/bin/env python                                                                
    
    import numpy as np
    from numpy.random import rand
    
    # xs-xe の x について、ax + b ± d の値を N 個作成する                  
    def makeDataForLR(a, b, n=100, d=0.1, xs=0, xe=10):
        x = rand(n) * (xe - xs) + xs
        r = rand(n) * 2*d - d
        y = x * a + b + r
        return x,y
    
    # numpy.array を csv に書き込む                                              
    import csv
    def writeArrayWithCSV(dataFile, data):
        f = open(dataFile, 'w')
        writer = csv.writer(f, lineterminator='\n')
        writer.writerows(data)
        f.close()
    
    # x,y の散布図                                                                
    import matplotlib.pyplot as plt
    def plotXY(title, x, y):
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.scatter(x, y)
        ax.set_title(title)
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        #fig.show()                                                                      
        imageFile = title + '.png'
        fig.savefig(imageFile)
    
    # a=0.4, b=0.8, ax+b±0.2 のデータを x=0〜10の範囲で100個作成する                     
    x,y = makeDataForLR(0.4, 0.8, 100, 0.2, 0, 10)
    
    # x,y を結合して、csvファイルとして保存                                              
    title = 'sampleLR'
    dataFile = title + '.csv'
    xy = np.c_[x, y]
    writeArrayWithCSV(dataFile, xy)
    
    # データをプロット                                                                   
    plotXY(title, x, y)
    

    좋은 웹페이지 즐겨찾기