ShrinkageTrees {ShrinkageTrees} | R Documentation |
General Shrinkage Regression Trees (ShrinkageTrees)
Description
Fits a Bayesian Shrinkage Tree model with flexible global-local priors on the
step heights. This function generalizes HorseTrees
by allowing
different global-local shrinkage priors on the step heights.
Usage
ShrinkageTrees(
y,
status = NULL,
X_train,
X_test = NULL,
outcome_type = "continuous",
timescale = "time",
number_of_trees = 200,
prior_type = "horseshoe",
local_hp = NULL,
global_hp = NULL,
power = 2,
base = 0.95,
p_grow = 0.4,
p_prune = 0.4,
nu = 3,
q = 0.9,
sigma = NULL,
N_post = 1000,
N_burn = 1000,
delayed_proposal = 5,
store_posterior_sample = TRUE,
seed = NULL,
verbose = TRUE
)
Arguments
y |
Outcome vector. Numeric. Can represent continuous outcomes, binary outcomes (0/1), or follow-up times for survival data. |
status |
Optional censoring indicator vector (1 = event occurred,
0 = censored). Required if |
X_train |
Covariate matrix for training. Each row corresponds to an observation, and each column to a covariate. |
X_test |
Optional covariate matrix for test data. If NULL, defaults to the mean of the training covariates. |
outcome_type |
Type of outcome. One of |
timescale |
Indicates the scale of follow-up times. Options are
|
number_of_trees |
Number of trees in the ensemble. Default is 200. |
prior_type |
Type of prior on the step heights. Options include
|
local_hp |
Local hyperparameter controlling shrinkage on individual step heights. Should typically be set smaller than 1 / sqrt(number_of_trees). |
global_hp |
Global hyperparameter controlling overall shrinkage.
Must be specified for Horseshoe-type priors; ignored for |
power |
Power parameter for the tree structure prior. Default is 2.0. |
base |
Base parameter for the tree structure prior. Default is 0.95. |
p_grow |
Probability of proposing a grow move. Default is 0.4. |
p_prune |
Probability of proposing a prune move. Default is 0.4. |
nu |
Degrees of freedom for the error distribution prior. Default is 3. |
q |
Quantile hyperparameter for the error variance prior. Default is 0.90. |
sigma |
Optional known value for error standard deviation. If NULL, estimated from data. |
N_post |
Number of posterior samples to store. Default is 1000. |
N_burn |
Number of burn-in iterations. Default is 1000. |
delayed_proposal |
Number of delayed iterations before proposal. Only for reversible updates. Default is 5. |
store_posterior_sample |
Logical; whether to store posterior samples for each iteration. Default is TRUE. |
seed |
Random seed for reproducibility. |
verbose |
Logical; whether to print verbose output. Default is TRUE. |
Details
This function is a flexible generalization of HorseTrees
.
Instead of using a single Horseshoe prior, it allows specifying different
global-local shrinkage configurations for the tree step heights.
Currently, four priors have been implemented.
The horseshoe
prior is the fully Bayesian global-local shrinkage
prior, where both the global and local shrinkage parameters are assigned
half-Cauchy distributions with scale hyperparameters global_hp
and
local_hp
, respectively. The global shrinkage parameter is defined
separately for each tree, allowing adaptive regularization per tree.
The horseshoe_fw
prior (forest-wide horseshoe) is similar to
horseshoe
, except that the global shrinkage parameter is shared
across all trees in the forest simultaneously.
The horseshoe_EB
prior is an empirical Bayes variant of the
horseshoe
prior. Here, the global shrinkage parameter (\tau
)
is not assigned a prior distribution but instead must be specified directly
using global_hp
, while local shrinkage parameters still follow
half-Cauchy priors. Note: \tau
must be provided by the user; it is
not estimated by the software.
The half-cauchy
prior considers only local shrinkage and does not
include a global shrinkage component. It places a half-Cauchy prior on each
local shrinkage parameter with scale hyperparameter local_hp
.
Value
A named list with the following elements:
- train_predictions
Vector of posterior mean predictions on the training data.
- test_predictions
Vector of posterior mean predictions on the test data (or on mean covariate vector if
X_test
not provided).- sigma
Vector of posterior samples of the error variance.
- acceptance_ratio
Average acceptance ratio across trees during sampling.
- train_predictions_sample
Matrix of posterior samples of training predictions (iterations in rows, observations in columns). Present only if
store_posterior_sample = TRUE
.- test_predictions_sample
Matrix of posterior samples of test predictions. Present only if
store_posterior_sample = TRUE
.- train_probabilities
Vector of posterior mean probabilities on the training data (only for
outcome_type = "binary"
).- test_probabilities
Vector of posterior mean probabilities on the test data (only for
outcome_type = "binary"
).- train_probabilities_sample
Matrix of posterior samples of training probabilities (only for
outcome_type = "binary"
and ifstore_posterior_sample = TRUE
).- test_probabilities_sample
Matrix of posterior samples of test probabilities (only for
outcome_type = "binary"
and ifstore_posterior_sample = TRUE
).
See Also
HorseTrees
, CausalHorseForest
, CausalShrinkageForest
Examples
# Example: Continuous outcome with ShrinkageTrees, two priors
n <- 50
p <- 3
X <- matrix(runif(n * p), ncol = p)
X_test <- matrix(runif(n * p), ncol = p)
y <- X[, 1] + rnorm(n)
# Fit ShrinkageTrees with standard horseshoe prior
fit_horseshoe <- ShrinkageTrees(y = y,
X_train = X,
X_test = X_test,
outcome_type = "continuous",
number_of_trees = 5,
prior_type = "horseshoe",
local_hp = 0.1 / sqrt(5),
global_hp = 0.1 / sqrt(5),
N_post = 10,
N_burn = 5,
store_posterior_sample = TRUE,
verbose = FALSE,
seed = 1)
# Fit ShrinkageTrees with half-Cauchy prior
fit_halfcauchy <- ShrinkageTrees(y = y,
X_train = X,
X_test = X_test,
outcome_type = "continuous",
number_of_trees = 5,
prior_type = "half-cauchy",
local_hp = 1 / sqrt(5),
N_post = 10,
N_burn = 5,
store_posterior_sample = TRUE,
verbose = FALSE,
seed = 1)
# Posterior mean predictions
pred_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample)
pred_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample)
# Posteriors of the mean (global average prediction)
post_mean_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample)
post_mean_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample)
# Posterior mean prediction averages
mean_pred_horseshoe <- mean(post_mean_horseshoe)
mean_pred_halfcauchy <- mean(post_mean_halfcauchy)