깔끔한 모델로 거대한 호박 무게 예측하기 🎃

이것은 screencasts 패키지를 사용하는 방법을 보여주는 tidymodels 시리즈의 최신작입니다. Tidymodels를 처음 시작했거나 패키지를 많이 사용한 사용자라면 our priorities for 2022에 대한 피드백에 관심이 있습니다. 작년에 실시한 설문조사는 결정을 내리는 데 매우 도움이 된 것으로 나타났습니다. 귀하의 의견에 다시 한 번 감사드립니다!

오늘의 스크린캐스트는 workflowsets , 여러 사전 처리/모델링 조합을 한 번에 처리하기 위한 tidymodels 패키지로 시작하는 사람에게 좋습니다. 이번 주 #TidyTuesday dataset는 대회의 거대한 호박에 관한 것입니다. 🥧



다음은 비디오 대신 또는 비디오에 추가하여 읽기를 선호하는 사람들을 위해 비디오에서 사용한 코드입니다.

데이터 탐색



우리의 모델링 목표는 경쟁 중에 측정된 다른 특성에서 giant pumpkins의 가중치를 예측하는 것입니다.

library(tidyverse)

pumpkins_raw <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-10-19/pumpkins.csv")

pumpkins <-
  pumpkins_raw %>%
  separate(id, into = c("year", "type")) %>%
  mutate(across(c(year, weight_lbs, ott, place), parse_number)) %>%
  filter(type == "P") %>%
  select(weight_lbs, year, place, ott, gpc_site, country)

pumpkins


## # A tibble: 15,965 × 6
## weight_lbs year place ott gpc_site country    
## <dbl> <dbl> <dbl> <dbl> <chr> <chr>      
## 1 2032 2013 1 475 Uesugi Farms Weigh-off United Sta…
## 2 1985 2013 2 453 Safeway World Championship Pumpkin … United Sta…
## 3 1894 2013 3 445 Safeway World Championship Pumpkin … United Sta…
## 4 1874. 2013 4 436 Elk Grove Giant Pumpkin Festival United Sta…
## 5 1813 2013 5 430 The Great Howard Dill Giant Pumpkin… Canada     
## 6 1791 2013 6 431 Elk Grove Giant Pumpkin Festival United Sta…
## 7 1784 2013 7 445 Uesugi Farms Weigh-off United Sta…
## 8 1784. 2013 8 434 Stillwater Harvestfest United Sta…
## 9 1780. 2013 9 422 Stillwater Harvestfest United Sta…
## 10 1766. 2013 10 425 Durham Fair Weigh-Off United Sta…
## # … with 15,955 more rows



여기에서 주요 관계는 호박의 부피/크기("over-the-top 인치"를 통해 측정)와 무게 사이입니다.

pumpkins %>%
  filter(ott > 20, ott < 1e3) %>%
  ggplot(aes(ott, weight_lbs, color = place)) +
  geom_point(alpha = 0.2, size = 1.1) +
  labs(x = "over-the-top inches", y = "weight (lbs)") +
  scale_color_viridis_c()





크고 무거운 호박은 당연히 대회 우승에 더 가까워졌습니다!

시간이 지남에 따라 이 관계에 어떤 변화가 있었습니까?

pumpkins %>%
  filter(ott > 20, ott < 1e3) %>%
  ggplot(aes(ott, weight_lbs)) +
  geom_point(alpha = 0.2, size = 1.1, color = "gray60") +
  geom_smooth(aes(color = factor(year)),
    method = lm, formula = y ~ splines::bs(x, 3),
    se = FALSE, size = 1.5, alpha = 0.6
  ) +
  labs(x = "over-the-top inches", y = "weight (lbs)", color = NULL) +
  scale_color_viridis_d()





말하기 어렵다고 생각합니다.

어느 나라에서 어느 정도 거대한 호박을 생산했습니까?

