pytorch 생성 데이터 집합

3803 단어 pytorch
import torch
import torchvision
from torchvision import datasets,transforms
dataroot = "data/celeba"  #  
#  
dataset = datasets.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

1)torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
여기에는 주로 네 개의 매개변수가 있습니다.
root: root에서 지정한 경로에서 그림transform 찾기: PIL Image에 대한 변환 작업, transform의 입력은loader로 그림을 읽는 반환 대상 targettransform: label에 대한 변환 loader: 경로를 지정한 후 이미지를 읽는 방법, 기본적으로 RGB 형식의 PIL Image 객체 원본 링크로 읽는 방법:https://blog.csdn.net/weixin_40123108/article/details/85099449
Pytorch의 torchvision 모듈에는 mnist,coco,imagenet, 일반적인 데이터 로더인 ImageFolder와 같은 기본적인 데이터 세트가 포함되어 있습니다.서로 다른 폴더 아래의 그림은 서로 다른 종류로 간주되어 선천적으로 이미지 분류 작업에 사용된다.
imagefolder에는 세 개의 구성원 변수가 있습니다: x가train일 때의 image데이터sets의 속성
  • self.classes: 클래스 이름을 하나의list로 저장합니다. 폴더의 이름입니다.예를 들어 ['green','normal','out','right']
  • self.class_to_idx: 클래스에 대응하는 색인으로 0, 1, 2, 3 등으로 이해할 수 있습니다.예를 들어 {'out':2,'green':0,'right':3,'normal':1}
  • self.imgs: 저장 (imgpath,class) 은 그림과 클래스의 그룹입니다.예를 들어 [('datasets/test True TrainTest/train/green/0000000012200roi.jpg', 0), ('datasets/test True TrainTest/train/right/0000012980roi.jpg', 3)]
  • 2)torchvision.transforms
    torchvision.transforms 모듈은 일반적인 이미지 변환 조작 클래스를 제공합니다.
  • class torchvision.transforms.ToTensor

  • shape=(H x W x C)의 픽셀 값을 [0, 255]로 하는 PIL.Image 및 numpy.ndarray는shape=(C x H x W)의 픽셀 값 범위가 [0.0, 1.0]인 torch로 변환됩니다.FloatTensor.
  • class torchvision.transforms.Normalize(mean, std)

  • 이 변환 클래스는 torch.*Tensor.균일치(R, G, B)와 표준차(R, G, B)를 정하고 공식 채널 =(channel - mean)/std로 규범화한다.텍스트 링크:https://blog.csdn.net/wsp_1138886114/article/details/83620869
     
    3)torch.utils.data.DataLoader
    batch크기가 Tensor로 캡슐화됨
  • 1.데이터셋(Dataset), 데이터 읽기 인터페이스(예를 들어 torchvision.datasets.ImageFolder) 또는 사용자 정의 데이터 인터페이스의 출력, 이 출력은 torch입니다.utils.data.Dataset 클래스의 객체(또는 클래스의 사용자 정의 클래스를 상속하는 객체)입니다.
  • 2.batch_size(int,optional), 일괄 트레이닝 데이터량의 크기는 구체적인 상황에 따라 설정하면 됩니다.(기본값: 1)
  • 3.shuffle(bool,optional), 데이터를 어지럽히는 것은 일반적으로 훈련 데이터에서 사용된다.(기본값: False)
  • 4.샘플을 데이터 세트에서 추출하는 Sampler(Sampler, optional) 정책지정하면 "shuffle"은false여야 합니다.일반적으로 묵인하면 된다.
  • 5.batch_sampler(Sampler, optional), 그리고batch크기, shuffle 등 매개 변수는 서로 배척되며, 일반적으로 기본값을 사용합니다.
  • 6.num_workers, 이 매개 변수는 0보다 커야 합니다. 다른 0보다 큰 수는 여러 프로세스를 통해 데이터를 가져오면 데이터 가져오는 속도를 높일 수 있음을 나타냅니다.(기본값: 0)
  • 7.collate_fn(callable,optional), 샘플 목록을 합쳐서 소량을 형성합니다.서로 다른 상황에서 입력한 데이터셋을 처리하는 데 사용되는 봉인은 기본적으로 사용되며, 사용자가 정의한 데이터 읽기 출력이 매우 드물지 않으면 됩니다.
  • 8.pin_memory (bool, optional): 데이터 로더는 장량을 CUDA 메모리에 복사한 다음 되돌려줍니다.즉, 데이터 복제의 문제입니다.
  • 9.drop_st (bool, optional): 데이터 집합의 크기가 일괄 크기로 제거되지 않으면 마지막 미완성 일괄 크기를 제거하기 위해 "true"로 설정합니다."false"라면 마지막 부분은 더 작아집니다.(기본값: false)
  • 10.시간 아웃 (numeric, optional): 데이터 읽기 시간 초과를 설정하지만, 이 시간을 초과하여 데이터를 읽지 못하면 오류가 발생합니다.(기본값: 0)
  • 11.worker_init_fn(callable, optional): none이 아니면 피드 설정 후 데이터가 불러오기 전에 모든 작업 프로세스에서 그것을 사용하고 작업 프로세스 ID([0,num workers-1)(기본값: None)
  • 를 입력합니다
    텍스트 링크:https://blog.csdn.net/wsp_1138886114/article/details/84146704

    좋은 웹페이지 즐겨찾기