Datasaurus Dozen의 클래스 멤버십 예측 🦖

이것은 screencasts 패키지를 사용하는 방법을 보여 주는 제 tidymodels 시리즈 중 첫 번째 모델링 단계부터 시작하여 더 복잡한 모델 튜닝에 이르는 최신 시리즈입니다. 오늘의 스크린캐스트에서는 더 작은 데이터 세트를 사용하지만 이번 주 #TidyTuesday datasetDatasaurus Dozen을 사용하여 모델링에서 몇 가지 중요한 기술을 시험해 볼 수 있습니다.



다음은 비디오 대신 또는 비디오에 추가하여 읽기를 선호하는 사람들을 위해 비디오에서 사용한 코드입니다.

데이터 탐색



Datasaurus Dozen dataset은 요약 통계가 매우 유사하지만 플롯할 때 매우 다르게 보이는 x/y 데이터의 13개 세트 모음입니다. 우리의 모델링 목표는 각 포인트가 "다스"에 속하는 멤버를 예측하는 것입니다.

datasauRus 패키지에서 데이터를 읽는 것으로 시작하겠습니다.

library(tidyverse)
library(datasauRus)

datasaurus_dozen


## # A tibble: 1,846 x 3
## dataset x y
## <chr> <dbl> <dbl>
## 1 dino 55.4 97.2
## 2 dino 51.5 96.0
## 3 dino 46.2 94.5
## 4 dino 42.8 91.4
## 5 dino 40.8 88.3
## 6 dino 38.7 84.9
## 7 dino 35.6 79.9
## 8 dino 33.1 77.6
## 9 dino 29.0 74.5
## 10 dino 26.2 71.4
## # … with 1,836 more rows



이 데이터 세트는 서로 매우 다릅니다!

datasaurus_dozen %>%
  ggplot(aes(x, y, color = dataset)) +
  geom_point(show.legend = FALSE) +
  facet_wrap(~dataset, ncol = 5)





그러나 요약 통계는 매우 유사합니다.

datasaurus_dozen %>%
  group_by(dataset) %>%
  summarise(across(c(x, y), list(mean = mean, sd = sd)),
    x_y_cor = cor(x, y)
  )


## # A tibble: 13 x 6
## dataset x_mean x_sd y_mean y_sd x_y_cor
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 away 54.3 16.8 47.8 26.9 -0.0641
## 2 bullseye 54.3 16.8 47.8 26.9 -0.0686
## 3 circle 54.3 16.8 47.8 26.9 -0.0683
## 4 dino 54.3 16.8 47.8 26.9 -0.0645
## 5 dots 54.3 16.8 47.8 26.9 -0.0603
## 6 h_lines 54.3 16.8 47.8 26.9 -0.0617
## 7 high_lines 54.3 16.8 47.8 26.9 -0.0685
## 8 slant_down 54.3 16.8 47.8 26.9 -0.0690
## 9 slant_up 54.3 16.8 47.8 26.9 -0.0686
## 10 star 54.3 16.8 47.8 26.9 -0.0630
## 11 v_lines 54.3 16.8 47.8 26.9 -0.0694
## 12 wide_lines 54.3 16.8 47.8 26.9 -0.0666
## 13 x_shape 54.3 16.8 47.8 26.9 -0.0656



포인트가 속한 데이터 세트를 예측하기 위해 모델링을 사용할 수 있는지 살펴보겠습니다. 이것은 클래스 수(13!)에 비해 큰 데이터 세트가 아니므로 전체적으로 예측 모델링 워크플로에 대한 모범 사례를 보여주는 자습서는 아니지만 다중 클래스 모델을 평가하는 방법과 약간의 정보를 보여줍니다. 랜덤 포레스트 모델이 작동하는 방식에 대해 설명합니다.

datasaurus_dozen %>%
  count(dataset)


## # A tibble: 13 x 2
## dataset n
## <chr> <int>
## 1 away 142
## 2 bullseye 142
## 3 circle 142
## 4 dino 142
## 5 dots 142
## 6 h_lines 142
## 7 high_lines 142
## 8 slant_down 142
## 9 slant_up 142
## 10 star 142
## 11 v_lines 142
## 12 wide_lines 142
## 13 x_shape 142



모델 구축



