Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update axe methods for xgboost #218

Merged
merged 3 commits into from
Mar 18, 2022
Merged

Update axe methods for xgboost #218

merged 3 commits into from
Mar 18, 2022

Conversation

juliasilge
Copy link
Member

@juliasilge juliasilge commented Mar 17, 2022

Closes #147

library(butcher)
library(xgboost)

more_cars <- mtcars[rep(1:32, each = 1000),]
xgb_mod <- xgboost(
  data = as.matrix(more_cars[, -6]),
  label = more_cars[["vs"]],
  nrounds = 10)
#> [1]  train-rmse:0.350062 
#> [2]  train-rmse:0.245052 
#> [3]  train-rmse:0.171519 
#> [4]  train-rmse:0.120046 
#> [5]  train-rmse:0.084052 
#> [6]  train-rmse:0.058839 
#> [7]  train-rmse:0.041190 
#> [8]  train-rmse:0.028835 
#> [9]  train-rmse:0.020180 
#> [10] train-rmse:0.014130

preds <- predict(xgb_mod, as.matrix(more_cars[, -6]))
head(preds)
#> [1] 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712

butchered <- butcher::butcher(xgb_mod)
preds <- predict(butchered, as.matrix(more_cars[, -6]))
head(preds)
#> [1] 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712

saveRDS(butchered, "butchered.rds")
butchered_serialized <- readRDS("butchered.rds")
identical(
  predict(xgb_mod, as.matrix(more_cars[, -6])),
  predict(butchered_serialized, as.matrix(more_cars[, -6]))
)
#> [1] TRUE


sum(weigh(xgb_mod)$size)
#> [1] 0.060928
sum(weigh(butchered)$size)
#> [1] 0.022376

weigh(xgb_mod)
#> # A tibble: 11 × 2
#>    object                            size
#>    <chr>                            <dbl>
#>  1 callbacks.cb.evaluation.log   0.0354  
#>  2 callbacks.cb.print.evaluation 0.0146  
#>  3 raw                           0.00793 
#>  4 call                          0.00151 
#>  5 feature_names                 0.000736
#>  6 handle                        0.000312
#>  7 evaluation_log.iter           0.000176
#>  8 evaluation_log.train_rmse     0.000176
#>  9 niter                         0.000056
#> 10 params.validate_parameters    0.000056
#> 11 nfeatures                     0.000056
weigh(butchered)
#> # A tibble: 10 × 2
#>    object                            size
#>    <chr>                            <dbl>
#>  1 raw                           0.00793 
#>  2 callbacks.cb.print.evaluation 0.00773 
#>  3 callbacks.cb.evaluation.log   0.00510 
#>  4 feature_names                 0.000736
#>  5 handle                        0.000312
#>  6 evaluation_log.iter           0.000176
#>  7 evaluation_log.train_rmse     0.000176
#>  8 call                          0.000112
#>  9 niter                         0.000056
#> 10 nfeatures                     0.000056

Created on 2022-03-17 by the reprex package (v2.0.1)

@juliasilge
Copy link
Member Author

Taking out the axe_ctrl() makes almost no difference in the resulting size:

library(butcher)
library(xgboost)

more_cars <- mtcars[rep(1:32, each = 1000),]
xgb_mod <- xgboost(
  data = as.matrix(more_cars[, -6]),
  label = more_cars[["vs"]],
  nrounds = 10)
#> [1]  train-rmse:0.350062 
#> [2]  train-rmse:0.245052 
#> [3]  train-rmse:0.171519 
#> [4]  train-rmse:0.120046 
#> [5]  train-rmse:0.084052 
#> [6]  train-rmse:0.058839 
#> [7]  train-rmse:0.041190 
#> [8]  train-rmse:0.028835 
#> [9]  train-rmse:0.020180 
#> [10] train-rmse:0.014130

preds <- predict(xgb_mod, as.matrix(more_cars[, -6]))
head(preds)
#> [1] 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712

butchered <- butcher::butcher(xgb_mod)
preds <- predict(butchered, as.matrix(more_cars[, -6]))
head(preds)
#> [1] 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712

saveRDS(butchered, "butchered.rds")
butchered_serialized <- readRDS("butchered.rds")
identical(
  predict(xgb_mod, as.matrix(more_cars[, -6])),
  predict(butchered_serialized, as.matrix(more_cars[, -6]))
)
#> [1] TRUE


sum(weigh(xgb_mod)$size)
#> [1] 0.060928
sum(weigh(butchered)$size)
#> [1] 0.022432

