diff --git a/DESCRIPTION b/DESCRIPTION index b1409c3..fe079ca 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -69,7 +69,8 @@ Suggests: testthat (>= 3.0.0), TH.data, usethis (>= 1.5.0), - xgboost (>= 1.3.2.1) + xgboost (>= 1.3.2.1), + xrf VignetteBuilder: knitr Config/Needs/website: diff --git a/NAMESPACE b/NAMESPACE index 4d7f1e4..e0aa883 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -30,6 +30,7 @@ S3method(axe_call,survreg.penal) S3method(axe_call,train) S3method(axe_call,train.recipe) S3method(axe_call,xgb.Booster) +S3method(axe_call,xrf) S3method(axe_ctrl,C5.0) S3method(axe_ctrl,default) S3method(axe_ctrl,gam) @@ -91,6 +92,7 @@ S3method(axe_env,terms) S3method(axe_env,train) S3method(axe_env,train.recipe) S3method(axe_env,xgb.Booster) +S3method(axe_env,xrf) S3method(axe_fitted,C5.0) S3method(axe_fitted,KMeansCluster) S3method(axe_fitted,bart) diff --git a/NEWS.md b/NEWS.md index 1c49202..858f535 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # butcher (development version) +* Added butcher methods for `xrf::xrf()` (#242). + * Added butcher methods for `mda::fda()` (#241). * Added butcher methods for `dbarts::bart()` (#240). diff --git a/R/xrf.R b/R/xrf.R new file mode 100644 index 0000000..4e2d934 --- /dev/null +++ b/R/xrf.R @@ -0,0 +1,65 @@ +#' Axing a xrf. +#' +#' @inheritParams butcher +#' +#' @return Axed xrf object. +#' +#' @examplesIf rlang::is_installed("xrf") +#' library(xrf) +#' +#' xrf_big <- function() { +#' boop <- runif(1e6) +#' xrf( +#' mpg ~ ., +#' mtcars, +#' xgb_control = list(nrounds = 2, max_depth = 2), +#' family = 'gaussian' +#' ) +#' } +#' +#' heavy_m <- xrf_big() +#' +#' m <- butcher(heavy_m, verbose = TRUE) +#' +#' weigh(heavy_m) +#' weigh(m) +#' +#' @name axe-xrf +NULL + +#' @rdname axe-xrf +#' @export +axe_call.xrf <- function(x, verbose = FALSE, ...) { + res <- x + res$xgb <- axe_call(res$xgb) + res$glm$model$glmnet.fit$call <- call("dummy_call") + res$glm$model$call <- call("dummy_call") + + add_butcher_attributes( + res, + x, + add_class = TRUE, + verbose = verbose + ) +} + +#' @rdname axe-xrf +#' @export +axe_env.xrf <- function(x, verbose = FALSE, ...) { + res <- x + res$base_formula <- axe_env(res$base_formula, ...) + res$rule_augmented_formula <- axe_env(res$rule_augmented_formula, ...) + res$glm$formula <- axe_env(res$glm$formula, ...) + res$xgb <- axe_env(res$xgb, ...) + + add_butcher_attributes( + res, + x, + add_class = TRUE, + verbose = verbose + ) +} + + + + diff --git a/man/axe-xrf.Rd b/man/axe-xrf.Rd new file mode 100644 index 0000000..18356ea --- /dev/null +++ b/man/axe-xrf.Rd @@ -0,0 +1,49 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/xrf.R +\name{axe-xrf} +\alias{axe-xrf} +\alias{axe_call.xrf} +\alias{axe_env.xrf} +\title{Axing a xrf.} +\usage{ +\method{axe_call}{xrf}(x, verbose = FALSE, ...) + +\method{axe_env}{xrf}(x, verbose = FALSE, ...) +} +\arguments{ +\item{x}{A model object.} + +\item{verbose}{Print information each time an axe method is executed. +Notes how much memory is released and what functions are +disabled. Default is \code{FALSE}.} + +\item{...}{Any additional arguments related to axing.} +} +\value{ +Axed xrf object. +} +\description{ +Axing a xrf. +} +\examples{ +\dontshow{if (rlang::is_installed("xrf")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +library(xrf) + +xrf_big <- function() { + boop <- runif(1e6) + xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) +} + +heavy_m <- xrf_big() + +m <- butcher(heavy_m, verbose = TRUE) + +weigh(heavy_m) +weigh(m) +\dontshow{\}) # examplesIf} +} diff --git a/tests/testthat/test-xrf.R b/tests/testthat/test-xrf.R new file mode 100644 index 0000000..15f849c --- /dev/null +++ b/tests/testthat/test-xrf.R @@ -0,0 +1,63 @@ +skip_if_not_installed("xrf") + +test_that("xrf + axe_call() works", { + res <- + xrf::xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) + x <- axe_call(res) + expect_equal(x$xgb$call, rlang::expr(dummy_call())) + expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) + expect_equal(x$glm$model$call, rlang::expr(dummy_call())) +}) + +test_that("xrf + axe_env() works", { + res <- + xrf::xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) + x <- axe_env(res) + expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) + expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) + expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) + expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) +}) + +test_that("xrf + butcher() works", { + res <- + xrf::xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) + x <- butcher(res) + expect_equal(x$xgb$call, rlang::expr(dummy_call())) + expect_equal(x$glm$model$glmnet.fit$call, rlang::expr(dummy_call())) + expect_equal(x$glm$model$call, rlang::expr(dummy_call())) + expect_equal(attr(x$base_formula, ".Environment"), rlang::base_env()) + expect_equal(attr(x$rule_augmented_formula, ".Environment"), rlang::base_env()) + expect_equal(attr(x$glm$formula, ".Environment"), rlang::base_env()) + expect_equal(environment(x$xgb$callbacks[[1]]), rlang::base_env()) +}) + +test_that("xrf + predict() works", { + res <- + xrf::xrf( + mpg ~ ., + mtcars, + xgb_control = list(nrounds = 2, max_depth = 2), + family = 'gaussian' + ) + x <- butcher(res) + expect_equal( + predict(x, newdata = head(mtcars))[1], + predict(res, newdata = head(mtcars))[1] + ) +})