범주형 분포 : 디리클레 분포 (베이즈 추정)

19526 단어 R

기사의 목적



범주 분포와 공액 사전 분포의 Dirichlet 분포를 사용하고 R을 사용하여 베이지안 추정을 수행합니다.
어떤 상품 A의 사이즈 S, 사이즈 M, 사이즈 L이 선택될 확률을 추정합니다.
참고 : 베이즈 추론에 의한 기계 학습 입문

목차



0. 모델 설명
1. 라이브러리
2. 추정할 분포
3. 사전 분포
4. 사후 분포
5. 예측 분포

0. 모델 설명





1. 라이브러리


library(dplyr)
library(MCMCpack)
library(ggplot2)
library(gganimate)
set.seed(100)

2. 추정할 분포



상품 A의 사이즈 S가 선택될 확률은 0.1, 사이즈 M이 선택될 확률은 0.6, 사이즈 L이 선택될 확률은 0.3입니다만, 우리는 그것을 모릅니다.
이 진정한 확률 0.1, 0.6, 0.3을 사후 분포로 추정합니다.
pi.true <- c(0.1, 0.6, 0.3)
NULL %>% ggplot(aes(x = c("Sサイズ", "Mサイズ", "Lサイズ"), y = pi.true)) + 
  geom_bar(stat = "identity") + ylim(0,1) + 
  labs(x="Sサイズ/Mサイズ/Lサイズ", y="選択確率", title="推定する分布") +
  scale_x_discrete(limits=c("Sサイズ", "Mサイズ", "Lサイズ"))



3. 사전 분포



사전분포로 범주형 분포의 공액 사전분포인 디리클레 분포를 지정합니다.
이 사전 분포는 어떤 확률이 나와도 이상하지 않은 상태를 나타냅니다.
K <- 3
alpha0 <- rep(1/K, K)
X.pre <- rdirichlet(1000, alpha0)[,1:2]
NULL %>% ggplot(aes(x = X.pre[,1], y = X.pre[,2])) + geom_point() +
  labs(x="Sサイズを選択する確率", y="Mサイズを選択する確率", title="事前分布")



4. 사후 분포



아래 그림은 진정한 분포에서 샘플 데이터를 점진적으로 늘리고 사후 분포를 추정하는 흐름을 보여줍니다.
최종적으로는 사이즈 s가 선택될 확률이 0.1, 사이즈 M이 선택될 확률이 0.6 부근으로 추정되어 있습니다.
#ハイパーパラメータの初期値設定
alpha <- alpha0
#可視化のためのデータ
Data <- rdirichlet(1000, alpha)[,1:2]
Data <- data.frame(Data, iter=rep(0,1000))
Data.pi <- data.frame(pi = alpha, iter=rep(0, 3))
#事後分布
for(t in 1:3){
  #データ発生
  X <- rmultinom(10^t, 1, pi.true) %>% apply(2, which.max)
  #パラメータ更新
  n <- X %>% factor() %>% summary()
  alpha <- n+alpha0
  #可視化用データ取得
  Data.tmp <- rdirichlet(1000, alpha)[,1:2]
  Data.tmp <- data.frame(Data.tmp, iter=rep(t,1000))
  Data <- rbind(Data, Data.tmp)
  Data.pi.tmp <- data.frame(pi = alpha/sum(alpha), iter=rep(t,3))
  Data.pi <- rbind(Data.pi, Data.pi.tmp)
}
#事後分布可視化
Data %>% ggplot(aes(x=X1, y=X2)) +
  geom_point(alpha=0.5, col="green") + 
  labs(x="Sサイズを選択する確率", y="Mサイズを選択する確率", title="事前分布") +
  transition_states(iter, transition_length = 2, state_length = 1)



5. 예측 분포



다음 그림은 사후 분포와 마찬가지로 추정 흐름을 나타냅니다.
궁극적으로는 확률의 값을 잘 예측할 수 있습니다.
Data.pi$label <- rep(c("Sサイズ", "Mサイズ", "Lサイズ"), 4)
Data.pi %>% ggplot(aes(x=label, y=pi)) +
  geom_bar(stat = "identity", alpha=0.5, fill="blue") +
  ylim(0,1)+
  labs(x="Sサイズ/Mサイズ/Lサイズ", y="選択確率", title="予測k分布") +
  scale_x_discrete(limits=c("Sサイズ", "Mサイズ", "Lサイズ")) +
  transition_states(iter, transition_length = 2, state_length = 1)

좋은 웹페이지 즐겨찾기