chainer의 작법 그 11

개요



chainer의 작법, 조사해 보았다.
Deconvolution2D 사용해 보았다.
autoencoder 써 보았다.

사진





샘플 코드


import numpy as np
from chainer import datasets, iterators
from chainer import optimizers
from chainer import Chain
from chainer import training
from chainer.training import extensions
import chainer.functions as F
import chainer.links as L
import chainer
import matplotlib.pyplot as plt


class Autoencoder(Chain):
    def __init__(self):
        super(Autoencoder, self).__init__(encoder = L.Convolution2D(None, 16, 5, 1), decoder = L.Deconvolution2D(16, 1, 5, 1))
    def __call__(self, x):
        h = F.relu(self.encoder(x))
        return F.relu(self.decoder(h))

def main():
    train, test = datasets.get_mnist()
    def transform(data):
        img, lable = data
        img = img.reshape((1, 28, 28))
        return img, lable
    train = datasets.TransformDataset(train, transform)
    test = datasets.TransformDataset(test, transform)
    train = train[0 : 1000]
    train = [i[0] for i in train]
    train = datasets.TupleDataset(train, train)
    train_iter = iterators.SerialIterator(train, 100)
    test = test[0 : 25]
    model = L.Classifier(Autoencoder(), lossfun = F.mean_squared_error)
    model.compute_accuracy = False
    optimizer = optimizers.Adam()
    optimizer.setup(model)
    updater = training.StandardUpdater(train_iter, optimizer, device = -1)
    trainer = training.Trainer(updater, (80, 'epoch'), out = "result")
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/loss']))
    trainer.run()
    pred = []
    for (data, label) in test:
        pred_data = model.predictor(np.array([data]).astype(np.float32)).data
        pred.append((pred_data, label))
    for index, (data, label) in enumerate(pred):
        plt.subplot(5, 5, index + 1)
        plt.axis('off')
        plt.imshow(data.reshape(28, 28), cmap = plt.cm.gray_r, interpolation = 'nearest')
        n = int(label)
        plt.title(n, color = 'red')
    plt.savefig("auto3.png")
    plt.show()

if __name__ == '__main__':
    main()



이상.

좋은 웹페이지 즐겨찾기