kelpnet의 작법 7

개요


나는 켈프넷의 방법을 조사했다.
선로를 추적해 보세요.
공부를 해보다.

결과



예제 코드

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using KelpNet;
using System.IO;

namespace ConsoleApp4
{
    class Program
    {
        const int EPOCH = 3001;       
        static void Main(string[] args)
        {
            int N = 2000;
            Real[][] trainData = new Real[N][];
            Real[][] trainLabel = new Real[N][];
            int[] Label = new int[N];
            int i = 0;
            Console.WriteLine("Read Start...");
            using (StreamReader sr = new StreamReader(@"test.csv"))
            {
                while (!sr.EndOfStream)
                {
                    string line = sr.ReadLine();
                    string[] items = line.Split(',');
                    trainData[i] = new Real[] {
                        double.Parse(items[0]),
                        double.Parse(items[1]),
                        double.Parse(items[2]),
                        double.Parse(items[3]),
                        double.Parse(items[4])
                    };
                    trainLabel[i] = new Real[] {
                        int.Parse(items[5])
                    };
                    Label[i] = int.Parse(items[5]);
                    i++;
                }
            }
            N = i;
            FunctionStack nn = new FunctionStack(new Linear(5, 20, name: "in"), new Sigmoid(name: "act"), new Linear(20, 3, name: "out"));
            nn.SetOptimizer(new SGD());
            Console.WriteLine("Train Start...");
            for (i = 0; i < EPOCH; i++)
            {
                Real loss = 0;
                for (int j = 0; j < N; j++)
                {
                    loss += Trainer.Train(nn, trainData[j], trainLabel[j], new SoftmaxCrossEntropy());
                }
                if (i % 100 == 0)
                {
                    Console.WriteLine("loss:" + loss / N);
                }
            }
            Console.WriteLine("Save Start...");
            ModelIO.Save(nn, "test.nn");
            Console.WriteLine("Load Start...");
            Function testnn = ModelIO.Load("test.nn");
            Console.WriteLine("Test Start...");
            float ac = 0;
            for (int j = 0; j < N; j++)
            {
                NdArray result = testnn.Predict(trainData[j])[0];
                int resultIndex = Array.IndexOf(result.Data, result.Data.Max());
                if (resultIndex == Label[j]) ac++;
            }
            Console.Write("正解率 ");
            Console.WriteLine(ac / N * 100);
            Console.WriteLine("Press any key to exit.");
            Console.ReadKey();
        }
    }
}

이상

좋은 웹페이지 즐겨찾기