We’re over the moon to announce the release of orbital 0.4.0. orbital lets you predict in databases using tidymodels workflows.
You can install it from CRAN with:
install.packages("orbital")This blog post will cover the highlights, which are post processing support and the new show_query() method.
You can see a full list of changes in the release notes.
Post processing support
The biggest improvement in this version is that
orbital() now works for supported
tailor methods. See
vignette for a list of all supported post-processors.
Let’s start by fitting a classification model on the penguins data set, using {xgboost} as the engine. We will be showcasing using an adjustment that only works on binary classification and will thus recode species to have levels "Adelie" and "not_Adelie".
penguins$species <- forcats::fct_recode(
penguins$species,
not_Adelie = "Chinstrap", not_Adelie = "Gentoo"
)After we have modified the data, we set up a simple workflow, with a preprocessor using recipes and the model specification using parsnip.
We also set up a post processor using the tailor package. A single adjustment will be done by adding adjust_equivocal_zone(). This will apply an equivocal zone to our binary classification model. Stopping predictions that are too close to the thresholds by labeling them as "[EQ]". Setting the argument value = 0.2 means that any predictions with a predicted probability of between 0.3 and 0.7 will be predicted as "[EQ]" instead.
rec_spec <- recipe(species ~ ., data = penguins) |>
step_unknown(all_nominal_predictors()) |>
step_dummy(all_nominal_predictors()) |>
step_impute_mean(all_numeric_predictors()) |>
step_zv(all_predictors())
lr_spec <- boost_tree(tree_depth = 1, trees = 5) |>
set_mode("classification") |>
set_engine("xgboost")
tlr_spec <- tailor() |>
adjust_equivocal_zone(value = 0.2)
wf_spec <- workflow(rec_spec, lr_spec, tlr_spec)
wf_fit <- fit(wf_spec, data = penguins)With this fitted workflow object, we can call
orbital() on it to create an orbital object. Notice that for adjust_equivocal_zone() to work, we need to set type = c("class", "prob") as both are required for the adjust_equivocal_zone() transformation.
orbital_obj <- orbital(wf_fit, type = c("class", "prob"))
orbital_obj
#>
#> ── orbital Object ───────────────────────────────────────────────────────
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.92193, ...
#> • flipper_length_mm = dplyr::if_else(is.na(flipper_length_mm), 201 ...
#> • .pred_class = dplyr::case_when(1 - 1/(1 + exp(dplyr::case_when(b ...
#> • .pred_Adelie = 1 - 1/(1 + exp(dplyr::case_when(bill_length_mm < ...
#> • .pred_not_Adelie = 1 - (1 - 1/(1 + exp(dplyr::case_when(bill_len ...
#> • .pred_class = dplyr::case_when( .pred_Adelie > 0.5 + 0.2 ~ 'Adel ...
#> ─────────────────────────────────────────────────────────────────────────
#> 6 equations in total.
This object contains all the information that is needed to produce predictions. Which we can produce with
predict().
preds <- predict(orbital_obj, penguins)
preds
#> # A tibble: 344 × 3
#> .pred_class .pred_Adelie .pred_not_Adelie
#> <chr> <dbl> <dbl>
#> 1 Adelie 0.845 0.155
#> 2 Adelie 0.845 0.155
#> 3 Adelie 0.845 0.155
#> 4 not_Adelie 0.291 0.709
#> 5 Adelie 0.845 0.155
#> 6 Adelie 0.845 0.155
#> 7 Adelie 0.845 0.155
#> 8 Adelie 0.845 0.155
#> 9 Adelie 0.845 0.155
#> 10 Adelie 0.845 0.155
#> # ℹ 334 more rows
The predictions are working; however, we don’t see any evidence that adjust_equivocal_zone() is working. A call to count() reveals that a couple of observation lands in the equivocal zone.
count(preds, .pred_class)
#> # A tibble: 3 × 2
#> .pred_class n
#> <chr> <int>
#> 1 Adelie 144
#> 2 [EQ] 15
#> 3 not_Adelie 185
And we can further verify that they are correct.
filter(preds, .pred_class == '[EQ]')
#> # A tibble: 15 × 3
#> .pred_class .pred_Adelie .pred_not_Adelie
#> <chr> <dbl> <dbl>
#> 1 [EQ] 0.483 0.517
#> 2 [EQ] 0.483 0.517
#> 3 [EQ] 0.483 0.517
#> 4 [EQ] 0.483 0.517
#> 5 [EQ] 0.483 0.517
#> 6 [EQ] 0.483 0.517
#> 7 [EQ] 0.483 0.517
#> 8 [EQ] 0.348 0.652
#> 9 [EQ] 0.348 0.652
#> 10 [EQ] 0.348 0.652
#> 11 [EQ] 0.348 0.652
#> 12 [EQ] 0.348 0.652
#> 13 [EQ] 0.483 0.517
#> 14 [EQ] 0.483 0.517
#> 15 [EQ] 0.483 0.517
New show_query method
One of the main purposes of orbital is to allow for predictions in databases.
library(DBI)
library(RSQLite)
con_sqlite <- dbConnect(SQLite(), path = ":memory:")
penguins_sqlite <- copy_to(con_sqlite, penguins, name = "penguins_table")Having set up a database we could have used
orbital_sql() to show what the SQL query would have looked like. For quick testing, the output isn’t immediately ready to be pasted into its own file due to the <SQL> fragments within the output.
The show_query() method has been implemented to see exactly what the generated SQL looks like.
show_query(orbital_obj, con_sqlite)
#> CASE WHEN ((`bill_length_mm` IS NULL)) THEN 43.9219298245614 WHEN NOT ((`bill_length_mm` IS NULL)) THEN `bill_length_mm` END AS bill_length_mm
#> CASE WHEN ((`flipper_length_mm` IS NULL)) THEN 201.0 WHEN NOT ((`flipper_length_mm` IS NULL)) THEN `flipper_length_mm` END AS flipper_length_mm
#> CASE
#> WHEN ((1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047))))) > 0.5) THEN 'Adelie'
#> ELSE 'not_Adelie'
#> END AS .pred_class
#> 1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047)))) AS .pred_Adelie
#> 1.0 - (1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047))))) AS .pred_not_Adelie
#> CASE
#> WHEN (`.pred_Adelie` > (0.5 + 0.2)) THEN 'Adelie'
#> WHEN (`.pred_Adelie` < (0.5 - 0.2)) THEN 'not_Adelie'
#> ELSE '[EQ]'
#> END AS .pred_class
Acknowledgements
A big thank you to all the people who have contributed to orbital since the release of v0.4.0:
@EmilHvitfeldt, @frankiethull, @jeroenjanssens, and @topepo.