pumpkins %>%
  mutate(
    country = fct_lump(country, n = 10),
    country = fct_reorder(country, weight_lbs)
  ) %>%
  ggplot(aes(country, weight_lbs, color = country)) +
  geom_boxplot(outlier.colour = NA) +
  geom_jitter(alpha = 0.1, width = 0.15) +
  labs(x = NULL, y = "weight (lbs)") +
  theme(legend.position = "none")





워크플로 세트 구축 및 맞춤



"데이터 예산"을 설정하여 모델링을 시작하겠습니다. 우리는 결과weight_lbs로 계층화할 것입니다.

library(tidymodels)

set.seed(123)
pumpkin_split <- pumpkins %>%
  filter(ott > 20, ott < 1e3) %>%
  initial_split(strata = weight_lbs)

pumpkin_train <- training(pumpkin_split)
pumpkin_test <- testing(pumpkin_split)

set.seed(234)
pumpkin_folds <- vfold_cv(pumpkin_train, strata = weight_lbs)
pumpkin_folds


## # 10-fold cross-validation using stratification 
## # A tibble: 10 × 2
## splits id    
## <list> <chr> 
## 1 <split [8954/996]> Fold01
## 2 <split [8954/996]> Fold02
## 3 <split [8954/996]> Fold03
## 4 <split [8954/996]> Fold04
## 5 <split [8954/996]> Fold05
## 6 <split [8954/996]> Fold06
## 7 <split [8955/995]> Fold07
## 8 <split [8956/994]> Fold08
## 9 <split [8957/993]> Fold09
## 10 <split [8958/992]> Fold10



다음으로 세 가지 데이터 전처리 레시피를 만들어 보겠습니다. 하나는 자주 사용되지 않는 요인 수준만 풀링하고, 하나는 지표 변수도 생성하고, 마지막으로 오버더톱 인치에 대한 스플라인 항도 생성합니다.

base_rec <-
  recipe(weight_lbs ~ ott + year + country + gpc_site,
    data = pumpkin_train
  ) %>%
  step_other(country, gpc_site, threshold = 0.02)

ind_rec <-
  base_rec %>%
  step_dummy(all_nominal_predictors())

spline_rec <-
  ind_rec %>%
  step_bs(ott)



그런 다음 Random Forest 모델, MARS 모델 및 선형 모델의 세 가지 모델 사양을 만들어 보겠습니다.

rf_spec <-
  rand_forest(trees = 1e3) %>%
  set_mode("regression") %>%
  set_engine("ranger")

mars_spec <-
  mars() %>%
  set_mode("regression") %>%
  set_engine("earth")

lm_spec <- linear_reg()



이제 전처리와 모델을 workflow_set()에 함께 넣을 시간입니다.

pumpkin_set <-
  workflow_set(
    list(base_rec, ind_rec, spline_rec),
    list(rf_spec, mars_spec, lm_spec),
    cross = FALSE
  )

pumpkin_set


## # A workflow set/tibble: 3 × 4
## wflow_id info option result    
## <chr> <list> <list> <list>    
## 1 recipe_1_rand_forest <tibble [1 × 4]> <opts[0]> <list [0]>
## 2 recipe_2_mars <tibble [1 × 4]> <opts[0]> <list [0]>
## 3 recipe_3_linear_reg <tibble [1 × 4]> <opts[0]> <list [0]>


cross = FALSE를 사용하는 이유는 이러한 구성 요소의 모든 조합을 원하지 않고 세 가지 옵션만 시도하기를 원하기 때문입니다. 이러한 가능한 후보를 우리의 리샘플링에 맞춰서 어떤 후보가 가장 잘 수행되는지 확인합시다.

doParallel::registerDoParallel()
set.seed(2021)

pumpkin_rs <-
  workflow_map(
    pumpkin_set,
    "fit_resamples",
    resamples = pumpkin_folds
  )

pumpkin_rs


