chainer의 작법 그 12

2489 단어 ChainerKaggle

개요



chainer의 작법, 조사해 보았다.
kaggle의 cat&dog 해봤다.
Convolution2D 사용해 보았다.

결과





샘플 코드


import glob
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training, datasets, iterators, Chain, optimizers, serializers
from chainer.training import extensions
from PIL import Image
import numpy as np


class CNN(Chain):
    def __init__(self, n_out):
        super(CNN, self).__init__(conv1 = L.Convolution2D(None, 16, 3, 2), conv2 = L.Convolution2D(16, 32, 3, 2), conv3 = L.Convolution2D(32, 32, 3, 2), fc4 = L.Linear(None, 100), fc5 = L.Linear(100, n_out))
    def __call__(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = F.relu(self.fc4(h))
        h = self.fc5(h)
        return h

def main():
    cats = glob.glob('../cats/dogs/cat*')
    dogs = glob.glob('../cats/dogs/dog*')
    data = []
    for i in cats:
        data.append((i, 0))
    for i in dogs:
        data.append((i, 1))
    dataset = datasets.LabeledImageDataset(data)
    def transform(inputs):
        img, label = inputs
        img = img[ : 3, ...]
        img = img.astype(np.uint8)
        img = Image.fromarray(img.transpose(1, 2, 0))
        img = img.resize((32, 32), Image.BICUBIC)
        img = np.array(img, dtype = np.float32).transpose(2, 0, 1) / 255
        return img, label
    dataset = datasets.TransformDataset(dataset, transform)
    train_iter = iterators.SerialIterator(dataset, 100)
    test_iter = iterators.SerialIterator(dataset, 100, repeat = False, shuffle = False)
    model = L.Classifier(CNN(10))
    optimizer = optimizers.Adam()
    optimizer.setup(model)
    updater = training.StandardUpdater(train_iter, optimizer, device = -1)
    trainer = training.Trainer(updater, (5, 'epoch'), out = 'result')
    trainer.extend(extensions.Evaluator(test_iter, model, device = -1))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'elapsed_time']))
    trainer.run()
    serializers.save_npz('cats1.model', model)

if __name__ == '__main__':
    main()




이상.

좋은 웹페이지 즐겨찾기