pytorch 응용 프로그램

1315 단어 pytorch
3대 절차: 데이터 읽기, 네트워크 구축, 기타 보조 데이터 읽기, 예를 들어 mnist는 torchvision의 데이터sets 방법으로 바보식 읽기를 하면 분류 문제에 대해 torchvision을 사용할 수 있다.datasets.ImageFolder에서 이미지 및 label 정보를 읽습니다.
data_dir = '/data' 
image_datasets = {x: datasets.ImageFolder( 
os.path.join(data_dir, x), data_transforms[x]
), for x in ['train', 'val']}

torchvision.datasets.ImageFolder는 list만 되돌려줍니다.list는 모델로 입력할 수 없기 때문에 PyTorch에서 다른 종류로 list를 봉인해야 합니다. 그것이 바로 torch입니다.utils.data.DataLoader.torch.utils.data.DataLoader 클래스는 list 유형의 입력 데이터를 Tensor 데이터 형식으로 봉하여 모델에 사용할 수 있습니다.그림% 1개의 캡션을 편집했습니다.
데이터가 하나의 클래스, 하나의 폴더로 저장되지 않을 때, 사용자는 하나의 클래스를 사용자 정의해서 데이터를 읽어야 한다. 사용자 정의의 클래스는 torch에서 계승해야 한다.utils.data.Dataset이라는 기본 클래스는 마지막으로torch를 사용합니다.utils.data.DataLoader는 Tensor로 캡슐화됩니다.
torchvision을 이용하다.모델의 모델은 조건을 충족시킬 수 있고 (pretrain의 매개 변수도 있음) 마지막 분류classes의 수가 같지 않으면 앞의 fc를 추출할 수 있습니다.
# coding=UTF-8
import torchvision.models as models
 
# 
model = models.resnet50(pretrained=True)
# fc 
fc_features = model.fc.in_features
# 9
model.fc = nn.Linear(fc_features, 9)


일반적으로 우리는 분류된 앞쪽 FC를 피처스라고 부른다
자신이 읽은 데이터에 대해 데이터sets를 계승하고 다시 쓰기len__() 및getitem__() 두 가지 방법 def getitem(self, index): index가 모든 데이터를 두루 읽어야 하고 최종적으로 이미지 & label도 모든 데이터를 되돌려야 한다는 것을 알 수 있습니다!

좋은 웹페이지 즐겨찾기