PyTorch – 이미지 분류 데이터 세트(Fashion-MNIST)

9187 단어 PyTorch학습 과정
PyTorch – 이미지 분류 데이터 세트(Fashion-MNIST)
  • 0, 선언
  • 1. 데이터 세트 획득
  • 2. 소량 데이터 읽기
  • 셋째, 그림

  • 본고는'손학 심도 학습(pytorch)'의'3.5 이미지 분류 데이터 집합(Fashion-MNIST)'을 학습한 필기로서 구체적으로 설명하려면 원문을 참고하십시오.
    선언
    사용한 가방은 주로torchvision인데 주로 다음과 같은 몇 부분으로 구성되어 있다.
  • torchvision.datasets: 데이터를 불러오는 함수와 자주 사용하는 데이터 집합 인터페이스;
  • torchvision.models: 자주 사용하는 모델 구조(예훈련 모델 포함)를 포함한다. 예를 들어 AlexNet, VGG,ResNet 등이다.
  • torchvision.transforms: 자주 사용하는 이미지 변환, 예를 들어 재단, 회전 등;
  • torchvision.utils: 다른 유용한 방법들.

  • 1. 데이터 세트 가져오기
    1. 다운로드한 데이터를 사용하지 않을 때transform=torchvision.transforms.ToTensor() 얻은 데이터는 크기(H)입니다.×W×C) 데이터가 [0, 255] 사이에 있는 PIL 이미지나 데이터 유형이 unit8인 Numpy 배열
    이 문장은 상기 형식의 데이터를 사이즈로 변환합니다 (C)×H×W) 데이터 유형이 torch입니다.float32 및 [0.0, 1.0] 사이에 있는 Tensor.
    import torchvision
    
    mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=torchvision.transforms.ToTensor())
    mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=torchvision.transforms.ToTensor())
    

    2. 데이터 태그 얻기
    데이터를 다운로드한 후 데이터에 대응하는 라벨을 찾을 수 있어야 한다.다음 함수는 수치 탭을 해당하는 텍스트 탭으로 바꿀 수 있습니다.
    def get_fashion_mnist_labels(labels):
    	text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    	return [text_labels[int(i)] for i in labels]
    

    2. 소량의 데이터 읽기torch.utilsdata의 한 방법DataLoader은 크기의 데이터를 쉽게 읽을 수 있다. 세 가지 자주 사용하는 세 가지 파라미터는 각각 batch_size이다.
    import torch.utils.data as Data
    
    batch_size = 256
    train_iter = Data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
    test_iter = Data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)
    

    3. 그림 그리기
    다음은 한 줄에 여러 장의 그림과 탭을 그릴 수 있는 함수를 정의합니다.
    #        d2lzh        
    def show_fashion_mnist(images, labels):
        d2l.use_svg_display()
        #    _      (   )   
        _, figs = plt.subplots(1, len(images), figsize=(12, 12))
        for f, img, lbl in zip(figs, images, labels):
            f.imshow(img.view((28, 28)).numpy())
            f.set_title(lbl)
            f.axes.get_xaxis().set_visible(False)
            f.axes.get_yaxis().set_visible(False)
        plt.show()
    

    좋은 웹페이지 즐겨찾기