weigh(xgb_mod)
#> # A tibble: 11 × 2
#>    object                            size
#>    <chr>                            <dbl>
#>  1 callbacks.cb.evaluation.log   0.0354  
#>  2 callbacks.cb.print.evaluation 0.0146  
#>  3 raw                           0.00793 
#>  4 call                          0.00151 
#>  5 feature_names                 0.000736
#>  6 handle                        0.000312
#>  7 evaluation_log.iter           0.000176
#>  8 evaluation_log.train_rmse     0.000176
#>  9 niter                         0.000056
#> 10 params.validate_parameters    0.000056
#> 11 nfeatures                     0.000056
weigh(butchered)
#> # A tibble: 11 × 2
#>    object                            size
#>    <chr>                            <dbl>
#>  1 raw                           0.00793 
#>  2 callbacks.cb.print.evaluation 0.00773 
#>  3 callbacks.cb.evaluation.log   0.00510 
#>  4 feature_names                 0.000736
#>  5 handle                        0.000312
#>  6 evaluation_log.iter           0.000176
#>  7 evaluation_log.train_rmse     0.000176
#>  8 call                          0.000112
#>  9 niter                         0.000056
#> 10 params.validate_parameters    0.000056
#> 11 nfeatures                     0.000056

Created on 2022-03-17 by the reprex package (v2.0.1)

And it means that tidymodels can work:

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(butcher)
data(ames)

ames <-
  ames %>%
  select(
    Sale_Price,
    Neighborhood,
    Gr_Liv_Area,
    Year_Built,
    Bldg_Type,
    Latitude,
    Longitude
  ) %>%
  mutate(Sale_Price = log10(Sale_Price))

xgb_spec <- 
  boost_tree() %>%
  set_engine("xgboost") %>%
  set_mode("regression")

xgb_rec <- 
  recipe(Sale_Price ~ ., data = ames) %>%
  step_dummy(all_nominal_predictors()) 

xgb_fit <- workflow(xgb_rec, xgb_spec) %>% fit(ames)

predict(xgb_fit, ames[22,])
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  5.16

butchered <- butcher::butcher(xgb_fit)
predict(butchered, ames[22,])
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  5.16

sum(weigh(xgb_fit)$size)
#> [1] 1.540464
sum(weigh(butchered)$size)
#> [1] 0.386416

weigh(xgb_fit)
#> # A tibble: 227 × 2
#>    object                                           size
#>    <chr>                                           <dbl>
#>  1 pre.actions.recipe.recipe.steps.terms          0.126 
#>  2 pre.mold.blueprint.recipe.steps.terms          0.126 
#>  3 fit.fit.fit.raw                                0.0400
#>  4 fit.fit.fit.callbacks.cb.evaluation.log        0.0354
#>  5 pre.actions.recipe.blueprint.forge.process     0.0290
#>  6 pre.mold.blueprint.forge.process               0.0290
#>  7 pre.mold.predictors.Neighborhood_College_Creek 0.0243
#>  8 pre.mold.predictors.Neighborhood_Old_Town      0.0243
#>  9 pre.mold.predictors.Neighborhood_Edwards       0.0243
#> 10 pre.mold.predictors.Neighborhood_Somerset      0.0243
#> # … with 217 more rows
weigh(butchered)
#> # A tibble: 190 × 2
#>    object                                          size
#>    <chr>                                          <dbl>
#>  1 fit.fit.fit.raw                               0.0400
#>  2 pre.actions.recipe.blueprint.forge.process    0.0290
#>  3 pre.mold.blueprint.forge.process              0.0290
#>  4 pre.actions.recipe.blueprint.mold.process     0.0240
#>  5 pre.mold.blueprint.mold.process               0.0240
#>  6 pre.actions.recipe.recipe.template.Latitude   0.0235
#>  7 pre.actions.recipe.recipe.template.Longitude  0.0235
#>  8 pre.actions.recipe.recipe.template.Sale_Price 0.0235
#>  9 pre.actions.recipe.blueprint.forge.clean      0.0158
#> 10 pre.mold.blueprint.forge.clean                0.0158
#> # … with 180 more rows

Created on 2022-03-17 by the reprex package (v2.0.1)

@juliasilge
Copy link
Member Author

These two changes will allow xgboost (and tidymodels xgboost) models to be used for prediction after butchering.

These changes do NOT address the serialization issue, which we need to look at. Eventually we would like a way to serialize, for example, a tidymodels xgboost model and use the native xgb.save() / xgb.load(). We will work on the serialization issues separately.

@juliasilge juliasilge marked this pull request as ready for review March 17, 2022 20:27
@juliasilge juliasilge requested a review from DavisVaughan March 17, 2022 20:27
@@ -93,21 +76,3 @@ axe_env.xgb.Booster <- function(x, verbose = FALSE, ...) {
verbose = verbose
)
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably remove axe_env.xgb.Booster too. It seems like callbacks are an optional feature, and it seems like they are only used at training time, but I can't confirm this. Since they are optional and probably don't take up space by default, I think removing this potentially harmful method is the safer thing to do

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see

#>  1 callbacks.cb.evaluation.log   0.0354  
#>  2 callbacks.cb.print.evaluation 0.0146  

in the weigh results, which does get smaller when you axe it.

I guess those are default callbacks that get added by xgboost. I guess we should leave this method in there then

@juliasilge juliasilge merged commit 987610d into main Mar 18, 2022
@juliasilge juliasilge deleted the keep-raw-in-xgboost branch March 18, 2022 17:00
@github-actions
Copy link

github-actions bot commented Apr 2, 2022

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Apr 2, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Using {xgboost} predict after butcher()
2 participants