-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update axe methods for xgboost #218
Conversation
Taking out the library(butcher)
library(xgboost)
more_cars <- mtcars[rep(1:32, each = 1000),]
xgb_mod <- xgboost(
data = as.matrix(more_cars[, -6]),
label = more_cars[["vs"]],
nrounds = 10)
#> [1] train-rmse:0.350062
#> [2] train-rmse:0.245052
#> [3] train-rmse:0.171519
#> [4] train-rmse:0.120046
#> [5] train-rmse:0.084052
#> [6] train-rmse:0.058839
#> [7] train-rmse:0.041190
#> [8] train-rmse:0.028835
#> [9] train-rmse:0.020180
#> [10] train-rmse:0.014130
preds <- predict(xgb_mod, as.matrix(more_cars[, -6]))
head(preds)
#> [1] 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712
butchered <- butcher::butcher(xgb_mod)
preds <- predict(butchered, as.matrix(more_cars[, -6]))
head(preds)
#> [1] 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712 0.01412712
saveRDS(butchered, "butchered.rds")
butchered_serialized <- readRDS("butchered.rds")
identical(
predict(xgb_mod, as.matrix(more_cars[, -6])),
predict(butchered_serialized, as.matrix(more_cars[, -6]))
)
#> [1] TRUE
sum(weigh(xgb_mod)$size)
#> [1] 0.060928
sum(weigh(butchered)$size)
#> [1] 0.022432
weigh(xgb_mod)
#> # A tibble: 11 × 2
#> object size
#> <chr> <dbl>
#> 1 callbacks.cb.evaluation.log 0.0354
#> 2 callbacks.cb.print.evaluation 0.0146
#> 3 raw 0.00793
#> 4 call 0.00151
#> 5 feature_names 0.000736
#> 6 handle 0.000312
#> 7 evaluation_log.iter 0.000176
#> 8 evaluation_log.train_rmse 0.000176
#> 9 niter 0.000056
#> 10 params.validate_parameters 0.000056
#> 11 nfeatures 0.000056
weigh(butchered)
#> # A tibble: 11 × 2
#> object size
#> <chr> <dbl>
#> 1 raw 0.00793
#> 2 callbacks.cb.print.evaluation 0.00773
#> 3 callbacks.cb.evaluation.log 0.00510
#> 4 feature_names 0.000736
#> 5 handle 0.000312
#> 6 evaluation_log.iter 0.000176
#> 7 evaluation_log.train_rmse 0.000176
#> 8 call 0.000112
#> 9 niter 0.000056
#> 10 params.validate_parameters 0.000056
#> 11 nfeatures 0.000056 Created on 2022-03-17 by the reprex package (v2.0.1) And it means that tidymodels can work: library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(butcher)
data(ames)
ames <-
ames %>%
select(
Sale_Price,
Neighborhood,
Gr_Liv_Area,
Year_Built,
Bldg_Type,
Latitude,
Longitude
) %>%
mutate(Sale_Price = log10(Sale_Price))
xgb_spec <-
boost_tree() %>%
set_engine("xgboost") %>%
set_mode("regression")
xgb_rec <-
recipe(Sale_Price ~ ., data = ames) %>%
step_dummy(all_nominal_predictors())
xgb_fit <- workflow(xgb_rec, xgb_spec) %>% fit(ames)
predict(xgb_fit, ames[22,])
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 5.16
butchered <- butcher::butcher(xgb_fit)
predict(butchered, ames[22,])
#> # A tibble: 1 × 1
#> .pred
#> <dbl>
#> 1 5.16
sum(weigh(xgb_fit)$size)
#> [1] 1.540464
sum(weigh(butchered)$size)
#> [1] 0.386416
weigh(xgb_fit)
#> # A tibble: 227 × 2
#> object size
#> <chr> <dbl>
#> 1 pre.actions.recipe.recipe.steps.terms 0.126
#> 2 pre.mold.blueprint.recipe.steps.terms 0.126
#> 3 fit.fit.fit.raw 0.0400
#> 4 fit.fit.fit.callbacks.cb.evaluation.log 0.0354
#> 5 pre.actions.recipe.blueprint.forge.process 0.0290
#> 6 pre.mold.blueprint.forge.process 0.0290
#> 7 pre.mold.predictors.Neighborhood_College_Creek 0.0243
#> 8 pre.mold.predictors.Neighborhood_Old_Town 0.0243
#> 9 pre.mold.predictors.Neighborhood_Edwards 0.0243
#> 10 pre.mold.predictors.Neighborhood_Somerset 0.0243
#> # … with 217 more rows
weigh(butchered)
#> # A tibble: 190 × 2
#> object size
#> <chr> <dbl>
#> 1 fit.fit.fit.raw 0.0400
#> 2 pre.actions.recipe.blueprint.forge.process 0.0290
#> 3 pre.mold.blueprint.forge.process 0.0290
#> 4 pre.actions.recipe.blueprint.mold.process 0.0240
#> 5 pre.mold.blueprint.mold.process 0.0240
#> 6 pre.actions.recipe.recipe.template.Latitude 0.0235
#> 7 pre.actions.recipe.recipe.template.Longitude 0.0235
#> 8 pre.actions.recipe.recipe.template.Sale_Price 0.0235
#> 9 pre.actions.recipe.blueprint.forge.clean 0.0158
#> 10 pre.mold.blueprint.forge.clean 0.0158
#> # … with 180 more rows Created on 2022-03-17 by the reprex package (v2.0.1) |
These two changes will allow xgboost (and tidymodels xgboost) models to be used for prediction after butchering. These changes do NOT address the serialization issue, which we need to look at. Eventually we would like a way to serialize, for example, a tidymodels xgboost model and use the native |
@@ -93,21 +76,3 @@ axe_env.xgb.Booster <- function(x, verbose = FALSE, ...) { | |||
verbose = verbose | |||
) | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should probably remove axe_env.xgb.Booster
too. It seems like callbacks are an optional feature, and it seems like they are only used at training time, but I can't confirm this. Since they are optional and probably don't take up space by default, I think removing this potentially harmful method is the safer thing to do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I see
#> 1 callbacks.cb.evaluation.log 0.0354
#> 2 callbacks.cb.print.evaluation 0.0146
in the weigh results, which does get smaller when you axe it.
I guess those are default callbacks that get added by xgboost. I guess we should leave this method in there then
This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
Closes #147
Created on 2022-03-17 by the reprex package (v2.0.1)