[PyTorch]Seed 고정

3936 단어 PyTorchPyTorch

실험을 위한 Randomness 제어 방법

  1. Pytorch
import torch
torch.manula_seed(random_seed)
  1. CuDNN
    : 딥러닝에 특화된 CUDA library. 딥러닝 프레임워크에서 필수적으로 사용되는 라이브러리
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

*주의사항 : 연산 속도가 느려짐 (연구 실험 초기단계보다는 후반 단계에서 사용하는 것을 권장)

  1. Numpy
    : numpy로 데이터를 받아오고, metric 계산
import numpy as np
np.random.seed(random_seed)
  1. Random
    : torchvision trasnforms는 python random 라이브러리로 randomness가 제어됨
import random
random.seed(random_seed)
  1. 동일한 조건에서 학습시 weight가 변화하지 않게 하는 옵션
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
  • 결론
torch.manual_seed(random_seed) # torch 
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True # cudnn
torch.backends.cudnn.benchmark = False # cudnn
np.random.seed(random_seed) # numpy
random.seed(random_seed) # random

좋은 웹페이지 즐겨찾기