Datasaurus Dozen의 부트스트랩 리샘플을 만들어 시작해 보겠습니다. 테스트 세트와 훈련 세트로 분할하지 않으므로 새 데이터에 대한 편향되지 않은 성능 추정치가 없습니다. 대신 이러한 리샘플링을 사용하여 데이터 세트와 다중 클래스 모델을 더 잘 이해할 것입니다.

library(tidymodels)

set.seed(123)
dino_folds <- datasaurus_dozen %>%
  mutate(dataset = factor(dataset)) %>%
  bootstraps()

dino_folds


## # Bootstrap sampling 
## # A tibble: 25 x 2
## splits id         
## <list> <chr>      
## 1 <split [1.8K/672]> Bootstrap01
## 2 <split [1.8K/689]> Bootstrap02
## 3 <split [1.8K/680]> Bootstrap03
## 4 <split [1.8K/674]> Bootstrap04
## 5 <split [1.8K/692]> Bootstrap05
## 6 <split [1.8K/689]> Bootstrap06
## 7 <split [1.8K/689]> Bootstrap07
## 8 <split [1.8K/695]> Bootstrap08
## 9 <split [1.8K/664]> Bootstrap09
## 10 <split [1.8K/671]> Bootstrap10
## # … with 15 more rows



랜덤 포레스트 모델을 만들고 모델 및 수식 전처리기로 모델 워크플로를 설정해 보겠습니다. 우리는 datasetx 에서 y 클래스(공룡 vs. 원 vs. 땡기 vs. …)를 예측하고 있습니다. 랜덤 포레스트 모델은 종종 예측 변수의 복잡한 상호 작용을 잘 학습할 수 있습니다.

rf_spec <- rand_forest(trees = 1000) %>%
  set_mode("classification") %>%
  set_engine("ranger")

dino_wf <- workflow() %>%
  add_formula(dataset ~ x + y) %>%
  add_model(rf_spec)

dino_wf


## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: rand_forest()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## dataset ~ x + y
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
## 
## Main Arguments:
## trees = 1000
## 
## Computational engine: ranger



랜덤 포레스트 모델을 부트스트랩 리샘플에 맞추겠습니다.

doParallel::registerDoParallel()
dino_rs <- fit_resamples(
  dino_wf,
  resamples = dino_folds,
  control = control_resamples(save_pred = TRUE)
)

dino_rs


## # Resampling results
## # Bootstrap sampling 
## # A tibble: 25 x 5
## splits id .metrics .notes .predictions     
## <list> <chr> <list> <list> <list>           
## 1 <split [1.8K/672… Bootstrap01 <tibble [2 × … <tibble [0 × … <tibble [672 × 1…
## 2 <split [1.8K/689… Bootstrap02 <tibble [2 × … <tibble [0 × … <tibble [689 × 1…
## 3 <split [1.8K/680… Bootstrap03 <tibble [2 × … <tibble [0 × … <tibble [680 × 1…
## 4 <split [1.8K/674… Bootstrap04 <tibble [2 × … <tibble [0 × … <tibble [674 × 1…
## 5 <split [1.8K/692… Bootstrap05 <tibble [2 × … <tibble [0 × … <tibble [692 × 1…
## 6 <split [1.8K/689… Bootstrap06 <tibble [2 × … <tibble [0 × … <tibble [689 × 1…
## 7 <split [1.8K/689… Bootstrap07 <tibble [2 × … <tibble [0 × … <tibble [689 × 1…
## 8 <split [1.8K/695… Bootstrap08 <tibble [2 × … <tibble [0 × … <tibble [695 × 1…
## 9 <split [1.8K/664… Bootstrap09 <tibble [2 × … <tibble [0 × … <tibble [664 × 1…
## 10 <split [1.8K/671… Bootstrap10 <tibble [2 × … <tibble [0 × … <tibble [671 × 1…
## # … with 15 more rows



우리는 해냈다!

모델 평가



이 모델들은 전반적으로 어떻게 작동했습니까?

collect_metrics(dino_rs)


## # A tibble: 2 x 5
## .metric .estimator mean n std_err
## <chr> <chr> <dbl> <int> <dbl>
## 1 accuracy multiclass 0.449 25 0.00337
## 2 roc_auc hand_till 0.846 25 0.00128



정확도는 좋지 않습니다. 이와 같은 다중 클래스 문제, 특히 너무 많은 클래스가 있는 문제는 이진 분류 문제보다 어렵습니다. 가능한 오답이 너무 많습니다!

