Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify concept of outcome type in the package #14

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: container
Title: Sandbox for a postprocessor object
Version: 0.0.0.9000
Version: 0.0.0.9001
Authors@R: c(
person("Simon", "Couch", , "[email protected]", role = "aut"),
person("Hannah", "Frick", , "[email protected]", role = "aut"),
Expand Down
3 changes: 1 addition & 2 deletions R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' library(modeldata)
#'
#' post_obj <-
#' container(mode = "classification") %>%
#' container() %>%
#' adjust_equivocal_zone(value = 1 / 4)
#'
#'
Expand Down Expand Up @@ -43,7 +43,6 @@ adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand Down
21 changes: 10 additions & 11 deletions R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Re-calibrate numeric predictions
#'
#' @param x A [container()].
#' @param type Character. One of `"linear"`, `"isotonic"`, or
#' @param method Character. One of `"linear"`, `"isotonic"`, or
#' `"isotonic_boot"`, corresponding to the function from the \pkg{probably}
#' package [probably::cal_estimate_linear()],
#' [probably::cal_estimate_isotonic()], or
Expand All @@ -19,21 +19,21 @@
#'
#' # specify calibration
#' reg_ctr <-
#' container(mode = "regression") %>%
#' adjust_numeric_calibration(type = "linear")
#' container() %>%
#' adjust_numeric_calibration(method = "linear")
#'
#' # train container
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
#'
#' predict(reg_ctr_trained, dat)
#' @export
adjust_numeric_calibration <- function(x, type = NULL) {
adjust_numeric_calibration <- function(x, method = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x, calibration_type = "numeric")
# wait to `check_type()` until `fit()` time
if (!is.null(type)) {
# wait to `check_method()` until `fit()` time
if (!is.null(method)) {
arg_match0(
type,
method,
c("linear", "isotonic", "isotonic_boot")
)
}
Expand All @@ -43,13 +43,12 @@ adjust_numeric_calibration <- function(x, type = NULL) {
"numeric_calibration",
inputs = "numeric",
outputs = "numeric",
arguments = list(type = type),
arguments = list(method = method),
results = list(),
trained = FALSE
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand All @@ -67,13 +66,13 @@ print.numeric_calibration <- function(x, ...) {

#' @export
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
type <- check_type(object$type, container$type)
method <- check_method(object$method, container$type)
# todo: adjust_numeric_calibration() should take arguments to pass to
# cal_estimate_* via dots
fit <-
eval_bare(
call2(
paste0("cal_estimate_", type),
paste0("cal_estimate_", method),
.data = data,
truth = container$columns$outcome,
estimate = container$columns$estimate,
Expand Down
1 change: 0 additions & 1 deletion R/adjust-numeric-range.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ adjust_numeric_range <- function(x, lower_limit = -Inf, upper_limit = Inf) {
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand Down
3 changes: 1 addition & 2 deletions R/adjust-predictions-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' library(modeldata)
#'
#' post_obj <-
#' container(mode = "classification") %>%
#' container() %>%
#' adjust_equivocal_zone() %>%
#' adjust_predictions_custom(linear_predictor = binomial()$linkfun(Class2))
#'
Expand Down Expand Up @@ -39,7 +39,6 @@ adjust_predictions_custom <- function(x, ..., .pkgs = character(0)) {
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand Down
17 changes: 8 additions & 9 deletions R/adjust-probability-calibration.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#' Re-calibrate classification probability predictions
#'
#' @param x A [container()].
#' @param type Character. One of `"logistic"`, `"multinomial"`,
#' @param method Character. One of `"logistic"`, `"multinomial"`,
#' `"beta"`, `"isotonic"`, or `"isotonic_boot"`, corresponding to the
#' function from the \pkg{probably} package [probably::cal_estimate_logistic()],
#' [probably::cal_estimate_multinomial()], etc., respectively.
#' @export
adjust_probability_calibration <- function(x, type = NULL) {
adjust_probability_calibration <- function(x, method = NULL) {
# to-do: add argument specifying `prop` in initial_split
check_container(x, calibration_type = "probability")
# wait to `check_type()` until `fit()` time
if (!is.null(type)) {
# wait to `check_method()` until `fit()` time
if (!is.null(method)) {
arg_match(
type,
method,
c("logistic", "multinomial", "beta", "isotonic", "isotonic_boot")
)
}
Expand All @@ -22,13 +22,12 @@ adjust_probability_calibration <- function(x, type = NULL) {
"probability_calibration",
inputs = "probability",
outputs = "probability_class",
arguments = list(type = type),
arguments = list(method = method),
results = list(),
trained = FALSE
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand All @@ -46,14 +45,14 @@ print.probability_calibration <- function(x, ...) {

#' @export
fit.probability_calibration <- function(object, data, container = NULL, ...) {
type <- check_type(object$type, container$type)
method <- check_method(object$method, container$type)
# todo: adjust_probability_calibration() should take arguments to pass to
# cal_estimate_* via dots
# to-do: add argument specifying `prop` in initial_split
fit <-
eval_bare(
call2(
paste0("cal_estimate_", type),
paste0("cal_estimate_", method),
.data = data,
# todo: make getters for the entries in `columns`
truth = container$columns$outcome,
Expand Down
3 changes: 1 addition & 2 deletions R/adjust-probability-threshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' library(modeldata)
#'
#' post_obj <-
#' container(mode = "classification") %>%
#' container() %>%
#' adjust_probability_threshold(threshold = .1)
#'
#' two_class_example %>% count(predicted)
Expand Down Expand Up @@ -39,7 +39,6 @@ adjust_probability_threshold <- function(x, threshold = 0.5) {
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
Expand Down
20 changes: 5 additions & 15 deletions R/container.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#' Declare post-processing for model predictions
#'
#' @param mode The model's mode, one of `"classification"`, or `"regression"`.
#' Modes of `"censored regression"` are not currently supported.
#' @param type The model sub-type. Possible values are `"unknown"`, `"regression"`,
#' `"binary"`, or `"multiclass"`.
#' @param outcome The name of the outcome variable.
Expand All @@ -14,9 +12,9 @@
#' @param time The name of the predicted event time. (not yet supported)
#' @examples
#'
#' container(mode = "regression")
#' container()
#' @export
container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
container <- function(type = "unknown", outcome = NULL, estimate = NULL,
probabilities = NULL, time = NULL) {
columns <-
list(
Expand All @@ -28,7 +26,6 @@ container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
)

new_container(
mode,
type,
operations = list(),
columns = columns,
Expand All @@ -37,13 +34,7 @@ container <- function(mode, type = "unknown", outcome = NULL, estimate = NULL,
)
}

new_container <- function(mode, type, operations, columns, ptype, call) {
mode <- arg_match0(mode, c("regression", "classification"))

if (mode == "regression") {
type <- "regression"
}

new_container <- function(type, operations, columns, ptype, call) {
type <- arg_match0(type, c("unknown", "regression", "binary", "multiclass"))

if (!is.list(operations)) {
Expand All @@ -58,11 +49,11 @@ new_container <- function(mode, type, operations, columns, ptype, call) {
}

# validate operation order and check duplicates
validate_order(operations, mode, call)
validate_order(operations, type, call)

# check columns
res <- list(
mode = mode, type = type, operations = operations,
type = type, operations = operations,
columns = columns, ptype = ptype
)
class(res) <- "container"
Expand Down Expand Up @@ -120,7 +111,6 @@ fit.container <- function(object, .data, outcome, estimate, probabilities = c(),
object <- set_container_type(object, .data[[columns$outcome]])

object <- new_container(
object$mode,
object$type,
operations = object$operations,
columns = columns,
Expand Down
37 changes: 18 additions & 19 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ check_container <- function(x, calibration_type = NULL, call = caller_env(), arg
# check that the type of calibration ("numeric" or "probability") is
# compatible with the container type
if (!is.null(calibration_type)) {
container_type <- x$type
type <- x$type
switch(
container_type,
type,
regression =
check_calibration_type(calibration_type, "numeric", container_type, call = call),
binary = , multinomial =
check_calibration_type(calibration_type, "probability", container_type, call = call)
check_calibration_type(calibration_type, "numeric", type, call = call),
binary = , multiclass =
check_calibration_type(calibration_type, "probability", type, call = call)
)
}

Expand All @@ -90,54 +90,53 @@ types_binary <- c("logistic", "beta", "isotonic", "isotonic_boot")
types_multiclass <- c("multinomial", "beta", "isotonic", "isotonic_boot")
# a check function to be called when a container is being `fit()`ted.
# by the time a container is fitted, we have:
# * `adjust_type`, the `type` argument passed to an `adjust_*` function
# * `method`, the `method` argument passed to an `adjust_*` function
# * this argument has already been checked to agree with the kind of
# `adjust_*()` function via `arg_match0()`.
# * `container_type`, the `type` argument either specified in `container()`
# or inferred in `fit.container()`.
check_type <- function(adjust_type,
container_type,
arg = caller_arg(adjust_type),
check_method <- function(method,
type,
arg = caller_arg(method),
call = caller_env()) {
# if no `adjust_type` was supplied, infer a reasonable one based on the
# `container_type`
if (is.null(adjust_type)) {
# if no `method` was supplied, infer a reasonable one based on the `type`
if (is.null(method)) {
switch(
container_type,
type,
regression = return("linear"),
binary = return("logistic"),
multiclass = return("multinomial")
)
}

switch(
container_type,
type,
regression = arg_match0(
adjust_type,
method,
types_regression,
arg_nm = arg,
error_call = call
),
binary = arg_match0(
adjust_type,
method,
types_binary,
arg_nm = arg,
error_call = call
),
multiclass = arg_match0(
adjust_type,
method,
types_multiclass,
arg_nm = arg,
error_call = call
),
arg_match0(
adjust_type,
method,
unique(c(types_regression, types_binary, types_multiclass)),
arg_nm = arg,
error_call = call
)
)

adjust_type
method
}

31 changes: 26 additions & 5 deletions R/validation-rules.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
validate_order <- function(ops, mode, call) {
validate_order <- function(ops, type, call = caller_env()) {
orderings <-
tibble::new_tibble(list(
name = purrr::map_chr(ops, ~ class(.x)[1]),
Expand All @@ -13,12 +13,17 @@ validate_order <- function(ops, mode, call) {
return(invisible(orderings))
}

if (mode == "classification") {
check_classification_order(orderings, call)
} else {
check_regression_order(orderings, call)
if (type == "unknown") {
type <- infer_type(orderings)
}

switch(
type,
regression = check_regression_order(orderings, call),
binary = , multiclass = check_classification_order(orderings, call),
invisible()
)

invisible(orderings)
}

Expand Down Expand Up @@ -83,3 +88,19 @@ check_duplicates <- function(x, call) {
}
invisible(x)
}

infer_type <- function(orderings) {
if (all(orderings$output_all)) {
return("unknown")
}

if (all(orderings$output_numeric | orderings$output_all)) {
return("regression")
}

if (all(orderings$output_prob | orderings$output_class | orderings$output_all)) {
return("binary")
}

"unknown"
}
2 changes: 1 addition & 1 deletion inst/examples/container_regression_example.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ We could manually use `cal_apply()` to adjust predictions, but instead, we'll ad
#| label: post-1

post_obj <-
container(mode = "regression") %>%
container() %>%
adjust_numeric_calibration(bst_cal)
post_obj
```
Expand Down
Loading