Optuna로 Prophet을 교잡해 봤어요.

  • 제조업 출신 데이터 과학자가 보낸 보도
  • 이번에 우리는 Optuna로 시간 서열 분석 방법의 Prophet를 교잡했다
  • 입문


    과거에 시간 서열 분석에 대해 몇 가지를 정리했으니 관심 있는 사람은 참조하세요.
  • 시간 시퀀스 분석
  • 상태 공간 모델
  • 시간 시퀀스 예측을 위해 GBDT를 사용해 보십시오.
  • 시간 시퀀스 분석 방법을 시도한 SARIMA 모델
  • Prophet


    Prophet는 페이스북이 공개한 오픈 타임 시퀀스 분석 라이브러리다.
    Prophet의 장점은 대체로 다음과 같다.
    1. 모델 디자인의 유연성 높음
    2. 샘플 간격 유연성 높음
    3. 빠른 계산 가능
    4. 매개 변수는 이해하기 쉽다
    자세한 내용은 Prophet 공식 문서 을 참조하십시오.

    Prophet 설치


    이번에 사용한 데이터도 매달 비행기의 승객 수 데이터다.
    
    # ライブラリーのインポート
    import pandas as pd
    import numpy as np
    import seaborn as sns
    import optuna
    from matplotlib import pylab as plt
    %matplotlib inline
    
    from fbprophet import Prophet
    from fbprophet.plot import add_changepoints_to_plot
    
    import warnings
    warnings.filterwarnings("ignore")
    
    # https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/AirPassengers.html
    df = pd.read_csv('../data/AirPassengers.csv')
    
    # float型に変換
    df['#Passengers'] = df['#Passengers'].astype('float64')
    df = df.rename(columns={'#Passengers': 'Passengers'})
    
    # datetime型に変換にする
    df.Month = pd.to_datetime(df.Month)
    
    # データの中身を確認
    df.head()
    
    # データの可視化
    fig, ax = plt.subplots()
    a = sns.lineplot(x="Month", y="Passengers", data=df)
    plt.show()
    

    Prophet에서 모델을 구축합니다.
    def objective_variable(train, valid):
    
        def objective(trial):
                params = {
                        'changepoint_range' : trial.suggest_discrete_uniform('changepoint_range', 0.8, 0.95, 0.001),
                        'n_changepoints' : trial.suggest_int('n_changepoints', 20, 35),
                        'changepoint_prior_scale' : trial.suggest_discrete_uniform('changepoint_prior_scale',0.001, 0.5, 0.001),
                        'seasonality_prior_scale' : trial.suggest_discrete_uniform('seasonality_prior_scale',1, 25, 0.1),
                }
    
                # fit_model
                model = Prophet(
                    changepoint_range = params['changepoint_prior_scale'],
                    n_changepoints=params['n_changepoints'],
                    changepoint_prior_scale=params['changepoint_prior_scale'],
                    seasonality_prior_scale = params['seasonality_prior_scale'],
                )
    
                model.fit(train)
                future = model.make_future_dataframe(periods=len(valid))
    
                forecast = model.predict(future)
                valid_forecast = forecast.tail(len(valid))
    
                val_mape = np.mean(np.abs((valid_forecast.yhat-valid.y)/valid.y))*100
    
                return val_mape
    
        return objective
    
    def optuna_parameter(train, valid):
        study = optuna.create_study(sampler=optuna.samplers.RandomSampler(seed=10))
        study.optimize(objective_variable(train, valid), timeout=500)
        optuna_best_params = study.best_params
    
        return study
    
    df = df.rename(columns={'Month':'ds','Passengers':'y'})
    df = df[['ds','y']]
    df_train = df[df['ds'] < '1956-04-01']
    df_valid = df[(df['ds'] >= '1956-04-01')&(df['ds'] < '1957-04-01')]
    df_test = df[df['ds'] >= '1957-04-01']
    
    study = optuna_parameter(df_train, df_valid)
    
    그리고 조정 결과의 초 파라미터를 이용하여 모델의 학습과 예측을 한다.
    # fit_model
    best_model = Prophet(
        changepoint_range = study.best_params['changepoint_prior_scale'],
        n_changepoints=study.best_params['n_changepoints'],
        seasonality_prior_scale = study.best_params['seasonality_prior_scale'],
        changepoint_prior_scale=study.best_params['changepoint_prior_scale'],
    )
    
    best_model.fit(df_train)
    feature_test = best_model.make_future_dataframe(periods=len(df_valid)+len(df_test), freq='M')
    
    forecast_test = best_model.predict(feature_test)
    forecast_test_plot = best_model.plot(forecast_test)
    
    fig, ax = plt.subplots()
    df.y.plot(ax=ax, label='Original', linestyle="dashed")
    forecast_test.yhat.plot(ax=ax, label='Predict')
    ax.legend()
    

    마지막


    끝까지 읽어줘서 고마워요.
    이번에 우리는 Optuna로 시간 서열 분석 방법인 Prophet에 대해 높은 매개 변수를 조정했다.
    정정 요구가 있으면 연락 주세요.

    좋은 웹페이지 즐겨찾기