kelpnet의 예법

2844 단어 KelpNetXORC#
개요
켈프넷 방법을 찾아봤어요.
부속품
ロス - SoftmaxCrossEntropy MeanSquaredError

オプチマイザー - MomentumSGD SGD Adam AdaGrad

アクチベーション - Sigmoid ReLU TanhActivation

샘플 코드
xor
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using KelpNet;

namespace ConsoleApp1
{
    class Program
    {
        static void Main(string[] args)
        {
            Console.WriteLine("Hello World!");

            const int learningCount = 10000;
            Real[][] trainData = {
                new Real[] { 0, 0 },
                new Real[] { 1, 0 },
                new Real[] { 0, 1 },
                new Real[] { 1, 1 }
            };
            Real[][] trainLabel = {
                new Real[] { 0 },
                new Real[] { 1 },
                new Real[] { 1 },
                new Real[] { 0 }
            };
            FunctionStack nn = new FunctionStack(new Linear(2, 2, name: "l1"), new TanhActivation(name: "act"), new Linear(2, 2, name: "l2"));
            nn.SetOptimizer(new SGD());
            Console.WriteLine("1 Training...");

            for (int i = 0; i < learningCount; i++)
            {
                Real loss = 0;
                for (int j = 0; j < trainData.Length; j++)
                {
                    loss += Trainer.Train(nn, trainData[j], trainLabel[j], new SoftmaxCrossEntropy());
                }
                if (i % 1000 == 0)
                {
                    Console.WriteLine("loss: " + loss / 4);
                }
            }
            Console.WriteLine("Test Start...");
            foreach (Real[] input in trainData)
            {
                NdArray result = nn.Predict(input)[0];
                int resultIndex = Array.IndexOf(result.Data, result.Data.Max());
                Console.WriteLine(input[0] + " xor " + input[1] + " = " + resultIndex + " " + result);
            }
            ModelIO.Save(nn, "test.nn");
            Function testnn = ModelIO.Load("test.nn");
            Console.WriteLine("Test Start...");
            foreach (Real[] input in trainData)
            {
                NdArray result = testnn.Predict(input)[0];
                int resultIndex = Array.IndexOf(result.Data, result.Data.Max());
                Console.WriteLine(input[0] + " xor " + input[1] + " = " + resultIndex + " " + result);
            }
            Console.WriteLine("Press any key to exit.");
            Console.ReadKey();
        }
    }
}

결실

이상.

좋은 웹페이지 즐겨찾기