Pytorch 학습 노트(1):pytoch에서 훈련 데이터를 어떻게 불러오는지

1.로드 메서드를 직접 쓸 필요가 없는 이유


pytorch에서 두 가지 종류의 트레이닝 데이터를 불러옵니다. 각각 torch입니다.utils.data.Dataset 및 torch.utils.data.DataLoader .torchvision에서 자주 사용하는 컴퓨터 시각의 자주 사용하는 데이터 집합과 달리 음악 정보 검색에 있어 데이터 집합은 자체적으로 불러오는 방법을 설계해야 한다.만약 매번 다른 데이터 집합이 스스로 함수를 써서 불러온다면,
  • 매번 읽기 코드를 다시 사용할 수 없고 데이터 읽기 코드가 다르다
  • 자신이 쓴 로드 함수도 여러 가지 문제가 있다. 예를 들어 데이터 읽기 속도를 제한하거나 데이터 집합이 너무 크면 사전이나 목록에 직접 로드하면 메모리를 많이 차지하고 데이터 읽기 단계도 많은 시간을 차지한다
  • 단일 스레드로만 데이터를 읽을 수 있음
  • 이번에 내가 한 실험은 노래의 멜 스펙트럼을 탑재해야 한다. 각 노래의 부분은 30초이고 대략 1290*128 크기의 행렬이다.그래서 이번에는pytorch의Dataset 클래스를 사용하여 데이터를 불러오기로 결정했습니다.

    2. Dataset 클래스


    class torch.utils.data.Dataset
    이 추상 클래스는 데이터 집합을 대표한다. 우리가 디자인한 데이터 집합 클래스는 모두 이 클래스의 하위 클래스가 되어야 한다. 이 클래스를 계승하고 다시 쓰기len_() 방법, 이 방법은 데이터 집합의 크기를 얻는 데 사용되며getitem__() 방법, 이 방법은 데이터 집중 인덱스 값이 0에서 렌 (dataset) 인 요소를 되돌려줍니다.
  • def __getitem__(self, index): 이 함수를 실현하면 색인 값을 통해 훈련 샘플 데이터
  • 를 되돌려줄 수 있습니다.
  • def __len__(self): 이 함수를 실현하고 데이터 집합의 크기
  • 를 되돌려줍니다.
    class Dataset(object):
        """An abstract class representing a Dataset.
    
        All other datasets should subclass it. All subclasses should override
        ``__len__``, that provides the size of the dataset, and ``__getitem__``,
        supporting integer indexing in range from 0 to len(self) exclusive.
        """
    
        def __getitem__(self, index):
            raise NotImplementedError
    
        def __len__(self):
            raise NotImplementedError
    
        def __add__(self, other):
            return ConcatDataset([self, other])
    

    이 두 개인 함수를 다시 쓰지 않으면 오류가 발생합니다.

    3. 자신의 데이터 세트 클래스 정의


    그래서 나는 자신의 수요에 맞추어 다음과 같은 유형을 실현했다.
    class Fma_dataset(Dataset):
        # root         , mode      train,test,validation,          
        def __init__(self, root, mode): 
            self.mode = mode
            self.root = root + "/fma_" + self.mode
            self.mel_cepstrum_path = self.get_sample(self.root)
    
        def __getitem__(self, index):
            sample = np.load(self.mel_cepstrum_path[index])
            data = torch.from_numpy(sample[0])
            target = torch.from_numpy(sample[1].astype(np.float32))
            return data, target
    
        def __len__(self):
            if self.mode == "train":
                return 23733  #      
            elif self.mode == "validation":
                return 6780  #      
            elif self.mode == "test":
                return 3390  #      
    
        def get_sample(self, root):
            cepstrum = []
            for entry in os.scandir(root):
                if entry.is_file():
                    cepstrum.append(entry.path)
            return cepstrum
    

    4. DataLoader 클래스


    class torch.utils.data.DataLoader (dataset, batch_size=1**,** shuffle=False**,** sampler=None**,** batch_sampler=None**,** num_workers=0**,** collate_fn=, pin_memory=False**,** drop_last=False**,** timeout=0**,** worker_init_fn=None**)**
    색인을 통해 트레이닝 데이터를 되돌려주는 것만으로는 부족하며 DataLoad 클래스의 확대 기능도 필요합니다.
  • 일괄 읽기 가능:batch-size
  • 데이터에 대해 shuffle 조작 가능
  • 여러 스레드로 데이터를 읽을 수 있음
  • 이 종류는 우리가 코드를 실현할 필요가 없다. 직접 호출하고 파라미터를 설정하면 된다.

    좋은 웹페이지 즐겨찾기