예측을 save_pred = TRUE로 저장했으므로 다른 성능 지표를 계산할 수 있습니다. 기본적으로 양의 예측 값(예: 정확도)은 다중 클래스 문제에 대해 거시적으로 가중치가 적용됩니다.

dino_rs %>%
  collect_predictions() %>%
  group_by(id) %>%
  ppv(dataset, .pred_class)


## # A tibble: 25 x 4
## id .metric .estimator .estimate
## <chr> <chr> <chr> <dbl>
## 1 Bootstrap01 ppv macro 0.428
## 2 Bootstrap02 ppv macro 0.431
## 3 Bootstrap03 ppv macro 0.436
## 4 Bootstrap04 ppv macro 0.418
## 5 Bootstrap05 ppv macro 0.445
## 6 Bootstrap06 ppv macro 0.413
## 7 Bootstrap07 ppv macro 0.420
## 8 Bootstrap08 ppv macro 0.423
## 9 Bootstrap09 ppv macro 0.393
## 10 Bootstrap10 ppv macro 0.429
## # … with 15 more rows



다음으로 각 클래스에 대한 ROC 곡선을 계산해 보겠습니다.

dino_rs %>%
  collect_predictions() %>%
  group_by(id) %>%
  roc_curve(dataset, .pred_away:.pred_x_shape) %>%
  ggplot(aes(1 - specificity, sensitivity, color = id)) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  geom_path(show.legend = FALSE, alpha = 0.6, size = 1.2) +
  facet_wrap(~.level, ncol = 5) +
  coord_equal()





이 플롯에는 각 클래스와 각 리샘플에 대한 ROC 곡선이 있습니다. 포인트 데이터셋은 모델이 식별하기 쉬웠지만 디노 데이터셋은 매우 어려웠습니다. 모델은 공룡에 대해 추측하는 것보다 간신히 더 잘했습니다!

혼동 행렬을 계산할 수도 있습니다. tune::conf_mat_resampled()를 사용할 수 있지만 클래스당 예제가 너무 적고 클래스가 균형을 이루었으므로 모든 리샘플을 함께 살펴보겠습니다.

dino_rs %>%
  collect_predictions() %>%
  conf_mat(dataset, .pred_class)


## Truth
## Prediction away bullseye circle dino dots h_lines high_lines slant_down slant_up star v_lines wide_lines x_shape
## away 220 78 50 59 9 55 78 130 96 58 4 118 83
## bullseye 125 470 17 97 3 38 101 74 109 31 40 93 55
## circle 99 16 852 105 4 34 147 49 98 85 6 62 30
## dino 54 65 16 142 5 42 82 153 114 50 23 66 49
## dots 22 20 22 33 1221 39 57 47 34 15 11 28 16
## h_lines 52 81 37 60 26 897 37 42 54 34 4 56 36
## high_lines 111 105 69 145 8 27 381 95 125 58 34 73 77
## slant_down 137 55 24 158 10 30 69 318 114 33 41 89 27
## slant_up 81 82 37 144 1 30 64 107 264 30 13 96 49
## star 60 52 37 77 19 28 62 73 37 755 0 34 87
## v_lines 32 66 30 69 7 9 45 78 56 20 1133 32 14
## wide_lines 175 134 55 137 0 56 69 102 193 53 21 390 147
## x_shape 158 102 65 79 4 27 121 67 44 92 1 136 648



이러한 개수는 시각화에서 더 쉽게 이해할 수 있습니다.

dino_rs %>%
  collect_predictions() %>%
  conf_mat(dataset, .pred_class) %>%
  autoplot(type = "heatmap")





대각선에는 공룡에서 점까지 10배의 차이가 있는 실제 가변성이 있습니다.

대각선을 모두 0으로 설정하면 어떤 클래스가 서로 혼동될 가능성이 가장 높은지 확인할 수 있습니다.

dino_rs %>%
  collect_predictions() %>%
  filter(.pred_class != dataset) %>%
  conf_mat(dataset, .pred_class) %>%
  autoplot(type = "heatmap")





공룡 데이터 세트는 다른 많은 데이터 세트와 혼동되었으며 wide_lines는 종종 slant_up와 혼동되었습니다.

좋은 웹페이지 즐겨찾기