CausalShrinkageForest {ShrinkageTrees} | R Documentation |
General Causal Shrinkage Forests
Description
Fits a (Bayesian) Causal Shrinkage Forest model for estimating heterogeneous treatment effects.
This function generalizes CausalHorseForest
by allowing flexible
global-local shrinkage priors on the step heights in both the control and treatment forests.
It supports continuous and right-censored survival outcomes.
Usage
CausalShrinkageForest(
y,
status = NULL,
X_train_control,
X_train_treat,
treatment_indicator_train,
X_test_control = NULL,
X_test_treat = NULL,
treatment_indicator_test = NULL,
outcome_type = "continuous",
timescale = "time",
number_of_trees_control = 200,
number_of_trees_treat = 200,
prior_type_control = "horseshoe",
prior_type_treat = "horseshoe",
local_hp_control,
local_hp_treat,
global_hp_control = NULL,
global_hp_treat = NULL,
power = 2,
base = 0.95,
p_grow = 0.4,
p_prune = 0.4,
nu = 3,
q = 0.9,
sigma = NULL,
N_post = 5000,
N_burn = 5000,
delayed_proposal = 5,
store_posterior_sample = FALSE,
seed = NULL,
verbose = TRUE
)
Arguments
y |
Outcome vector. Numeric. Represents continuous outcomes or follow-up times. |
status |
Optional event indicator vector (1 = event occurred, 0 = censored).
Required when |
X_train_control |
Covariate matrix for the control forest. Rows correspond to samples, columns to covariates. |
X_train_treat |
Covariate matrix for the treatment forest. |
treatment_indicator_train |
Vector indicating treatment assignment for training samples (1 = treated, 0 = control). |
X_test_control |
Optional covariate matrix for control forest test data. Defaults to
column means of |
X_test_treat |
Optional covariate matrix for treatment forest test data. Defaults to
column means of |
treatment_indicator_test |
Optional vector indicating treatment assignment for test data. |
outcome_type |
Type of outcome: one of |
timescale |
For survival outcomes: either |
number_of_trees_control |
Number of trees in the control forest. Default is 200. |
number_of_trees_treat |
Number of trees in the treatment forest. Default is 200. |
prior_type_control |
Type of prior on control forest step heights. One of
|
prior_type_treat |
Type of prior on treatment forest step heights. Same options as
|
local_hp_control |
Local hyperparameter controlling shrinkage on individual steps (control forest). Required for all prior types. |
local_hp_treat |
Local hyperparameter for treatment forest. |
global_hp_control |
Global hyperparameter for control forest. Required for horseshoe-type
priors; ignored for |
global_hp_treat |
Global hyperparameter for treatment forest. |
power |
Power parameter for tree structure prior. Default is 2.0. |
base |
Base parameter for tree structure prior. Default is 0.95. |
p_grow |
Probability of proposing a grow move. Default is 0.4. |
p_prune |
Probability of proposing a prune move. Default is 0.4. |
nu |
Degrees of freedom for the error variance prior. Default is 3. |
q |
Quantile parameter for error variance prior. Default is 0.90. |
sigma |
Optional known standard deviation of the outcome. If NULL, estimated from data. |
N_post |
Number of posterior samples to store. Default is 5000. |
N_burn |
Number of burn-in iterations. Default is 5000. |
delayed_proposal |
Number of delayed iterations before proposal updates. Default is 5. |
store_posterior_sample |
Logical; whether to store posterior samples of predictions.
Default is |
seed |
Random seed for reproducibility. Default is |
verbose |
Logical; whether to print verbose output. Default is |
Details
This function is a flexible generalization of CausalHorseForest
.
The Causal Shrinkage Forest model decomposes the outcome into a prognostic
(control) and a treatment effect part. Each part is modeled by its own
shrinkage tree ensemble, with separate flexible global-local shrinkage
priors. It is particularly useful for estimating heterogeneous treatment
effects in high-dimensional settings.
The horseshoe
prior is the fully Bayesian global-local shrinkage
prior, where both the global and local shrinkage parameters are assigned
half-Cauchy distributions with scale hyperparameters global_hp
and
local_hp
, respectively. The global shrinkage parameter is defined
separately for each tree, allowing adaptive regularization per tree.
The horseshoe_fw
prior (forest-wide horseshoe) is similar to
horseshoe
, except that the global shrinkage parameter is shared
across all trees in the forest simultaneously.
The horseshoe_EB
prior is an empirical Bayes variant of the
horseshoe
prior. Here, the global shrinkage parameter (\tau
)
is not assigned a prior distribution but instead must be specified directly
using global_hp
, while local shrinkage parameters still follow
half-Cauchy priors. Note: \tau
must be provided by the user; it is
not estimated by the software.
The half-cauchy
prior considers only local shrinkage and does not
include a global shrinkage component. It places a half-Cauchy prior on each
local shrinkage parameter with scale hyperparameter local_hp
.
Value
A list containing:
- train_predictions
Posterior mean predictions on training data (combined forest).
- test_predictions
Posterior mean predictions on test data (combined forest).
- train_predictions_control
Estimated control outcomes on training data.
- test_predictions_control
Estimated control outcomes on test data.
- train_predictions_treat
Estimated treatment effects on training data.
- test_predictions_treat
Estimated treatment effects on test data.
- sigma
Vector of posterior samples for the error standard deviation.
- acceptance_ratio_control
Average acceptance ratio in control forest.
- acceptance_ratio_treat
Average acceptance ratio in treatment forest.
- train_predictions_sample_control
Matrix of posterior samples for control predictions (if
store_posterior_sample = TRUE
).- test_predictions_sample_control
Matrix of posterior samples for control predictions (if
store_posterior_sample = TRUE
).- train_predictions_sample_treat
Matrix of posterior samples for treatment effects (if
store_posterior_sample = TRUE
).- test_predictions_sample_treat
Matrix of posterior samples for treatment effects (if
store_posterior_sample = TRUE
).
See Also
CausalHorseForest
, ShrinkageTrees
,
HorseTrees
Examples
# Example: Continuous outcome, homogenuous treatment effect, two priors
n <- 50
p <- 3
X <- matrix(runif(n * p), ncol = p)
X_treat <- X_control <- X
treat <- rbinom(n, 1, X[,1])
tau <- 2
y <- X[, 1] + (0.5 - treat) * tau + rnorm(n)
# Fit a standard Causal Horseshoe Forest
fit_horseshoe <- CausalShrinkageForest(y = y,
X_train_control = X_control,
X_train_treat = X_treat,
treatment_indicator_train = treat,
outcome_type = "continuous",
number_of_trees_treat = 5,
number_of_trees_control = 5,
prior_type_control = "horseshoe",
prior_type_treat = "horseshoe",
local_hp_control = 0.1/sqrt(5),
local_hp_treat = 0.1/sqrt(5),
global_hp_control = 0.1/sqrt(5),
global_hp_treat = 0.1/sqrt(5),
N_post = 10,
N_burn = 5,
store_posterior_sample = TRUE,
verbose = FALSE,
seed = 1
)
# Fit a Causal Shrinkage Forest with half-cauchy prior
fit_halfcauchy <- CausalShrinkageForest(y = y,
X_train_control = X_control,
X_train_treat = X_treat,
treatment_indicator_train = treat,
outcome_type = "continuous",
number_of_trees_treat = 5,
number_of_trees_control = 5,
prior_type_control = "half-cauchy",
prior_type_treat = "half-cauchy",
local_hp_control = 1/sqrt(5),
local_hp_treat = 1/sqrt(5),
N_post = 10,
N_burn = 5,
store_posterior_sample = TRUE,
verbose = FALSE,
seed = 1
)
# Posterior mean CATEs
CATE_horseshoe <- colMeans(fit_horseshoe$train_predictions_sample_treat)
CATE_halfcauchy <- colMeans(fit_halfcauchy$train_predictions_sample_treat)
# Posteriors of the ATE
post_ATE_horseshoe <- rowMeans(fit_horseshoe$train_predictions_sample_treat)
post_ATE_halfcauchy <- rowMeans(fit_halfcauchy$train_predictions_sample_treat)
# Posterior mean ATE
ATE_horseshoe <- mean(post_ATE_horseshoe)
ATE_halfcauchy <- mean(post_ATE_halfcauchy)