Flux.jl에서 Sin 파도 학습

배경



Julia에는 몇 가지 신경망 라이브러리가 있습니다.
나는 옛날 Mocha.jl을 사용하고 있었지만 Julia1.0 이후라고 지금의 소비 추천이라고 하는 것이므로 Flux.jl에 이민했다.
자신 가운데 새로운 신경망 라이브러리를 사용할 때 Sin파를 학습한다는 의식이 있기 때문에 의식을 실시했다.

코드



using Plots
gr()

using Flux
using Statistics
using Flux.Tracker: TrackedReal, data
using Flux: mse
using Base.Iterators: repeated, flatten

# 訓練データ
N = 100
X = range(0, stop = pi, length = N)
Y = sin.(X)

# 訓練データをプロットしておく
plot(X, Y)

data_x = [[x] for x in X]
data_y = [[y] for y in Y]
# batch処理すべきだが、めんどうなのでrepeatedでごまかした。
# Model-Zooもrepeatedでなんとかしてる奴あるしいいよね。
data_xf = Iterators.flatten(repeated(data_x, 100))
data_yf = Iterators.flatten(repeated(data_y, 100))
入力データは[(入力の配列, 出力の配列)]な形式
dataset = zip(data_xf, data_yf)

# モデル
m = Chain(
  Dense(1, 20, relu),
  Dense(20, 1, σ))

loss(x, y) = mse(m(x), y) 

opt = Descent()
Flux.train!(loss, params(m), dataset, opt)

Nt = 100
Xt = range(0, stop = pi, length = Nt)
input_xt = [[x] for x in Xt]
expect_yt = m.(input_xt)

Yt = collect(Iterators.flatten(expect_yt))
# 結果はTrackedRealという型に入ってくるため数字だけ抜き出す。
Yt2 = data.(Yt)

plot!(Xt, Yt2)
png("result.png")

결과





아무래도 좋은 이야기



그녀를 원하기 때문에 여성 소개하십시오.

좋은 웹페이지 즐겨찾기