ShrinkageTrees {ShrinkageTrees}R Documentation

General Shrinkage Regression Trees (ShrinkageTrees)

Description

Fits a Bayesian Shrinkage Tree model with flexible global-local priors on the step heights. This function generalizes HorseTrees by allowing different global-local shrinkage priors on the step heights.

Usage

ShrinkageTrees(
  y,
  status = NULL,
  X_train,
  X_test = NULL,
  outcome_type = "continuous",
  timescale = "time",
  number_of_trees = 200,
  prior_type = "horseshoe",
  local_hp = NULL,
  global_hp = NULL,
  power = 2,
  base = 0.95,
  p_grow = 0.4,
  p_prune = 0.4,
  nu = 3,
  q = 0.9,
  sigma = NULL,
  N_post = 1000,
  N_burn = 1000,
  delayed_proposal = 5,
  store_posterior_sample = TRUE,
  seed = NULL,
  verbose = TRUE
)

Arguments

y

Outcome vector. Numeric. Can represent continuous outcomes, binary outcomes (0/1), or follow-up times for survival data.

status

Optional censoring indicator vector (1 = event occurred, 0 = censored). Required if outcome_type = "right-censored".

X_train

Covariate matrix for training. Each row corresponds to an observation, and each column to a covariate.

X_test

Optional covariate matrix for test data. If NULL, defaults to the mean of the training covariates.

outcome_type

Type of outcome. One of "continuous", "binary", or "right-censored".

timescale

Indicates the scale of follow-up times. Options are "time" (nonnegative follow-up times, will be log-transformed internally) or "log" (already log-transformed). Only used when outcome_type = "right-censored".

number_of_trees

Number of trees in the ensemble. Default is 200.

prior_type

Type of prior on the step heights. Options include "horseshoe", "horseshoe_fw", "horseshoe_EB", and "half-cauchy".

local_hp

Local hyperparameter controlling shrinkage on individual step heights. Should typically be set smaller than 1 / sqrt(number_of_trees).

global_hp

Global hyperparameter controlling overall shrinkage. Must be specified for Horseshoe-type priors; ignored for prior_type = "half-cauchy".

power

Power parameter for the tree structure prior. Default is 2.0.

base

Base parameter for the tree structure prior. Default is 0.95.

p_grow

Probability of proposing a grow move. Default is 0.4.

p_prune

Probability of proposing a prune move. Default is 0.4.

nu

Degrees of freedom for the error distribution prior. Default is 3.

q

Quantile hyperparameter for the error variance prior. Default is 0.90.

sigma

Optional known value for error standard deviation. If NULL, estimated from data.

N_post

Number of posterior samples to store. Default is 1000.

N_burn

Number of burn-in iterations. Default is 1000.

delayed_proposal

Number of delayed iterations before proposal. Only for reversible updates. Default is 5.

store_posterior_sample

Logical; whether to store posterior samples for each iteration. Default is TRUE.

seed

Random seed for reproducibility.

verbose

Logical; whether to print verbose output. Default is TRUE.

Details

This function is a flexible generalization of HorseTrees. Instead of using a single Horseshoe prior, it allows specifying different global-local shrinkage configurations for the tree step heights. Currently, four priors have been implemented.

The horseshoe prior is the fully Bayesian global-local shrinkage prior, where both the global and local shrinkage parameters are assigned half-Cauchy distributions with scale hyperparameters global_hp and local_hp, respectively. The global shrinkage parameter is defined separately for each tree, allowing adaptive regularization per tree.

The horseshoe_fw prior (forest-wide horseshoe) is similar to horseshoe, except that the global shrinkage parameter is shared across all trees in the forest simultaneously.

The horseshoe_EB prior is an empirical Bayes variant of the horseshoe prior. Here, the global shrinkage parameter (\tau) is not assigned a prior distribution but instead must be specified directly using global_hp, while local shrinkage parameters still follow half-Cauchy priors. Note: \tau must be provided by the user; it is not estimated by the software.

The half-cauchy prior considers only local shrinkage and does not include a global shrinkage component. It places a half-Cauchy prior on each local shrinkage parameter with scale hyperparameter local_hp.

Value

A named list with the following elements:

train_predictions

Vector of posterior mean predictions on the training data.

test_predictions

Vector of posterior mean predictions on the test data (or on mean covariate vector if X_test not provided).

sigma

Vector of posterior samples of the error variance.

acceptance_ratio

Average acceptance ratio across trees during sampling.

train_predictions_sample

Matrix of posterior samples of training predictions (iterations in rows, observations in columns). Present only if store_posterior_sample = TRUE.

test_predictions_sample

Matrix of posterior samples of test predictions. Present only if store_posterior_sample = TRUE.

train_probabilities

Vector of posterior mean probabilities on the training data (only for outcome_type = "binary").

test_probabilities

Vector of posterior mean probabilities on the test data (only for outcome_type = "binary").

train_probabilities_sample

Matrix of posterior samples of training probabilities (only for outcome_type = "binary" and if store_posterior_sample = TRUE).

test_probabilities_sample

Matrix of posterior samples of test probabilities (only for outcome_type = "binary" and if store_posterior_sample = TRUE).

See Also

HorseTrees, CausalHorseForest, CausalShrinkageForest

Examples

# Example: Continuous outcome with ShrinkageTrees, two priors
n <- 50
p <- 3
X <- matrix(runif(n * p), ncol = p)
X_test <- matrix(runif(n * p), ncol = p)
y <- X[, 1] + rnorm(n)

# Fit ShrinkageTrees with standard horseshoe prior
fit_horseshoe <- ShrinkageTrees(y = y,
                                X_train = X,
                                X_test = X_test,
                                outcome_type = "continuous",
                                number_of_trees = 5,
                                prior_type = "horseshoe",
                                local_hp = 0.1 / sqrt(5),
                                global_hp = 0.1 / sqrt(5),
                                N_post = 10,
                                N_burn = 5,
                                store_posterior_sample = TRUE,
                                verbose = FALSE,
                                seed = 1)

# Fit ShrinkageTrees with half-Cauchy prior
fit_halfcauchy <- ShrinkageTrees(y = y,
                                 X_train = X,
                                 X_test = X_test,
                                 outcome_type = "continuous",
                                 number_of_trees = 5,
                                 prior_type = "half-cauchy",
                                 local_hp = 1 / sqrt(5),
                                 N_post = 10,
                                 N_burn = 5,
                                 store_posterior_sample = TRUE,
                                 verbose = FALSE,
                                 seed = 1)

# Posterior mean predictions
pred_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample)
pred_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample)

# Posteriors of the mean (global average prediction)
post_mean_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample)
post_mean_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample)

# Posterior mean prediction averages
mean_pred_horseshoe <- mean(post_mean_horseshoe)
mean_pred_halfcauchy <- mean(post_mean_halfcauchy)


[Package ShrinkageTrees version 1.0.0 Index]