## # A workflow set/tibble: 3 × 4
## wflow_id info option result   
## <chr> <list> <list> <list>   
## 1 recipe_1_rand_forest <tibble [1 × 4]> <opts[1]> <rsmp[+]>
## 2 recipe_2_mars <tibble [1 × 4]> <opts[1]> <rsmp[+]>
## 3 recipe_3_linear_reg <tibble [1 × 4]> <opts[1]> <rsmp[+]>



워크플로우 세트 평가



세 명의 후보는 어떻게 되었습니까?

autoplot(pumpkin_rs)





세 가지 옵션 간에는 큰 차이가 없으며 스플라인 피쳐 엔지니어링을 사용한 선형 모델이 더 나을 수도 있습니다. 심플한 모델이라서 좋네요!

collect_metrics(pumpkin_rs)


## # A tibble: 6 × 9
## wflow_id .config preproc model .metric .estimator mean n std_err
## <chr> <chr> <chr> <chr> <chr> <chr> <dbl> <int> <dbl>
## 1 recipe_1_r… Preprocess… recipe rand_… rmse standard 86.1 10 1.10e+0
## 2 recipe_1_r… Preprocess… recipe rand_… rsq standard 0.969 10 9.97e-4
## 3 recipe_2_m… Preprocess… recipe mars rmse standard 83.8 10 1.92e+0
## 4 recipe_2_m… Preprocess… recipe mars rsq standard 0.969 10 1.67e-3
## 5 recipe_3_l… Preprocess… recipe linea… rmse standard 82.4 10 2.27e+0
## 6 recipe_3_l… Preprocess… recipe linea… rsq standard 0.970 10 1.97e-3



사용하려는 워크플로를 추출하여 교육 데이터에 맞출 수 있습니다.

final_fit <-
  extract_workflow(pumpkin_rs, "recipe_3_linear_reg") %>%
  fit(pumpkin_train)



우리는 predict(final_fit, pumpkin_test) 와 같은 테스트 데이터와 같이 예측하기 위해 이와 같은 개체를 사용하거나 모델 매개변수를 검사할 수 있습니다.

tidy(final_fit) %>%
  arrange(-abs(estimate))


## # A tibble: 15 × 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) -9731. 675. -14.4 1.30e- 46
## 2 ott_bs_3 2585. 25.6 101. 0        
## 3 ott_bs_2 450. 11.9 37.9 2.75e-293
## 4 ott_bs_1 -345. 36.3 -9.50 2.49e- 21
## 5 gpc_site_Ohio.Valley.Giant.Pumpkin.Gr… 21.1 7.80 2.70 6.89e- 3
## 6 country_United.States 11.9 5.66 2.11 3.53e- 2
## 7 gpc_site_Stillwater.Harvestfest 11.6 7.87 1.48 1.40e- 1
## 8 country_Germany -11.5 6.68 -1.71 8.64e- 2
## 9 country_other -10.7 6.33 -1.69 9.13e- 2
## 10 country_Canada 9.29 6.12 1.52 1.29e- 1
## 11 country_Italy 8.12 7.02 1.16 2.47e- 1
## 12 gpc_site_Elk.Grove.Giant.Pumpkin.Fest… -7.81 7.70 -1.01 3.10e- 1
## 13 year 4.89 0.334 14.6 5.03e- 48
## 14 gpc_site_Wiegemeisterschaft.Berlin.Br… 1.51 8.07 0.187 8.51e- 1
## 15 gpc_site_other 1.41 5.60 0.251 8.02e- 1



스플라인 항이 가장 중요하지만 특정 사이트와 국가에서 무게(위 또는 아래)를 예측할 수 있는 증거와 연도에 따라 호박이 더 무거워지는 작은 경향을 볼 수 있습니다.

잊지 말고 tidymodels survey for 2022 priorities 가져가세요!

좋은 웹페이지 즐겨찾기