orbital 0.3.0

  orbital, tidymodels

  Emil Hvitfeldt

We’re thrilled to announce the release of orbital 0.3.0. orbital lets you predict in databases using tidymodels workflows.

You can install it from CRAN with:

install.packages("orbital")

This blog post will cover the highlights, which are classification support and the new augment method.

You can see a full list of changes in the release notes.

Classification support

The biggest improvement in this version is that orbital() now works for supported classification models. See vignette for list of all supported models.

Let’s start by fitting a classification model on the penguins data set, using {xgboost} as the engine.

rec_spec <- recipe(species ~ ., data = penguins) |>
  step_unknown(all_nominal_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  step_impute_mean(all_numeric_predictors()) |>
  step_zv(all_predictors())

lr_spec <- boost_tree() |>
  set_mode("classification") |>
  set_engine("xgboost")

wf_spec <- workflow(rec_spec, lr_spec)
wf_fit <- fit(wf_spec, data = penguins)

With this fitted workflow object, we can call orbital() on it to create an orbital object.

orbital_obj <- orbital(wf_fit)
orbital_obj
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • island = dplyr::if_else(is.na(island), "unknown", island)
#> • sex = dplyr::if_else(is.na(sex), "unknown", sex)
#> • island_Dream = as.numeric(island == "Dream")
#> • island_Torgersen = as.numeric(island == "Torgersen")
#> • sex_male = as.numeric(sex == "male")
#> • sex_unknown = as.numeric(sex == "unknown")
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.92193, bill_l ...
#> • bill_depth_mm = dplyr::if_else(is.na(bill_depth_mm), 17.15117, bill_dep ...
#> • flipper_length_mm = dplyr::if_else(is.na(flipper_length_mm), 201, flipp ...
#> • body_mass_g = dplyr::if_else(is.na(body_mass_g), 4202, body_mass_g)
#> • island_Dream = dplyr::if_else(is.na(island_Dream), 0.3604651, island_Dr ...
#> • island_Torgersen = dplyr::if_else(is.na(island_Torgersen), 0.1511628, i ...
#> • sex_male = dplyr::if_else(is.na(sex_male), 0.4883721, sex_male)
#> • sex_unknown = dplyr::if_else(is.na(sex_unknown), 0.03197674, sex_unknow ...
#> • Adelie = 0 + dplyr::case_when((bill_depth_mm < 15.1 | is.na(bill_depth_ ...
#> • Chinstrap = 0 + dplyr::case_when((island_Dream < 0.5 | is.na(island_Dre ...
#> • Gentoo = 0 + dplyr::case_when((bill_depth_mm < 15.95 | is.na(bill_depth ...
#> • .pred_class = dplyr::case_when(Adelie > Chinstrap & Adelie > Gentoo ~ " ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 18 equations in total.

This object contains all the information that is needed to produce predictions. Which we can produce with predict().

predict(orbital_obj, penguins)
#> # A tibble: 344 × 1
#>    .pred_class
#>    <chr>      
#>  1 Adelie     
#>  2 Adelie     
#>  3 Adelie     
#>  4 Adelie     
#>  5 Adelie     
#>  6 Adelie     
#>  7 Adelie     
#>  8 Adelie     
#>  9 Adelie     
#> 10 Adelie     
#> # ℹ 334 more rows

The main thing to note here is that the orbital package produces character vectors instead of factors. This is done as a unifying approach since many databases don’t have factor types.

Speaking of databases, you can predict() on an orbital object using tables from databases. Below we create an ephemeral in-memory RSQLite database.

library(DBI)
library(RSQLite)

con_sqlite <- dbConnect(SQLite(), path = ":memory:")
penguins_sqlite <- copy_to(con_sqlite, penguins, name = "penguins_table")

And we can predict with it like normal. All the calculations are sent to the database for execution.

predict(orbital_obj, penguins_sqlite)
#> # Source:   SQL [?? x 1]
#> # Database: sqlite 3.47.1 []
#>    .pred_class
#>    <chr>      
#>  1 Adelie     
#>  2 Adelie     
#>  3 Adelie     
#>  4 Adelie     
#>  5 Adelie     
#>  6 Adelie     
#>  7 Adelie     
#>  8 Adelie     
#>  9 Adelie     
#> 10 Adelie     
#> # ℹ more rows

This works the same with many types of databases.

Classification is different from regression in part because it comes with multiple prediction types. The above example showed the default which is hard classification. You can set the type of prediction you want with the type argument to orbital. For classification models, possible options are "class" and "prob".

orbital_obj_prob <- orbital(wf_fit, type = c("class", "prob"))
orbital_obj_prob
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • island = dplyr::if_else(is.na(island), "unknown", island)
#> • sex = dplyr::if_else(is.na(sex), "unknown", sex)
#> • island_Dream = as.numeric(island == "Dream")
#> • island_Torgersen = as.numeric(island == "Torgersen")
#> • sex_male = as.numeric(sex == "male")
#> • sex_unknown = as.numeric(sex == "unknown")
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.92193, bill_l ...
#> • bill_depth_mm = dplyr::if_else(is.na(bill_depth_mm), 17.15117, bill_dep ...
#> • flipper_length_mm = dplyr::if_else(is.na(flipper_length_mm), 201, flipp ...
#> • body_mass_g = dplyr::if_else(is.na(body_mass_g), 4202, body_mass_g)
#> • island_Dream = dplyr::if_else(is.na(island_Dream), 0.3604651, island_Dr ...
#> • island_Torgersen = dplyr::if_else(is.na(island_Torgersen), 0.1511628, i ...
#> • sex_male = dplyr::if_else(is.na(sex_male), 0.4883721, sex_male)
#> • sex_unknown = dplyr::if_else(is.na(sex_unknown), 0.03197674, sex_unknow ...
#> • Adelie = 0 + dplyr::case_when((bill_depth_mm < 15.1 | is.na(bill_depth_ ...
#> • Chinstrap = 0 + dplyr::case_when((island_Dream < 0.5 | is.na(island_Dre ...
#> • Gentoo = 0 + dplyr::case_when((bill_depth_mm < 15.95 | is.na(bill_depth ...
#> • .pred_class = dplyr::case_when(Adelie > Chinstrap & Adelie > Gentoo ~ " ...
#> • norm = exp(Adelie) + exp(Chinstrap) + exp(Gentoo)
#> • .pred_Adelie = exp(Adelie) / norm
#> • .pred_Chinstrap = exp(Chinstrap) / norm
#> • .pred_Gentoo = exp(Gentoo) / norm
#> ────────────────────────────────────────────────────────────────────────────────
#> 22 equations in total.

Notice how we can select both "class" and "prob". The predictions now include both hard and soft class predictions.

predict(orbital_obj_prob, penguins)
#> # A tibble: 344 × 4
#>    .pred_class .pred_Adelie .pred_Chinstrap .pred_Gentoo
#>    <chr>              <dbl>           <dbl>        <dbl>
#>  1 Adelie             0.989         0.00554      0.00560
#>  2 Adelie             0.989         0.00554      0.00560
#>  3 Adelie             0.989         0.00554      0.00560
#>  4 Adelie             0.709         0.0245       0.267  
#>  5 Adelie             0.989         0.00554      0.00560
#>  6 Adelie             0.989         0.00554      0.00560
#>  7 Adelie             0.989         0.00554      0.00560
#>  8 Adelie             0.989         0.00554      0.00560
#>  9 Adelie             0.979         0.00549      0.0158 
#> 10 Adelie             0.980         0.00559      0.0148 
#> # ℹ 334 more rows

That works equally well in databases.

predict(orbital_obj_prob, penguins_sqlite)
#> # Source:   SQL [?? x 4]
#> # Database: sqlite 3.47.1 []
#>    .pred_class .pred_Adelie .pred_Chinstrap .pred_Gentoo
#>    <chr>              <dbl>           <dbl>        <dbl>
#>  1 Adelie             0.989         0.00554      0.00560
#>  2 Adelie             0.989         0.00554      0.00560
#>  3 Adelie             0.989         0.00554      0.00560
#>  4 Adelie             0.709         0.0245       0.267  
#>  5 Adelie             0.989         0.00554      0.00560
#>  6 Adelie             0.989         0.00554      0.00560
#>  7 Adelie             0.989         0.00554      0.00560
#>  8 Adelie             0.989         0.00554      0.00560
#>  9 Adelie             0.979         0.00549      0.0158 
#> 10 Adelie             0.980         0.00559      0.0148 
#> # ℹ more rows

New augment method

The users of tidymodels have found the augment() function to be a handy tool. This function performs predictions and returns them alongside the original data set.

This release adds augment() support for orbital objects.

augment(orbital_obj, penguins)
#> # A tibble: 344 × 8
#>    .pred_class species island    bill_length_mm bill_depth_mm flipper_length_mm
#>    <chr>       <fct>   <fct>              <dbl>         <dbl>             <int>
#>  1 Adelie      Adelie  Torgersen           39.1          18.7               181
#>  2 Adelie      Adelie  Torgersen           39.5          17.4               186
#>  3 Adelie      Adelie  Torgersen           40.3          18                 195
#>  4 Adelie      Adelie  Torgersen           NA            NA                  NA
#>  5 Adelie      Adelie  Torgersen           36.7          19.3               193
#>  6 Adelie      Adelie  Torgersen           39.3          20.6               190
#>  7 Adelie      Adelie  Torgersen           38.9          17.8               181
#>  8 Adelie      Adelie  Torgersen           39.2          19.6               195
#>  9 Adelie      Adelie  Torgersen           34.1          18.1               193
#> 10 Adelie      Adelie  Torgersen           42            20.2               190
#> # ℹ 334 more rows
#> # ℹ 2 more variables: body_mass_g <int>, sex <fct>

The function works for most databases, but for technical reasons doesn’t work with all. It has been confirmed to not work work in spark databases or arrow tables.

augment(orbital_obj, penguins_sqlite)
#> # Source:   SQL [?? x 8]
#> # Database: sqlite 3.47.1 []
#>    .pred_class species island    bill_length_mm bill_depth_mm flipper_length_mm
#>    <chr>       <chr>   <chr>              <dbl>         <dbl>             <int>
#>  1 Adelie      Adelie  Torgersen           39.1          18.7               181
#>  2 Adelie      Adelie  Torgersen           39.5          17.4               186
#>  3 Adelie      Adelie  Torgersen           40.3          18                 195
#>  4 Adelie      Adelie  Torgersen           NA            NA                  NA
#>  5 Adelie      Adelie  Torgersen           36.7          19.3               193
#>  6 Adelie      Adelie  Torgersen           39.3          20.6               190
#>  7 Adelie      Adelie  Torgersen           38.9          17.8               181
#>  8 Adelie      Adelie  Torgersen           39.2          19.6               195
#>  9 Adelie      Adelie  Torgersen           34.1          18.1               193
#> 10 Adelie      Adelie  Torgersen           42            20.2               190
#> # ℹ more rows
#> # ℹ 2 more variables: body_mass_g <int>, sex <chr>

Acknowledgements

A big thank you to all the people who have contributed to orbital since the release of v0.3.0:

@EmilHvitfeldt, @joscani, @jrosell, @npelikan, and @szimmer.