NumPyro의 미분방정식의 매개 변수에 근거하여 추정(스프링 질계)

29856 단어 Pythonjaxnumpyrotech

미분 방정식


지난번 소개된 감염증의 모델(SIR-모델)에 이어 이번에는 스프링체 저항계로 불리는 운동방정식에서 지난번과 같이 미분방정식의 매개 변수 추정을 하려고 한다.
https://zenn.dev/eota/articles/numpyro_ode_sir_model
기계적인 상황에서 측정 데이터는 대부분 일정한 샘플링 간격으로 얻어지기 때문에 미분방정식은 이산화되고 보통 상태 공간 모델로 처리하는 것이 비교적 편리하지만 측정 데이터가 반드시 등거리인 것은 아니다이 경우에도 효과적으로 활용할 수 있는 경우가 있다고 생각해 소개하기로 했다.

Mass-Spring-Dumper System


스프링 받침대 감진기계는 제어공학에서 흔히 볼 수 있는 예제로 테이프를 자르면 이런 느낌을 준다.
png
기본적으로 질점(Mass)과 스프링(Spring)과 감쇠기(Dumper)를 결합한 시스템이다.내가 가장 쉽게 연상할 수 있는 것은 자동차의 브래킷이라고 생각하지만, 이 용수철과 감쇠기의 조합에 따라 자동차의 탑승 감각에 큰 변화가 있을 것이다. (아마도 현대의 자동차는 내가 상상한 것보다 더 복잡할 것이다.)
스프링 저항계라고 불리는 이 미분방정식은 공식적으로 쓰면 다음과 같은 느낌을 준다.
\begin{aligned}
m\ddot{y}(t) + c\dot{y}(t) + k y(t) = f(t)
\end{aligned}
c 감쇠기의 계수에서 감쇠 계수라고 부른다.k는 용수철 상수다.이 상수들은 미지수이며 측정 데이터에서 이 상수들을 추측하는 것이 앞으로 해야 할 일이다.

Install Packages


넘피로는 JAX 버전에 다소 민감해 조금 오래된 JAX 장착 포장을 선택했다.
!pip install --upgrade jax==0.2.17 jaxlib==0.1.71+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install numpyro==0.7.2
!pip install arviz==0.11.2
!pip install japanize_matplotlib

Import Packages


JAX와 넘프로 등의 포장을 도입한다.
import jax
import jax.numpy as jnp
import jax.experimental.ode as ode

import numpyro
import numpyro.distributions as dist

import arviz as az
import numpy as np

import matplotlib.pyplot as plt
import japanize_matplotlib
plt.rcParams['font.size'] = 14
numpyro.set_platform('cpu')
numpyro.set_host_device_count(4)

Generate Data


이번에도 시뮬레이션을 통해 데이터를 생성할 것이다.또한 미분 방정식이 원형을 유지하면 odedint가 풀리지 않기 때문에 다음과 같은 몇 가지로 나뉜다.
\begin{aligned}
\dot{y}(t) &= v(t)\\
\dot{v}(t) &= -k y(t) - c\dot{v(t)} + f(t)\\
\end{aligned}
질점의 속도와 위치는 처음에는 0으로 설정되었으나, 최초의 2초는 1[N]의 힘으로 줄어들었다.그러나 외력이 다음과 같은 방식으로 가해지지 않으면 미분방정식에 잘 입력되지 않기 때문에 제어시스템의 해석 등에 사용하면 좀 실용적이지 못할 수 있다.
def f(t, F=1):
    
    return jnp.where(t < 2, F, 0)
def dz_dt(z, t, m, c, k):
    
    y = z[0]
    v = z[1]
    
    dy_dt = v
    dv_dt = (- k * y - c * v + f(t)) / m
        
    return jnp.stack([dy_dt, dv_dt])
m = 1.0 # 質量
c = 1.0 # 減衰係数
k = 1.0 # バネ定数

t_true = jnp.arange(0, 10, 0.1).astype(float)
z_init = jnp.array([0, 0]).astype(float)

z = ode.odeint(dz_dt, z_init, t_true, m, c, k)

y_true = z[:, 0]
v_true = z[:, 1]
plt.plot(t_true, y_true)
plt.title('Mass-Spring-Dumper System')
plt.xlabel('時間 [s]')
plt.ylabel('変位 [m]');
png
정방향에서 2초 전의 위치가 발생한 후에 다시 원래의 방향으로 돌아왔다.이 미분 방정식의 해에 소음을 약간 더해서 측정 데이터를 날조하다.
np.random.seed(0)

t_observed = t_true[::5]
y_observed = np.random.normal(y_true[::5], 0.05)
plt.plot(t_true, y_true)
plt.plot(t_observed, y_observed, 'o')

plt.title('Mass-Spring-Dumper System')
plt.xlabel('時間 [s]')
plt.ylabel('変位 [m]');
png
위에서 말한 바와 같이 용수철의 측정 데이터를 조작할 수 있다.

Define Model & Inference


