Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Add print() and summary() methods for Booster #4686

Merged
merged 18 commits into from
Nov 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset)
S3method(getinfo,lgb.Dataset)
S3method(predict,lgb.Booster)
S3method(print,lgb.Booster)
S3method(set_field,lgb.Dataset)
S3method(setinfo,lgb.Dataset)
S3method(slice,lgb.Dataset)
S3method(summary,lgb.Booster)
export(get_field)
export(getinfo)
export(lgb.Dataset)
Expand Down
59 changes: 59 additions & 0 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,65 @@ predict.lgb.Booster <- function(object,
)
}

#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
# nolint start
handle <- x$.__enclos_env__$private$handle
handle_is_null <- lgb.is.null.handle(handle)

if (!handle_is_null) {
ntrees <- x$current_iter()
if (ntrees == 1L) {
cat("LightGBM Model (1 tree)\n")
} else {
cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
}
} else {
cat("LightGBM Model\n")
}

if (!handle_is_null) {
obj <- x$params$objective
if (obj == "none") {
obj <- "custom"
}
if (x$.__enclos_env__$private$num_class == 1L) {
cat(sprintf("Objective: %s\n", obj))
} else {
cat(sprintf("Objective: %s (%d classes)\n"
, obj
, x$.__enclos_env__$private$num_class))
}
} else {
cat("(Booster handle is invalid)\n")
}

if (!handle_is_null) {
ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
cat(sprintf("Fitted to dataset with %d columns\n", ncols))
}
# nolint end

return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
print(object)
}

#' @name lgb.load
#' @title Load LightGBM model
#' @description Load LightGBM takes in either a file path or model string.
Expand Down
19 changes: 19 additions & 0 deletions R-package/man/print.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions R-package/man/summary.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,15 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
R_API_END();
}

SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out = 0;
CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
return Rf_ScalarInteger(out);
R_API_END();
}

SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
Expand Down Expand Up @@ -889,6 +898,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterResetTrainingData_R" , (DL_FUNC) &LGBM_BoosterResetTrainingData_R , 2},
{"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2},
{"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2},
{"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1},
{"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1},
{"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
Expand Down
9 changes: 9 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R(
SEXP out
);

/*!
* \brief Get number of features.
* \param handle Booster handle
* \return Total number of features, as R integer
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumFeature_R(
SEXP handle
);

/*!
* \brief update the model in one round
* \param handle Booster handle
Expand Down
113 changes: 113 additions & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1041,3 +1041,116 @@ test_that("boosters with linear models at leaves can be written to RDS and re-lo
preds2 <- predict(bst2, X)
expect_identical(preds, preds2)
})

test_that("Booster's print, show, and summary work correctly", {
.have_same_handle <- function(model, other_model) {
expect_equal(
model$.__enclos_env__$private$handle
, other_model$.__enclos_env__$private$handle
)
}

.check_methods_work <- function(model) {

# should work for fitted models
ret <- print(model)
.have_same_handle(ret, model)
ret <- show(model)
expect_null(ret)
ret <- summary(model)
.have_same_handle(ret, model)

# should not fail for finalized models
model$finalize()
ret <- print(model)
.have_same_handle(ret, model)
ret <- show(model)
expect_null(ret)
ret <- summary(model)
.have_same_handle(ret, model)
}

data("mtcars")
model <- lgb.train(
params = list(objective = "regression")
, data = lgb.Dataset(
as.matrix(mtcars[, -1L])
, label = mtcars$mpg)
, verbose = 0L
, nrounds = 5L
)
.check_methods_work(model)

data("iris")
model <- lgb.train(
params = list(objective = "multiclass", num_class = 3L)
, data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(factor(iris$Species)) - 1.0
)
, verbose = 0L
, nrounds = 5L
)
.check_methods_work(model)


# with custom objective
.logregobj <- function(preds, dtrain) {
labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds))
grad <- preds - labels
hess <- preds * (1.0 - preds)
return(list(grad = grad, hess = hess))
}

.evalerror <- function(preds, dtrain) {
labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds))
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list(
name = "error"
, value = err
, higher_better = FALSE
))
}

model <- lgb.train(
data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(iris$Species == "virginica")
)
, obj = .logregobj
, eval = .evalerror
, verbose = 0L
, nrounds = 5L
)

.check_methods_work(model)
})

test_that("LGBM_BoosterGetNumFeature_R returns correct outputs", {
data("mtcars")
model <- lgb.train(
params = list(objective = "regression")
, data = lgb.Dataset(
as.matrix(mtcars[, -1L])
, label = mtcars$mpg)
, verbose = 0L
, nrounds = 5L
)
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(mtcars) - 1L)

data("iris")
model <- lgb.train(
params = list(objective = "multiclass", num_class = 3L)
, data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(factor(iris$Species)) - 1.0
)
, verbose = 0L
, nrounds = 5L
)
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(iris) - 1L)
})