From bf368e378b9610c6a380bf76dc62a3a852c591db Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Thu, 22 Dec 2022 11:58:54 -0500 Subject: [PATCH 1/3] add butcher methods for `xrf::xrf()` --- DESCRIPTION | 3 +- NAMESPACE | 2 ++ R/xrf.R | 65 +++++++++++++++++++++++++++++++++++++++ man/axe-xrf.Rd | 49 +++++++++++++++++++++++++++++ tests/testthat/test-xrf.R | 63 +++++++++++++++++++++++++++++++++++++ 5 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 R/xrf.R create mode 100644 man/axe-xrf.Rd create mode 100644 tests/testthat/test-xrf.R diff --git a/DESCRIPTION b/DESCRIPTION index b2642cd..a0a43a0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -68,7 +68,8 @@ Suggests: testthat (>= 3.0.0), TH.data, usethis (>= 1.5.0), - xgboost (>= 1.3.2.1) + xgboost (>= 1.3.2.1), + xrf VignetteBuilder: knitr Config/Needs/website: diff --git a/NAMESPACE b/NAMESPACE index 0b22277..920e363 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -28,6 +28,7 @@ S3method(axe_call,survreg.penal) S3method(axe_call,train) S3method(axe_call,train.recipe) S3method(axe_call,xgb.Booster) +S3method(axe_call,xrf) S3method(axe_ctrl,C5.0) S3method(axe_ctrl,default) S3method(axe_ctrl,gam) @@ -88,6 +89,7 @@ S3method(axe_env,terms) S3method(axe_env,train) S3method(axe_env,train.recipe) S3method(axe_env,xgb.Booster) +S3method(axe_env,xrf) S3method(axe_fitted,C5.0) S3method(axe_fitted,KMeansCluster) S3method(axe_fitted,default) diff --git a/R/xrf.R b/R/xrf.R new file mode 100644 index 0000000..4e2d934 --- /dev/null +++ b/R/xrf.R @@ -0,0 +1,65 @@ +#' Axing a xrf. +#' +#' @inheritParams butcher +#' +#' @return Axed xrf object. +#' +#' @examplesIf rlang::is_installed("xrf") +#' library(xrf) +#' +#' xrf_big <- function() { +#' boop <- runif(1e6) +#' xrf( +#' mpg ~ ., +#' mtcars, +#' xgb_control = list(nrounds = 2, max_depth = 2), +#' family = 'gaussian' +#' ) +#' } +#' +#' heavy_m <- xrf_big() +#' +#' m <- butcher(heavy_m, verbose = TRUE) +#' +#' weigh(heavy_m) +#' weigh(m) +#' +#' @name axe-xrf +NULL + +#' @rdname axe-xrf +#' @export +axe_call.xrf <- function(x, verbose = FALSE, ...) { + res <- x + res$xgb <- axe_call(res$xgb) + res$glm$model$glmnet.fit$call <- call("dummy_call") + res$glm$model$call <- call("dummy_call") + + add_butcher_attributes( + res, + x, + add_class = TRUE, + verbose = verbose + ) +} + +#' @rdname axe-xrf +#' @export +axe_env.xrf <- function(x, verbose = FALSE, ...) { + res <- x + res$base_formula <- axe_env(res$base_formula, ...) + res$rule_augmented_formula <- axe_env(res$rule_augmented_formula, ...) + res$glm$formula <- axe_env(res$glm$formula, ...) + res$xgb <- axe_env(res$xgb, ...) + + add_butcher_attributes( + res, + x, + add_class = TRUE, + verbose = verbose + ) +} + + + + diff --git a/man/axe-xrf.Rd b/man/axe-xrf.Rd new file mode 100644 index 0000000..18356ea --- /dev/null +++ b/man/axe-xrf.Rd @@ -0,0 +1,49 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/xrf.R +\name{axe-xrf} +\alias{axe-xrf} +\alias{axe_call.xrf} +\alias{axe_env.xrf} +\title{Axing a xrf.} +\usage{ +\method{axe_call}{xrf}(x, verbose = FALSE, ...) + +\method{axe_env}{xrf}(x, verbose = FALSE, ...) +} +\arguments{ +\item{x}{A model object.} + +\item{verbose}{Print information each time an axe method is executed. +Notes how much memory is released and what functions are +disabled. Default is \code{FALSE}.} + +\item{...}{Any additional arguments related to axing.} +} +\value{ +Axed xrf object. +} +\description{ +Axing a xrf. +} +\examples{ +\dontshow{if (rlang::is_installed("xrf")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +library(xrf) + +xrf_big <- function() { + boop <- runif(1e6) + xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) +} + +heavy_m <- xrf_big() + +m <- butcher(heavy_m, verbose = TRUE) + +weigh(heavy_m) +weigh(m) +\dontshow{\}) # examplesIf} +} diff --git a/tests/testthat/test-xrf.R b/tests/testthat/test-xrf.R new file mode 100644 index 0000000..9ee9595 --- /dev/null +++ b/tests/testthat/test-xrf.R @@ -0,0 +1,63 @@ +skip_if_not_installed("xrf") + +test_that("xrf + axe_call() works", { + res <- + xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) + x <- axe_call(res) + expect_equal(x$xgb$call, rlang::expr(dummy_call())) + expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) + expect_equal(x$glm$model$call, rlang::expr(dummy_call())) +}) + +test_that("xrf + axe_env() works", { + res <- + xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) + x <- axe_env(res) + expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) + expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) + expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) + expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) +}) + +test_that("xrf + butcher() works", { + res <- + xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) + x <- butcher(res) + expect_equal(x$xgb$call, rlang::expr(dummy_call())) + expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) + expect_equal(x$glm$model$call, rlang::expr(dummy_call())) + expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) + expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) + expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) + expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) +}) + +test_that("xrf + predict() works", { + res <- + xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) + x <- butcher(res) + expect_equal( + predict(x, newdata = head(mtcars))[1], + predict(res, newdata = head(mtcars))[1] + ) +}) From 1845183800283f2927220142d7ae30cd5fa274f7 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Tue, 17 Jan 2023 12:54:44 -0500 Subject: [PATCH 2/3] update NEWS --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 1c49202..858f535 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # butcher (development version) +* Added butcher methods for `xrf::xrf()` (#242). + * Added butcher methods for `mda::fda()` (#241). * Added butcher methods for `dbarts::bart()` (#240). From 1265ba20454189c3a76f51b9607d3566dc62da64 Mon Sep 17 00:00:00 2001 From: simonpcouch Date: Tue, 17 Jan 2023 13:00:48 -0500 Subject: [PATCH 3/3] namespace xrf --- tests/testthat/test-xrf.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-xrf.R b/tests/testthat/test-xrf.R index 9ee9595..15f849c 100644 --- a/tests/testthat/test-xrf.R +++ b/tests/testthat/test-xrf.R @@ -2,7 +2,7 @@ skip_if_not_installed("xrf") test_that("xrf + axe_call() works", { res <- - xrf( + xrf::xrf( mpg ~ ., mtcars, xgb_control = list(nrounds = 2, max_depth = 2), @@ -16,7 +16,7 @@ test_that("xrf + axe_call() works", { test_that("xrf + axe_env() works", { res <- - xrf( + xrf::xrf( mpg ~ ., mtcars, xgb_control = list(nrounds = 2, max_depth = 2), @@ -31,7 +31,7 @@ test_that("xrf + axe_env() works", { test_that("xrf + butcher() works", { res <- - xrf( + xrf::xrf( mpg ~ ., mtcars, xgb_control = list(nrounds = 2, max_depth = 2), @@ -49,7 +49,7 @@ test_that("xrf + butcher() works", { test_that("xrf + predict() works", { res <- - xrf( + xrf::xrf( mpg ~ ., mtcars, xgb_control = list(nrounds = 2, max_depth = 2),