R에서 베이즈 선형 회귀

18256 단어 R

기사의 목적



선형 회귀 파라미터를 베이즈 추정합니다.
참고 : 베이즈 추론에 의한 기계 학습 입문

목차



0. 모델 설명
1. 라이브러리
2. 추정할 분포
3. 사전 분포
4. 사후 분포
5. 예측 분포

0. 모델 설명





1. 라이브러리


library(dplyr)
library(mvtnorm)
library(scatterplot3d)
set.seed(100)

2. 추정할 분포



y=w1*x1+w2*x2를 생각합니다. x1은 체중, x2는 키, y는 달의 식비로 가정합니다. w1, w2를 500,100으로 하고, 이것을 추정하는 것이 목적입니다.
여기서 체중은 평균 50으로 분산 10, 신장은 평균 170으로 분산 10의 정규 분포를 가정하고 있습니다.
체중은 인과관계가 반대인 것 같지만 신경쓰지 마세요. (웃음)
w <- c(500, 100)
lambda <- 0.01
X.true <- rmvnorm(1000, c(50, 170), diag(2)*100)
y.true <- rnorm(1000, w[1]*X.true[,1] + w[2]*X.true[,2], 1/sqrt(lambda))
scatterplot3d(X.true[,1], X.true[, 2], y.true/10000,
              xlab="体重", ylab="身長", zlab="食費(万)",
              xlim=c(10,80), ylim=c(120, 210))



3. 사전 분포



w1,w2의 사전분포로서 각각 표준정규분포를 가정합니다.
m0 <- c(0, 0)
lambda0 <- diag(2)
w.pre <- rmvnorm(1000, m0, solve(lambda0))
plot(w.pre[,1], w.pre[,2], xlab="体重の影響", ylab="身長の影響")



4. 사후 분포



아래 그림에서 500, 100을 잘 추정할 수 있음을 알 수 있습니다.
X <- rmvnorm(100, c(50, 170), diag(2)*100)
y <- rnorm(100, w[1]*X[,1] + w[2]*X[,2], 1/sqrt(lambda))
sum.tmp <- 0
for(i in 1:nrow(X)){
  sum.tmp <- sum.tmp + X[i, ] %*% t(X[i, ])
}
lambda.post <- lambda*sum.tmp + solve(lambda0)
m.post <- solve(lambda.post) %*% (lambda*apply(y*X,2,sum) + lambda0%*%m0)
w.post <- rmvnorm(1000, m.post, solve(lambda.post))
plot(w.post[,1], w.post[,2], xlab="体重の影響", ylab="身長の影響",
     xlim=c(400,600), ylim=c(50, 150), col="green")



5. 예측 분포



예측 분포도 진정한 분포와 비슷하며 잘 추정된다는 것을 알 수 있습니다.
X.sample <- rmvnorm(1000, c(50, 170), diag(2)*100)
y.sample <- {}
for(i in 1:1000){
  u.post <- t(m.post) %*% X.sample[i, ]
  lambda.post.inv <- 1/lambda + t(X.sample[i, ]) %*% solve(lambda.post) %*% X.sample[i,]
  y.sample[i] <- rnorm(1, u.post, sqrt(lambda.post.inv))
}
scatterplot3d(X.sample[,1], X.sample[, 2], y.sample/10000,
              xlab="体重", ylab="身長", zlab="食費(万)",
              xlim=c(10,80), ylim=c(120, 210), color="blue")

좋은 웹페이지 즐겨찾기