matplotlib로 예측 결과 3D 표시

소개



심층 학습에서 예측한 결과 중 오차가 큰 것과 오차가 작은 것에 특징은 있는지 가시화해 확인하고 싶었으므로 3D 그래프에 플롯해 보았습니다.
* 예측 실패: 빨간색
* 예측에 성공: 파랑
로 표시됩니다.

사용한 라이브러리


  • matplotlib
  • pandas
  • Pillow

  • 3D 디스플레이



    Matplotlib 공식이 색으로 구분 된 산점도를 표시하는 방법을 썼기 때문에 참고했습니다.
    mplot3d tutorial — Matplotlib 2.0.2 documentation

    사용할 데이터



    x, y, z, TF로 레이블을 지정했습니다.


    그리기



    plot3d.py
    import pandas as pd
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from io import BytesIO
    from PIL import Image
    
    import_path = './test.csv'
    
    
    def main():
        fig = plt.figure(figsize=(10.0, 9.0))
        ax = fig.add_subplot(111, projection='3d')
    
        data = pd.read_csv(import_path, header=0)
        print(data)
        for index, row in data.iterrows():
            print('\rindex = {}/{}, TorF = {}  '.format(index,len(data), str(row['TF'])),end="")
            color = 'r'
            marker = 'x'
            if str(row['TF']) == 'True':
                color = 'b'
                marker = 'o'
            ax.scatter(row['x'], row['y'], row['z'], c = color, marker = marker)
    
        print('\n---Finished Loading---')
        ax.set_xlabel("x", fontsize=20)
        ax.set_ylabel("y", fontsize=20)
        ax.set_zlabel("z", fontsize=20)
        ax.view_init(30, 0)
        plt.savefig('./output2D-1.png')
        ax.view_init(30, 90)
        plt.savefig('./output2D-2.png')
        ax.view_init(30, 180)
        plt.savefig('./output2D-3.png')
    
    
    if __name__ == '__main__':
        main()
    
    

    그린 결과







    3D 디스플레이 회전



    GIF 애니메이션으로 돌리면서 회전했습니다.
    이 기사를 참고로했습니다.
    3D 산점도를 회전 GIF 애니메이션으로 만들기 - Qiita

    plot3d.py
    import pandas as pd
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from io import BytesIO
    from PIL import Image
    import time
    import datetime
    
    import_path = './test.csv'
    
    
    def  render_frame(ax, angle):
        print('\rangle = {}'.format(angle),end="")
        ax.view_init(30, angle)
        buf = BytesIO()
        plt.savefig(buf, bbox_inches='tight', pad_inches=0.0)
        return Image.open(buf)
    
    
    def main():
        start = time.time()
        fig = plt.figure(figsize=(10.0, 9.0))
        ax = fig.add_subplot(111, projection='3d')
    
        data = pd.read_csv(import_path, header=0)
        for index, row in data.iterrows():
            print('\rindex = {}/{}, TorF = {}  '.format(index,len(data), str(row['TF'])),end="")
            color = 'r'
            marker = 'x'
            if str(row['TF']) == 'True':
                color = 'b'
                marker = 'o'
            ax.scatter(row['x'], row['y'], row['z'], c = color, marker = marker)
    
        ax.set_xlabel("x", fontsize=20)
        ax.set_ylabel("y", fontsize=20)
        ax.set_zlabel("z", fontsize=20)
    
        print('\n---Finished Loading---\n')
        images = [render_frame(ax, angle) for angle in range(50)]
        images[0].save('output3D.gif', save_all=True, append_images=images[1:], duration=100, loop=0)
    
        # 経過時間の集計
        process_time = time.time() - start
        td = datetime.timedelta(seconds=process_time)
        print('PROCESS TIME = {}'.format(td))
    
    
    if __name__ == '__main__':
        main()
    
    

    GIF 애니메이션



    좋은 웹페이지 즐겨찾기