다음에 이 측정 데이터로부터 미분 방정식의 매개 변수를 추측해 봅시다.아날로그로 데이터를 생성할 때, 우리는 이를 이미 알고 있는 감쇠계수 (c) 와 용수철 상수 (k) 의 매개 변수로 계산했지만, 이번에는 이 매개 변수들이 알 수 없는지, 반대로 데이터에서 얻을 수 있는지 고려할 것이다.반대로 생각하겠다는 얘기다.
NumPyro의 경우 모델이 함수로 결정됩니다.매개 변수에 미리 분포한 다음에 함수에 순서대로 데이터를 관측하기 전의 과정을 기술한다.
def model(t, y_observed=None):
    
    # 変位(y)と速度(v)の初期値に関する事前分布
    y_init = numpyro.sample('y_init', dist.Normal(0, 10))
    v_init = numpyro.sample('v_init', dist.Normal(0, 10))
    
    z_init = jnp.stack([y_init, v_init])
    
    # 減衰係数(c)とバネ定数(k)に関する事前分布
    c = numpyro.sample('c', dist.HalfNormal(10))
    k = numpyro.sample('k', dist.HalfNormal(10))
    
    # 微分方程式のソルバー
    z = ode.odeint(dz_dt, z_init, t, m, c, k)
    
    # 観測データからの尤度の計算
    sd_y = numpyro.sample('sigma', dist.HalfNormal(10))
    numpyro.sample('y', dist.Normal(z[:, 0], sd_y), obs=y_observed)
이어 마르코프 연쇄 몬테카로법(MCC)으로 불리는 방법을 사용해 역으로 파라미터를 추측한다.즉 위에서 매개 변수의 사전 확률 분포 모델을 설정할 수 있기 때문에 이번에는 데이터를 바탕으로 매개 변수의 사후 확률 분포(Bays 추정)를 계산한다.
사실 이 검증 확률 분포는 공식으로 직접 계산할 수 있었으면 좋겠지만, 이번 모델이 어려워 마르코프 프랜차이즈 몬테카로법(MCMC)으로 불리는 방법을 사용했다.마르코프 프랜차이즈 몬테칼로법(MCMC)은 후험 확률 분포에서 많은 양의 샘플을 생성하는 방법 중 하나로, 이를 통해 후험 확률 분포에 대한 다양한 정보를 조사할 수 있다.
%%time

nuts = numpyro.infer.NUTS(model, target_accept_prob=0.95)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=2000, num_samples=1000, num_chains=4, progress_bar=False)

mcmc.run(jax.random.PRNGKey(0), t_observed, y_observed=y_observed)
mcmc_samples = mcmc.get_samples()

idata = az.from_numpyro(mcmc)
CPU times: user 19.2 s, sys: 44 ms, total: 19.3 s
Wall time: 13.3 s
az.plot_trace(idata);
png
az.summary(idata)
mean
sd
hdi_3%
hdi_97%
mcse_mean
mcse_sd
ess_bulk
ess_tail
r_hat
0
c
1.072
0.055
0.97
1.176
0.001
0.001
2083
2204
1
1
k
0.985
0.028
0.934
1.04
0.001
0
2141
2420
1
2
sigma
0.047
0.009
0.032
0.064
0
0
1898
2058
1
3
v_init
0.113
0.084
-0.042
0.273
0.002
0.002
1488
2107
1
4
y_init
0.045
0.044
-0.038
0.127
0.001
0.001
1591
1771
1
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

az.plot_posterior(idata, var_names=['y_init', 'v_init', 'c', 'k'], ax=axes)

fig.subplots_adjust(hspace=0.4)
png

Check Prediction


이상의 계산을 통해 매개 변수의 백업 확률 분포를 알 수 있고 다음에 미분 방정식을 통해 실험 시 측정 값이 얼마나 편차가 있을지 예측할 수 있다.이 미분방정식의 경우 초기값을 결정하면 이 해는 거의 유일한 결정이지만 관측할 때 잡음 등이 있기 때문에 측정값의 예측 범위가 넓다.또 추정되는 미분방정식의 매개 변수도 확대(분포)돼 예측도 확대될 것으로 보인다.
t_pred = jnp.arange(0, 10, 0.1).astype(float)
predictive = numpyro.infer.Predictive(model, mcmc_samples)
ppc_samples = predictive(jax.random.PRNGKey(2), t_pred)

y_pred = ppc_samples['y']
mu_pred = jnp.mean(y_pred, 0)
pi_pred = jnp.percentile(y_pred, (5, 95), 0)
plt.figure(figsize=(8, 6))

plt.plot(t_observed, y_observed, 'o', color='C1', label='観測値')
plt.plot(t_true, y_true, '--', color='C2', label='真値')

plt.plot(t_pred, mu_pred, '-.', color='C0', label='予測値 (平均)')
plt.fill_between(t_pred, pi_pred[0, :], pi_pred[1, :], color='C0', alpha=0.2)

plt.title('事後予測分布')
plt.xlabel('時間 [s]')
plt.ylabel('変位 [m]')

plt.legend();
png
이 도표에는 얇은 남색 띠에 따라 후예측 분포의 양 끝을 각각 5%씩 자르는 베이즈 예측 구간(※)이 나와 있지만, 실제 측정 데이터도 대체로 이 주파수대에 수납돼 있다.
※ 또 용어 사용법과 관련해서는 송포 선생님의'스탠과 R의 베이즈 통계 모델링'(즉 오리본)의 2.5장'베이즈 신뢰 구간과 베이즈 예측 구간'을 참고했습니다.

Summary


이번에 우리는 제어 등 방면에서 자주 사용하는 용수철 감진기계의 미분 방정식을 통해 측정 데이터에서 미분 방정식의 파라미터를 추측했다.지난번에 감염증 모델(SIR-모델)의 매개 변수 추정을 해 보았지만, 어느 모델이든 비교적 간단한 절차로 매개 변수를 비교적 신속하게 추정할 수 있다고 생각합니다.
NumPyro는 Windows에서는 이동이 어렵지만 Google Colab 등이라면 간단하게 이동할 수 있으니 관심 있는 분들은 꼭 시도해 보세요.

관련 정보


https://note.com/ds_kotaro/n/n22a43c709bad

좋은 웹페이지 즐겨찾기