기계 학습 시각화 도구 Tensorboard를 사용해 보았습니다.

소개



Tensorboard를 처음 사용해 그래프를 써 그 편리함에 감동했으므로, 공유합니다.
Deep Learning 프레임워크는 PyTorch를 사용했습니다.

설치



Anaconda를 사용하고 있으므로 다음 명령으로 Tensorboard를 설치합니다.
conda install tensorboard

코딩



tb.py
import numpy as np
from torch.utils.tensorboard import SummaryWriter#グラフを書くSummaryWriterをimport

np.random.seed(1000)

x = np.random.randn(1000)

writer = SummaryWriter(log_dir="./logs")#インスタンス生成 保存するディレクトリも指定

for i in range(1000):
    writer.add_scalar("x", x[i], i)#値を書き込む
    writer.add_scalar("sin", np.sin(i), i)

writer.close()#閉じる

파일명을 tensorboard.py로 하면 module과 쓰고 ImportError가 되므로 주의합시다.

해설



간단히 말해 위의 코드는 임의의 값을 가진 배열과 sin 함수를 플로팅합니다.

SummaryWriter import


from torch.utils.tensorboard import SummaryWriterTensorboard에서 그래프를 그리는 데 필요한 모듈인 SummaryWriter를 가져옵니다.

인스턴스 생성


writer = SummaryWriter(log_dir="./logs")이렇게하면 현재 디렉토리에 logs 디렉토리가 작성되고 logs에 Tensorboard 용 파일이 저장됩니다.

값 대입


writer.add_scalar("x", x[i], i) 로 배열의 값을 넣습니다.writer.add_scalar(tags, scalar_value, global_step) 이며 tags에서 그래프의 이름을 지정하고 scalar_value로 저장할 값을 할당하고 global_step에서 그래프의 가로축 간격을 지정합니다.

닫다


writer.close() 로 마지막에 닫자.

그래프 보기



tb.py 실행



위의 코드를 실행합시다. 그래프가 그려집니다.
python tb.py

그래프 보기



다음 명령을 실행합시다. --logdir="" 에 저장된 디렉토리를 지정합시다.
이번은 ./logs 입니다.
tensorboard --logdir="./logs"

그러면 다음 문장이 터미널에 출력됩니다.
TensorBoard 2.2.1 at http://localhost:8000/ (Press CTRL+C to quit)

로컬 서버가 시작되기 때문에 브라우저에 http://localhost:8000/ 를 치자.



chrome에서 보면 그래프가 깨끗하게 플롯되어 있는 것을 알 수 있습니다.

ssh 대상 그래프 보기



Deep Learning의 코드는 계산량이 많아 로컬 PC(수중의 PC)에서는 막대한 시간이 걸리므로,
실험실의 서버 GPU로 ssh하고 서버에서 코드를 돌리는 것이 기본값입니다.
그렇다면 원격 서버에서 그린 그래프를 로컬 PC에서 어떻게 볼 수 있습니까?

원격 서버로 ssh



ssh를 할 때 -L 옵션을 사용하여 클라이언트 (로컬 PC)의 localhost : 9000을 원격 서버의 사용자 이름 @ 서버의 IP 주소 : 8000에 연결합니다.

@ 로컬 PC
ssh ユーザ名@サーバーのIPアドレス -L 9000:localhost:8000

원격 서버에서 tb.py 실행



ssh 한 원격 서버에서 그래프를 그리는 코드를 실행합시다.

@ 원격 서버
python tb.py

Tensorboard 실행



ssh 한 원격 서버에서 그래프를 보는 명령을 실행합시다.
ssh했을 때 로컬 PC에 연결한 포트는 8000이므로 --port 옵션으로 8000을 지정하여 실행합시다.

@ 원격 서버
tensorboard --logdir="./logs" --port 8000

다음과 같은 문장이 출력됩니다.

@ 원격 서버
TensorBoard 2.2.1 at http://localhost:8000/ (Press CTRL+C to quit)

그래프 보기



아까는 http://localhost:8000/를 브라우저에 입력하면 그래프가 보였지만 이번에는 볼 수 없습니다.

이번에는 원격 서버의 포트 8000과 로컬 PC의 포트 9000을 연결했기 때문에,
로컬 PC의 브라우저에서 http://localhost:9000/를 입력하면 조금 전과 같은 그래프를 볼 수 있습니다.



요약



PyTorch에서 Tensorboard를 사용하여 그래프를 그렸습니다.
또, ssh처의 리모트 서버로 돌린 코드의 그래프를 로컬 PC로 보는 방법을 소개했습니다.
저도 이 Tensorboard와 ssh -L을 이용하여 Deep Learning에 활용해 가고 싶습니다.

좋은 웹페이지 즐겨찾기