orbital 0.4.0

  orbital, tidymodels

  Emil Hvitfeldt

We’re over the moon to announce the release of orbital 0.4.0. orbital lets you predict in databases using tidymodels workflows.

You can install it from CRAN with:

install.packages("orbital")

This blog post will cover the highlights, which are post processing support and the new show_query() method.

You can see a full list of changes in the release notes.

Post processing support

The biggest improvement in this version is that orbital() now works for supported tailor methods. See vignette for a list of all supported post-processors.

Let’s start by fitting a classification model on the penguins data set, using {xgboost} as the engine. We will be showcasing using an adjustment that only works on binary classification and will thus recode species to have levels "Adelie" and "not_Adelie".

penguins$species <- forcats::fct_recode(
 penguins$species,
 not_Adelie = "Chinstrap", not_Adelie = "Gentoo"
)

After we have modified the data, we set up a simple workflow, with a preprocessor using recipes and the model specification using parsnip.

We also set up a post processor using the tailor package. A single adjustment will be done by adding adjust_equivocal_zone(). This will apply an equivocal zone to our binary classification model. Stopping predictions that are too close to the thresholds by labeling them as "[EQ]". Setting the argument value = 0.2 means that any predictions with a predicted probability of between 0.3 and 0.7 will be predicted as "[EQ]" instead.

rec_spec <- recipe(species ~ ., data = penguins) |>
  step_unknown(all_nominal_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  step_impute_mean(all_numeric_predictors()) |>
  step_zv(all_predictors())

lr_spec <- boost_tree(tree_depth = 1, trees = 5) |>
  set_mode("classification") |>
  set_engine("xgboost")

tlr_spec <- tailor() |>
  adjust_equivocal_zone(value = 0.2)

wf_spec <- workflow(rec_spec, lr_spec, tlr_spec)
wf_fit <- fit(wf_spec, data = penguins)

With this fitted workflow object, we can call orbital() on it to create an orbital object. Notice that for adjust_equivocal_zone() to work, we need to set type = c("class", "prob") as both are required for the adjust_equivocal_zone() transformation.

orbital_obj <- orbital(wf_fit, type = c("class", "prob"))
orbital_obj
#> 
#> ── orbital Object ───────────────────────────────────────────────────────
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.92193, ...
#> • flipper_length_mm = dplyr::if_else(is.na(flipper_length_mm), 201 ...
#> • .pred_class = dplyr::case_when(1 - 1/(1 + exp(dplyr::case_when(b ...
#> • .pred_Adelie = 1 - 1/(1 + exp(dplyr::case_when(bill_length_mm < ...
#> • .pred_not_Adelie = 1 - (1 - 1/(1 + exp(dplyr::case_when(bill_len ...
#> • .pred_class = dplyr::case_when( .pred_Adelie > 0.5 + 0.2 ~ 'Adel ...
#> ─────────────────────────────────────────────────────────────────────────
#> 6 equations in total.

This object contains all the information that is needed to produce predictions. Which we can produce with predict().

preds <- predict(orbital_obj, penguins)
preds
#> # A tibble: 344 × 3
#>    .pred_class .pred_Adelie .pred_not_Adelie
#>    <chr>              <dbl>            <dbl>
#>  1 Adelie             0.845            0.155
#>  2 Adelie             0.845            0.155
#>  3 Adelie             0.845            0.155
#>  4 not_Adelie         0.291            0.709
#>  5 Adelie             0.845            0.155
#>  6 Adelie             0.845            0.155
#>  7 Adelie             0.845            0.155
#>  8 Adelie             0.845            0.155
#>  9 Adelie             0.845            0.155
#> 10 Adelie             0.845            0.155
#> # ℹ 334 more rows

The predictions are working; however, we don’t see any evidence that adjust_equivocal_zone() is working. A call to count() reveals that a couple of observation lands in the equivocal zone.

count(preds, .pred_class)
#> # A tibble: 3 × 2
#>   .pred_class     n
#>   <chr>       <int>
#> 1 Adelie        144
#> 2 [EQ]           15
#> 3 not_Adelie    185

And we can further verify that they are correct.

filter(preds, .pred_class == '[EQ]')
#> # A tibble: 15 × 3
#>    .pred_class .pred_Adelie .pred_not_Adelie
#>    <chr>              <dbl>            <dbl>
#>  1 [EQ]               0.483            0.517
#>  2 [EQ]               0.483            0.517
#>  3 [EQ]               0.483            0.517
#>  4 [EQ]               0.483            0.517
#>  5 [EQ]               0.483            0.517
#>  6 [EQ]               0.483            0.517
#>  7 [EQ]               0.483            0.517
#>  8 [EQ]               0.348            0.652
#>  9 [EQ]               0.348            0.652
#> 10 [EQ]               0.348            0.652
#> 11 [EQ]               0.348            0.652
#> 12 [EQ]               0.348            0.652
#> 13 [EQ]               0.483            0.517
#> 14 [EQ]               0.483            0.517
#> 15 [EQ]               0.483            0.517

New show_query method

One of the main purposes of orbital is to allow for predictions in databases.

library(DBI)
library(RSQLite)

con_sqlite <- dbConnect(SQLite(), path = ":memory:")
penguins_sqlite <- copy_to(con_sqlite, penguins, name = "penguins_table")

Having set up a database we could have used orbital_sql() to show what the SQL query would have looked like. For quick testing, the output isn’t immediately ready to be pasted into its own file due to the <SQL> fragments within the output.

The show_query() method has been implemented to see exactly what the generated SQL looks like.

show_query(orbital_obj, con_sqlite)
#> CASE WHEN ((`bill_length_mm` IS NULL)) THEN 43.9219298245614 WHEN NOT ((`bill_length_mm` IS NULL)) THEN `bill_length_mm` END AS bill_length_mm
#> CASE WHEN ((`flipper_length_mm` IS NULL)) THEN 201.0 WHEN NOT ((`flipper_length_mm` IS NULL)) THEN `flipper_length_mm` END AS flipper_length_mm
#> CASE
#> WHEN ((1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047))))) > 0.5) THEN 'Adelie'
#> ELSE 'not_Adelie'
#> END AS .pred_class
#> 1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047)))) AS .pred_Adelie
#> 1.0 - (1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047))))) AS .pred_not_Adelie
#> CASE
#> WHEN (`.pred_Adelie` > (0.5 + 0.2)) THEN 'Adelie'
#> WHEN (`.pred_Adelie` < (0.5 - 0.2)) THEN 'not_Adelie'
#> ELSE '[EQ]'
#> END AS .pred_class

Acknowledgements

A big thank you to all the people who have contributed to orbital since the release of v0.4.0:

@EmilHvitfeldt, @frankiethull, @jeroenjanssens, and @topepo.