We’re chuffed to announce the release of a new interface to validation splits in rsample 1.2.0 and tune 1.1.2. The rsample package makes it easy to create resamples for assessing model performance. The tune package facilitates hyperparameter tuning for the tidymodels packages.
You can install the new versions from CRAN with:
install.packages(c("rsample", "tune"))
This blog post will walk you through how to make a validation split and use it for tuning.
You can see a full list of changes in the release notes for rsample and tune.
Let’s start with loading the tidymodels package which will load, among others, both rsample and tune.
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
#> ✔ broom 1.0.5 ✔ recipes 1.0.7
#> ✔ dials 1.2.0 ✔ rsample 1.2.0
#> ✔ dplyr 1.1.2 ✔ tibble 3.2.1
#> ✔ ggplot2 3.4.3 ✔ tidyr 1.3.0
#> ✔ infer 1.0.4 ✔ tune 1.1.2
#> ✔ modeldata 1.2.0 ✔ workflows 1.1.3
#> ✔ parsnip 1.1.1 ✔ workflowsets 1.0.1
#> ✔ purrr 1.0.2 ✔ yardstick 1.2.0
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter() masks stats::filter()
#> ✖ dplyr::lag() masks stats::lag()
#> ✖ recipes::step() masks stats::step()
#> • Use suppressPackageStartupMessages() to eliminate package startup messages
The new functions
You can now make a three-way split of your data instead of doing a sequence of two binary splits.
initial_validation_split()
with variantsinitial_validation_time_split()
andgroup_initial_validation_split()
for the initial three-way splitvalidation_set()
to create therset
for tuning containing the analysis (= training) and assessment (= validation) settraining()
,validation()
, andtesting()
for access to the separate subsetslast_fit()
(andfit_best()
) now also work on the initial three-way split
The new functions in action
To illustrate how to use the new functions, we’ll replicate an analysis of childcare cost from a Tidy Tuesday done by Julia Silge in one of her screencasts.
We are modeling the median weekly price for school-aged kids in childcare centers mcsa
and are thus removing the other variables containing different variants of median prices (e.g., for different age groups). We are also removing the FIPS code identifying the county as we are including various characteristics of the counties instead of their ID.
library(readr)
#>
#> Attaching package: 'readr'
#> The following object is masked from 'package:yardstick':
#>
#> spec
#> The following object is masked from 'package:scales':
#>
#> col_factor
childcare_costs <- read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2023/2023-05-09/childcare_costs.csv')
#> Rows: 34567 Columns: 61
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> dbl (61): county_fips_code, study_year, unr_16, funr_16, munr_16, unr_20to64...
#>
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
childcare_costs <- childcare_costs |>
select(-matches("^mc_|^mfc")) |>
select(-county_fips_code) |>
drop_na()
glimpse(childcare_costs)
#> Rows: 23,593
#> Columns: 53
#> $ study_year <dbl> 2008, 2009, 2010, 2011, 2012, 2013, 2014, 20…
#> $ unr_16 <dbl> 5.42, 5.93, 6.21, 7.55, 8.60, 9.39, 8.50, 7.…
#> $ funr_16 <dbl> 4.41, 5.72, 5.57, 8.13, 8.88, 10.31, 9.18, 8…
#> $ munr_16 <dbl> 6.32, 6.11, 6.78, 7.03, 8.29, 8.56, 7.95, 6.…
#> $ unr_20to64 <dbl> 4.6, 4.8, 5.1, 6.2, 6.7, 7.3, 6.8, 5.9, 4.4,…
#> $ funr_20to64 <dbl> 3.5, 4.6, 4.6, 6.3, 6.4, 7.6, 6.8, 6.1, 4.6,…
#> $ munr_20to64 <dbl> 5.6, 5.0, 5.6, 6.1, 7.0, 7.0, 6.8, 5.9, 4.3,…
#> $ flfpr_20to64 <dbl> 68.9, 70.8, 71.3, 70.2, 70.6, 70.7, 69.9, 68…
#> $ flfpr_20to64_under6 <dbl> 66.9, 63.7, 67.0, 66.5, 67.1, 67.5, 65.2, 66…
#> $ flfpr_20to64_6to17 <dbl> 79.59, 78.41, 78.15, 77.62, 76.31, 75.91, 75…
#> $ flfpr_20to64_under6_6to17 <dbl> 60.81, 59.91, 59.71, 59.31, 58.30, 58.00, 57…
#> $ mlfpr_20to64 <dbl> 84.0, 86.2, 85.8, 85.7, 85.7, 85.0, 84.2, 82…
#> $ pr_f <dbl> 8.5, 7.5, 7.5, 7.4, 7.4, 8.3, 9.1, 9.3, 9.4,…
#> $ pr_p <dbl> 11.5, 10.3, 10.6, 10.9, 11.6, 12.1, 12.8, 12…
#> $ mhi_2018 <dbl> 58462.55, 60211.71, 61775.80, 60366.88, 5915…
#> $ me_2018 <dbl> 32710.60, 34688.16, 34740.84, 34564.32, 3432…
#> $ fme_2018 <dbl> 25156.25, 26852.67, 27391.08, 26727.68, 2796…
#> $ mme_2018 <dbl> 41436.80, 43865.64, 46155.24, 45333.12, 4427…
#> $ total_pop <dbl> 49744, 49584, 53155, 53944, 54590, 54907, 55…
#> $ one_race <dbl> 98.1, 98.6, 98.5, 98.5, 98.5, 98.6, 98.7, 98…
#> $ one_race_w <dbl> 78.9, 79.1, 79.1, 78.9, 78.9, 78.3, 78.0, 77…
#> $ one_race_b <dbl> 17.7, 17.9, 17.9, 18.1, 18.1, 18.4, 18.6, 18…
#> $ one_race_i <dbl> 0.4, 0.4, 0.3, 0.2, 0.3, 0.3, 0.4, 0.4, 0.4,…
#> $ one_race_a <dbl> 0.4, 0.6, 0.7, 0.7, 0.8, 1.0, 0.9, 1.0, 0.8,…
#> $ one_race_h <dbl> 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1,…
#> $ one_race_other <dbl> 0.7, 0.7, 0.6, 0.5, 0.4, 0.7, 0.7, 0.9, 1.4,…
#> $ two_races <dbl> 1.9, 1.4, 1.5, 1.5, 1.5, 1.4, 1.3, 1.6, 2.0,…
#> $ hispanic <dbl> 1.8, 2.0, 2.3, 2.4, 2.4, 2.5, 2.5, 2.6, 2.6,…
#> $ households <dbl> 18373, 18288, 19718, 19998, 19934, 20071, 20…
#> $ h_under6_both_work <dbl> 1543, 1475, 1569, 1695, 1714, 1532, 1557, 13…
#> $ h_under6_f_work <dbl> 970, 964, 1009, 1060, 938, 880, 1191, 1258, …
#> $ h_under6_m_work <dbl> 22, 16, 16, 106, 120, 161, 159, 211, 109, 10…
#> $ h_under6_single_m <dbl> 995, 1099, 1110, 1030, 1095, 1160, 954, 883,…
#> $ h_6to17_both_work <dbl> 4900, 5028, 5472, 5065, 4608, 4238, 4056, 40…
#> $ h_6to17_fwork <dbl> 1308, 1519, 1541, 1965, 1963, 1978, 2073, 20…
#> $ h_6to17_mwork <dbl> 114, 92, 113, 246, 284, 354, 373, 551, 322, …
#> $ h_6to17_single_m <dbl> 1966, 2305, 2377, 2299, 2644, 2522, 2269, 21…
#> $ emp_m <dbl> 27.40, 29.54, 29.33, 31.17, 32.13, 31.74, 32…
#> $ memp_m <dbl> 24.41, 26.07, 25.94, 26.97, 28.59, 27.44, 28…
#> $ femp_m <dbl> 30.68, 33.40, 33.06, 35.96, 36.09, 36.61, 37…
#> $ emp_service <dbl> 17.06, 15.81, 16.92, 16.18, 16.09, 16.72, 16…
#> $ memp_service <dbl> 15.53, 14.16, 15.09, 14.21, 14.71, 13.92, 13…
#> $ femp_service <dbl> 18.75, 17.64, 18.93, 18.42, 17.63, 19.89, 20…
#> $ emp_sales <dbl> 29.11, 28.75, 29.07, 27.56, 28.39, 27.22, 25…
#> $ memp_sales <dbl> 15.97, 17.51, 17.82, 17.74, 17.79, 17.38, 15…
#> $ femp_sales <dbl> 43.52, 41.25, 41.43, 38.76, 40.26, 38.36, 36…
#> $ emp_n <dbl> 13.21, 11.89, 11.57, 10.72, 9.02, 9.27, 9.38…
#> $ memp_n <dbl> 22.54, 20.30, 19.86, 18.28, 16.03, 16.79, 17…
#> $ femp_n <dbl> 2.99, 2.52, 2.45, 2.09, 1.19, 0.77, 0.58, 0.…
#> $ emp_p <dbl> 13.22, 14.02, 13.11, 14.38, 14.37, 15.04, 16…
#> $ memp_p <dbl> 21.55, 21.96, 21.28, 22.80, 22.88, 24.48, 24…
#> $ femp_p <dbl> 4.07, 5.19, 4.13, 4.77, 4.84, 4.36, 6.07, 7.…
#> $ mcsa <dbl> 80.92, 83.42, 85.92, 88.43, 90.93, 93.43, 95…
Even after omitting rows with missing values are we left with 23593 observations. That is plenty to work with! We are likely to get a reliable estimate of the model performance from a validation set without having to fit and evaluate the model multiple times, as with, for example, v-fold cross-validation.
We are creating a three-way split of the data into a training, a validation, and a test set with the new initial_validation_split()
function. We are stratifying based on our outcome mcsa
. The default of prop = c(0.6, 0.2)
means that 60% of the data gets allocated to the training set and 20% to the validation set - and the remaining 20% go into the test set.
set.seed(123)
childcare_split <- childcare_costs |>
initial_validation_split(strata = mcsa)
childcare_split
#> <Training/Validation/Testing/Total>
#> <14155/4718/4720/23593>
You can access the subsets of the data with the familiar training()
and testing()
as well as the new validation()
:
validation(childcare_split)
#> # A tibble: 4,718 × 53
#> study_year unr_16 funr_16 munr_16 unr_20to64 funr_20to64 munr_20to64
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 2013 9.39 10.3 8.56 7.3 7.6 7
#> 2 2011 13.0 12.4 13.6 13.2 12.4 13.9
#> 3 2008 3.85 4.4 3.43 3.7 3.9 3.6
#> 4 2015 8.31 11.8 5.69 7.8 11.7 4.9
#> 5 2015 7.67 6.92 8.27 7.6 6.7 8.3
#> 6 2016 5.95 6.33 5.66 5.7 5.9 5.5
#> 7 2009 10.7 15.9 7.06 8.7 16.8 2.9
#> 8 2010 11.2 15.2 7.89 10.9 14.7 7.8
#> 9 2013 15.0 17.0 13.4 15.2 18.1 13
#> 10 2014 17.4 16.3 18.2 17.2 17.7 16.9
#> # ℹ 4,708 more rows
#> # ℹ 46 more variables: flfpr_20to64 <dbl>, flfpr_20to64_under6 <dbl>,
#> # flfpr_20to64_6to17 <dbl>, flfpr_20to64_under6_6to17 <dbl>,
#> # mlfpr_20to64 <dbl>, pr_f <dbl>, pr_p <dbl>, mhi_2018 <dbl>, me_2018 <dbl>,
#> # fme_2018 <dbl>, mme_2018 <dbl>, total_pop <dbl>, one_race <dbl>,
#> # one_race_w <dbl>, one_race_b <dbl>, one_race_i <dbl>, one_race_a <dbl>,
#> # one_race_h <dbl>, one_race_other <dbl>, two_races <dbl>, hispanic <dbl>, …
You may want to extract the training data to do some exploratory data analysis but here we are going to rely on xgboost to figure out patterns in the data so we can breeze straight to tuning a model.
xgb_spec <-
boost_tree(
trees = 500,
min_n = tune(),
mtry = tune(),
stop_iter = tune(),
learn_rate = 0.01
) |>
set_engine("xgboost", validation = 0.2) |>
set_mode("regression")
xgb_wf <- workflow(mcsa ~ ., xgb_spec)
xgb_wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Formula
#> Model: boost_tree()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> mcsa ~ .
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#>
#> Main Arguments:
#> mtry = tune()
#> trees = 500
#> min_n = tune()
#> learn_rate = 0.01
#> stop_iter = tune()
#>
#> Engine-Specific Arguments:
#> validation = 0.2
#>
#> Computational engine: xgboost
We give this workflow object with the model specification to tune_grid()
to try multiple combinations of the hyperparameters we tagged for tuning (min_n
, mtry
, and stop_iter
).
During tuning, the model should not have access to the test data, only to the data used to fit the model (the analysis set) and the data used to assess the model (the assessment set). Each pair of analysis and assessment set forms a resample. For 10-fold cross-validation, we’d have 10 resamples. With a validation split, we have just one resample with the training set functioning as the analysis set and the validation set as the assessment set. The tidymodels tuning functions all expect a set of resamples (which can be of size one) and the corresponding objects are of class rset
.
To remove the test data from the initial three-way split and create such an rset
object for tuning, use validation_set()
.
set.seed(234)
childcare_set <- validation_set(childcare_split)
childcare_set
#> # A tibble: 1 × 2
#> splits id
#> <list> <chr>
#> 1 <split [14155/4718]> validation
We are going to try 15 different parameter combinations and pick the one with the smallest RMSE.
set.seed(234)
xgb_res <- tune_grid(xgb_wf, childcare_set, grid = 15)
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> Warning in `[.tbl_df`(x, is.finite(x <- as.numeric(x))): NAs introduced by coercion
best_parameters <- select_best(xgb_res, "rmse")
childcare_wflow <- finalize_workflow(xgb_wf, best_parameters)
last_fit()
then lets you fit your model on the training data and calculate performance on the test data. If you provide it with a three-way split, you can choose if you want your model to be fitted on the training data only or on the combination of training and validation set. You can specify this with the add_validation_set
argument.
childcare_fit <- last_fit(childcare_wflow, childcare_split, add_validation_set = TRUE)
collect_metrics(childcare_fit)
#> # A tibble: 2 × 4
#> .metric .estimator .estimate .config
#> <chr> <chr> <dbl> <chr>
#> 1 rmse standard 21.4 Preprocessor1_Model1
#> 2 rsq standard 0.610 Preprocessor1_Model1
This takes you through the important changes for validation sets in the tidymodels framework!
Acknowledgements
Many thanks to the people who contributed since the last releases!
For rsample: @afrogri37, @AngelFelizR, @bschneidr, @erictleung, @exsell-jc, @hfrick, @jrosell, @MasterLuke84, @MichaelChirico, @mikemahoney218, @rdavis120, @sametsoekel, @Shafi2016, @simonpcouch, @topepo, and @trevorcampbell.
For tune: @blechturm, @cphaarmeyer, @EmilHvitfeldt, @forecastingEDs, @hfrick, @kjbeath, @mikemahoney218, @rdavis120, @simonpcouch, and @topepo.