streamlit로 머신러닝 웹 응용 프로그램 만들기

의 목적


https://github.com/komo135/trade-rl
streamlit으로trade-rl의 간단한 웹 애플리케이션을 제작합니다.streamlit는 HTML, CSS, 자바스크립트 지식이 전혀 없어도 웹 애플리케이션을 만들 수 있다.자세한 내용은 home page를 참조하십시오.

그림 완성


image

1. 필요한 패키지 설치


다음 명령을 입력하여 필요한 패키지를 디바이스에 설치합니다.
git clone https://github.com/komo135/trade-rl.git
cd trade-rl
pip install .

pip install streamlit

웹 응용 프로그램 실행


실행하기 전에 필요한 포장을 설치하세요.
streamlit run https://raw.githubusercontent.com/komo135/traderl-web-app/main/github_app.py [ARGUMENTS]

생성 방법


2. 매크로 패키지 가져오기

  • streamlit->웹 응용 프로그램 만들기에 필요한
  • agent->streamlit에 맞춤형(print와 matplotlib의 부분만 변경하고 그대로 유지하면 화면에 나타나지 않음)
  • pandas->csv 파일 읽기
  • gc->메모리 열기
  • import streamlit as st
    from agent import dqn, qrdqn
    from traderl import data, nn
    import pandas as pd
    import gc
    

    3. 사이드바 만들기


    st.sidebar.라디오 사이드바에 라디오 단추를 만듭니다.
    def sidebar():
        return st.sidebar.radio("", ("Home", "select data", "create agent", "training",
                                     "show results", "save model", "initialize"))
    

    주체의 제작


    전역 코드
    class App:
        def __init__(self):
    
            self.df = None
            self.agent = None
            self.model_name = ""
    
        def select_data(self):
            file = None
    
            select = st.selectbox("", ("forex", "stock", "url or path", "file upload"))
            col1, col2 = st.columns(2)
            load_file = st.button("load file")
    
            if select == "forex":
                symbol = col1.selectbox("", ("AUDJPY", "AUDUSD", "EURCHF", "EURGBP", "EURJPY", "EURUSD",
                                             "GBPJPY", "GBPUSD", "USDCAD", "USDCHF", "USDJPY", "XAUUSD"))
                timeframe = col2.selectbox("", ("m15", "m30", "h1", "h4", "d1"))
                if load_file:
                    self.df = data.get_forex_data(symbol, timeframe)
            elif select == "stock":
                symbol = col1.text_input("", help="enter a stock symbol name")
                if load_file:
                    self.df = data.get_stock_data(symbol)
            elif select == "url or path":
                file = col1.text_input("", help="enter url or local file path")
            elif select == "file upload":
                file = col1.file_uploader("", "csv")
    
            if load_file and file:
                st.write(file)
                self.df = pd.read_csv(file)
    
            if load_file:
                st.write("Data selected")
    
        def check_data(self):
            f"""
            # Select Data
            """
            if isinstance(self.df, pd.DataFrame):
                st.write("Data already exists")
                if st.button("change data"):
                    st.warning("data and agent have been initialized")
                    self.df = None
                    self.agent = None
    
            if not isinstance(self.df, pd.DataFrame):
                self.select_data()
    
        def create_agent(self, agent_name, args):
            agent_dict = {"dqn": dqn.DQN, "qrdqn":qrdqn.QRDQN}
            self.agent = agent_dict[agent_name](**args)
    
        def agent_select(self):
            if not isinstance(self.df, pd.DataFrame):
                st.warning("data does not exist.\n"
                           "please select data")
                return None
    
            agent_name = st.selectbox("", ("dqn", "qrdqn"), help="select agent")
    
            """
            # select Args
            """
            col1, col2 = st.columns(2)
            network = col1.selectbox("select network", (nn.available_network))
            network_level = col2.selectbox("select network level", (f"b{i}" for i in range(8)))
            network += "_" + network_level
            self.model_name = network
    
            col1, col2, col3, col4 = st.columns(4)
            lr = float(col1.text_input("lr", "1e-4"))
            n = int(col2.text_input("n", "3"))
            risk = float(col3.text_input("risk", "0.01"))
            pip_scale = int(col4.text_input("pip scale", "25"))
            col1, col2 = st.columns(2)
            gamma = float(col1.text_input("gamma", "0.99"))
            use_device = col2.selectbox("use device", ("cpu", "gpu", "tpu"))
            train_spread = float(col1.text_input("train_spread", "0.2"))
            spread = int(col2.text_input("spread", "10"))
    
            kwargs = {"df": self.df, "model_name": network, "lr": lr, "pip_scale": pip_scale, "n": n,
                      "use_device": use_device, "gamma": gamma, "train_spread": train_spread,
                      "spread": spread, "risk": risk}
    
            if st.button("create agent"):
                self.create_agent(agent_name, kwargs)
                st.write("Agent created")
    
        def agent_train(self):
            if self.agent:
                if st.button("training"):
                    self.agent.train()
            else:
                st.warning("agent does not exist.\n"
                           "please create agent")
    
        def show_result(self):
            if self.agent:
                self.agent.plot_result(self.agent.best_w)
            else:
                st.warning("agent does not exist.\n"
                           "please create agent")
    
        def model_save(self):
            if self.agent:
                save_name = st.text_input("save name", self.model_name)
                if st.button("model save"):
                    self.agent.model.save(save_name)
                    st.write("Model saved.")
            else:
                st.warning("agent does not exist.\n"
                           "please create agent")
    
        @staticmethod
        def clear_cache():
            if st.button("initialize"):
                st.experimental_memo.clear()
                del st.session_state["app"]
                gc.collect()
    
                m = """
                    **Initialized.**
                    """
                st.markdown(m)
    

    1. 데이터 선택

  • check_data에서 데이터 존재 여부 확인
  • 데이터가 이미 존재하는 경우st.button("change data") 데이터 변경 여부를 확인하고 클릭한 후 데이터를 초기화한다.
  • 데이터가 존재하지 않는 상황에서 st.selectbox("", ("forex", "stock", "url or path", "file upload"))에서 데이터를 어디서 얻었는지 확인하고 선택한 데이터에 따라 마지막으로 클릭st.button("load file")하여 데이터를 읽는다.
  • 2. 에이전트의 제작

  • 데이터가 존재하는지 확인하고 존재하지 않을 때 경고합니다.
  • 데이터가 존재할 때 사용하는 에이전트,network, 파라미터st.button("create agent")를 선택하고 에이전트를 만들려면 누르십시오.
  • 3. 모델 트레이닝

  • 에이전트가 존재하는지 확인하고 존재하지 않을 때 경고합니다.
  • 존재하면 훈련 기종.
  • 4. 훈련 결과의 표시

  • 에이전트가 존재하는지 확인하고 존재하지 않을 때 경고합니다.
  • 존재 시 훈련 결과가 표시됩니다.
  • 5. 모델 저장

  • 에이전트가 존재하는지 확인하고 존재하지 않을 때 경고합니다.
  • 저장된 파일 이름을 입력하여 모델을 저장합니다.
  • 6. 초기화


    모두 초기화합니다.
    다음 코드는 세션 상태에 저장된 클래스를 삭제합니다.
    del st.session_state["app"]
    

    실행


    사이드바의 라디오 단추에서 선택한 요소에 따라 함수를 실행합니다
    if __name__ == "__main__":
        
        #ページのレイアウトをワイドに変更する。
        st.set_page_config(layout="wide", )
    
        if "app" in st.session_state:
            app = st.session_state["app"]
        else:
            app = App()
    
        select = sidebar()
    
        # マークダウンの表示
        if select == "Home":
            home()
    
        elif select == "select data":
            app.check_data()
        elif select == "create agent":
            app.agent_select()
        elif select == "training":
            app.agent_train()
        elif select == "save model":
            app.model_save()
        elif select == "show results":
            app.show_result()
    
        st.session_state["app"] = app
        if select == "initialize":
            app.clear_cache()
    
    아래 코드는 st.session_state에'app'가 존재하는 경우st.session_state에'app'를 불러오고 없으면'app'를 만듭니다.
    마지막으로 st.session_state["app"]에 덮어씁니다.
    왜 그랬을까
    매번 선택 단추를 다시 선택할 때마다 처음부터 실행되며 데이터와 에이전트는 저장되지 않습니다.따라서 선택 단추를 다시 선택해도 바뀌지 않습니다. st.session_state 'app' 를 저장하고 불러옵니다.
    사이드바 함수를 구분하는 이유는 'app' 이기 때문에, 뒷사이드바를 불러오면 표시되지 않습니다.
    마지막으로 초기화한 이유는 삭제st.session_state["app"]하더라도 이 코드st.session_state["app"] = app까지는 의미가 없다는 것이다.
        if "app" in st.session_state:
            app = st.session_state["app"]
        else:
            app = App()
    
        st.session_state["app"] = app
        if select == "initialize":
            app.clear_cache()
    

    github


    https://github.com/komo135/traderl-web-app
    전체 코드는 여기 있습니다.

    좋은 웹페이지 즐겨찾기