We’re thrilled to announce the release of orbital 0.3.0. orbital lets you predict in databases using tidymodels workflows.
You can install it from CRAN with:
install.packages("orbital")
This blog post will cover the highlights, which are classification support and the new augment method.
You can see a full list of changes in the release notes.
Classification support
The biggest improvement in this version is that
orbital()
now works for supported classification models. See
vignette for list of all supported models.
Let’s start by fitting a classification model on the penguins
data set, using {xgboost} as the engine.
rec_spec <- recipe(species ~ ., data = penguins) |>
step_unknown(all_nominal_predictors()) |>
step_dummy(all_nominal_predictors()) |>
step_impute_mean(all_numeric_predictors()) |>
step_zv(all_predictors())
lr_spec <- boost_tree() |>
set_mode("classification") |>
set_engine("xgboost")
wf_spec <- workflow(rec_spec, lr_spec)
wf_fit <- fit(wf_spec, data = penguins)
With this fitted workflow object, we can call
orbital()
on it to create an orbital object.
orbital_obj <- orbital(wf_fit)
orbital_obj
#>
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • island = dplyr::if_else(is.na(island), "unknown", island)
#> • sex = dplyr::if_else(is.na(sex), "unknown", sex)
#> • island_Dream = as.numeric(island == "Dream")
#> • island_Torgersen = as.numeric(island == "Torgersen")
#> • sex_male = as.numeric(sex == "male")
#> • sex_unknown = as.numeric(sex == "unknown")
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.92193, bill_l ...
#> • bill_depth_mm = dplyr::if_else(is.na(bill_depth_mm), 17.15117, bill_dep ...
#> • flipper_length_mm = dplyr::if_else(is.na(flipper_length_mm), 201, flipp ...
#> • body_mass_g = dplyr::if_else(is.na(body_mass_g), 4202, body_mass_g)
#> • island_Dream = dplyr::if_else(is.na(island_Dream), 0.3604651, island_Dr ...
#> • island_Torgersen = dplyr::if_else(is.na(island_Torgersen), 0.1511628, i ...
#> • sex_male = dplyr::if_else(is.na(sex_male), 0.4883721, sex_male)
#> • sex_unknown = dplyr::if_else(is.na(sex_unknown), 0.03197674, sex_unknow ...
#> • Adelie = 0 + dplyr::case_when((bill_depth_mm < 15.1 | is.na(bill_depth_ ...
#> • Chinstrap = 0 + dplyr::case_when((island_Dream < 0.5 | is.na(island_Dre ...
#> • Gentoo = 0 + dplyr::case_when((bill_depth_mm < 15.95 | is.na(bill_depth ...
#> • .pred_class = dplyr::case_when(Adelie > Chinstrap & Adelie > Gentoo ~ " ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 18 equations in total.
This object contains all the information that is needed to produce predictions. Which we can produce with
predict()
.
predict(orbital_obj, penguins)
#> # A tibble: 344 × 1
#> .pred_class
#> <chr>
#> 1 Adelie
#> 2 Adelie
#> 3 Adelie
#> 4 Adelie
#> 5 Adelie
#> 6 Adelie
#> 7 Adelie
#> 8 Adelie
#> 9 Adelie
#> 10 Adelie
#> # ℹ 334 more rows
The main thing to note here is that the orbital package produces character vectors instead of factors. This is done as a unifying approach since many databases don’t have factor types.
Speaking of databases, you can
predict()
on an orbital object using tables from databases. Below we create an ephemeral in-memory RSQLite database.
library(DBI)
library(RSQLite)
con_sqlite <- dbConnect(SQLite(), path = ":memory:")
penguins_sqlite <- copy_to(con_sqlite, penguins, name = "penguins_table")
And we can predict with it like normal. All the calculations are sent to the database for execution.
predict(orbital_obj, penguins_sqlite)
#> # Source: SQL [?? x 1]
#> # Database: sqlite 3.47.1 []
#> .pred_class
#> <chr>
#> 1 Adelie
#> 2 Adelie
#> 3 Adelie
#> 4 Adelie
#> 5 Adelie
#> 6 Adelie
#> 7 Adelie
#> 8 Adelie
#> 9 Adelie
#> 10 Adelie
#> # ℹ more rows
This works the same with many types of databases.
Classification is different from regression in part because it comes with multiple prediction types. The above example showed the default which is hard classification. You can set the type of prediction you want with the type
argument to orbital
. For classification models, possible options are "class"
and "prob"
.
orbital_obj_prob <- orbital(wf_fit, type = c("class", "prob"))
orbital_obj_prob
#>
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • island = dplyr::if_else(is.na(island), "unknown", island)
#> • sex = dplyr::if_else(is.na(sex), "unknown", sex)
#> • island_Dream = as.numeric(island == "Dream")
#> • island_Torgersen = as.numeric(island == "Torgersen")
#> • sex_male = as.numeric(sex == "male")
#> • sex_unknown = as.numeric(sex == "unknown")
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.92193, bill_l ...
#> • bill_depth_mm = dplyr::if_else(is.na(bill_depth_mm), 17.15117, bill_dep ...
#> • flipper_length_mm = dplyr::if_else(is.na(flipper_length_mm), 201, flipp ...
#> • body_mass_g = dplyr::if_else(is.na(body_mass_g), 4202, body_mass_g)
#> • island_Dream = dplyr::if_else(is.na(island_Dream), 0.3604651, island_Dr ...
#> • island_Torgersen = dplyr::if_else(is.na(island_Torgersen), 0.1511628, i ...
#> • sex_male = dplyr::if_else(is.na(sex_male), 0.4883721, sex_male)
#> • sex_unknown = dplyr::if_else(is.na(sex_unknown), 0.03197674, sex_unknow ...
#> • Adelie = 0 + dplyr::case_when((bill_depth_mm < 15.1 | is.na(bill_depth_ ...
#> • Chinstrap = 0 + dplyr::case_when((island_Dream < 0.5 | is.na(island_Dre ...
#> • Gentoo = 0 + dplyr::case_when((bill_depth_mm < 15.95 | is.na(bill_depth ...
#> • .pred_class = dplyr::case_when(Adelie > Chinstrap & Adelie > Gentoo ~ " ...
#> • norm = exp(Adelie) + exp(Chinstrap) + exp(Gentoo)
#> • .pred_Adelie = exp(Adelie) / norm
#> • .pred_Chinstrap = exp(Chinstrap) / norm
#> • .pred_Gentoo = exp(Gentoo) / norm
#> ────────────────────────────────────────────────────────────────────────────────
#> 22 equations in total.
Notice how we can select both "class"
and "prob"
. The predictions now include both hard and soft class predictions.
predict(orbital_obj_prob, penguins)
#> # A tibble: 344 × 4
#> .pred_class .pred_Adelie .pred_Chinstrap .pred_Gentoo
#> <chr> <dbl> <dbl> <dbl>
#> 1 Adelie 0.989 0.00554 0.00560
#> 2 Adelie 0.989 0.00554 0.00560
#> 3 Adelie 0.989 0.00554 0.00560
#> 4 Adelie 0.709 0.0245 0.267
#> 5 Adelie 0.989 0.00554 0.00560
#> 6 Adelie 0.989 0.00554 0.00560
#> 7 Adelie 0.989 0.00554 0.00560
#> 8 Adelie 0.989 0.00554 0.00560
#> 9 Adelie 0.979 0.00549 0.0158
#> 10 Adelie 0.980 0.00559 0.0148
#> # ℹ 334 more rows
That works equally well in databases.
predict(orbital_obj_prob, penguins_sqlite)
#> # Source: SQL [?? x 4]
#> # Database: sqlite 3.47.1 []
#> .pred_class .pred_Adelie .pred_Chinstrap .pred_Gentoo
#> <chr> <dbl> <dbl> <dbl>
#> 1 Adelie 0.989 0.00554 0.00560
#> 2 Adelie 0.989 0.00554 0.00560
#> 3 Adelie 0.989 0.00554 0.00560
#> 4 Adelie 0.709 0.0245 0.267
#> 5 Adelie 0.989 0.00554 0.00560
#> 6 Adelie 0.989 0.00554 0.00560
#> 7 Adelie 0.989 0.00554 0.00560
#> 8 Adelie 0.989 0.00554 0.00560
#> 9 Adelie 0.979 0.00549 0.0158
#> 10 Adelie 0.980 0.00559 0.0148
#> # ℹ more rows
New augment method
The users of tidymodels have found the
augment()
function to be a handy tool. This function performs predictions and returns them alongside the original data set.
This release adds
augment()
support for orbital objects.
augment(orbital_obj, penguins)
#> # A tibble: 344 × 8
#> .pred_class species island bill_length_mm bill_depth_mm flipper_length_mm
#> <chr> <fct> <fct> <dbl> <dbl> <int>
#> 1 Adelie Adelie Torgersen 39.1 18.7 181
#> 2 Adelie Adelie Torgersen 39.5 17.4 186
#> 3 Adelie Adelie Torgersen 40.3 18 195
#> 4 Adelie Adelie Torgersen NA NA NA
#> 5 Adelie Adelie Torgersen 36.7 19.3 193
#> 6 Adelie Adelie Torgersen 39.3 20.6 190
#> 7 Adelie Adelie Torgersen 38.9 17.8 181
#> 8 Adelie Adelie Torgersen 39.2 19.6 195
#> 9 Adelie Adelie Torgersen 34.1 18.1 193
#> 10 Adelie Adelie Torgersen 42 20.2 190
#> # ℹ 334 more rows
#> # ℹ 2 more variables: body_mass_g <int>, sex <fct>
The function works for most databases, but for technical reasons doesn’t work with all. It has been confirmed to not work work in spark databases or arrow tables.
augment(orbital_obj, penguins_sqlite)
#> # Source: SQL [?? x 8]
#> # Database: sqlite 3.47.1 []
#> .pred_class species island bill_length_mm bill_depth_mm flipper_length_mm
#> <chr> <chr> <chr> <dbl> <dbl> <int>
#> 1 Adelie Adelie Torgersen 39.1 18.7 181
#> 2 Adelie Adelie Torgersen 39.5 17.4 186
#> 3 Adelie Adelie Torgersen 40.3 18 195
#> 4 Adelie Adelie Torgersen NA NA NA
#> 5 Adelie Adelie Torgersen 36.7 19.3 193
#> 6 Adelie Adelie Torgersen 39.3 20.6 190
#> 7 Adelie Adelie Torgersen 38.9 17.8 181
#> 8 Adelie Adelie Torgersen 39.2 19.6 195
#> 9 Adelie Adelie Torgersen 34.1 18.1 193
#> 10 Adelie Adelie Torgersen 42 20.2 190
#> # ℹ more rows
#> # ℹ 2 more variables: body_mass_g <int>, sex <chr>
Acknowledgements
A big thank you to all the people who have contributed to orbital since the release of v0.3.0:
@EmilHvitfeldt, @joscani, @jrosell, @npelikan, and @szimmer.