CausalHorseForest {ShrinkageTrees} | R Documentation |
Causal Horseshoe Forests
Description
This function fits a (Bayesian) Causal Horseshoe Forest. It can be used for estimation of conditional average treatments effects of survival data given high-dimensional covariates. The outcome is decomposed in a prognostic part (control) and a treatment effect part. For both of these, we specify a Horseshoe Trees regression function.
Usage
CausalHorseForest(
y,
status = NULL,
X_train_control,
X_train_treat,
treatment_indicator_train,
X_test_control = NULL,
X_test_treat = NULL,
treatment_indicator_test = NULL,
outcome_type = "continuous",
timescale = "time",
number_of_trees = 200,
k = 0.1,
power = 2,
base = 0.95,
p_grow = 0.4,
p_prune = 0.4,
nu = 3,
q = 0.9,
sigma = NULL,
N_post = 5000,
N_burn = 5000,
delayed_proposal = 5,
store_posterior_sample = FALSE,
seed = NULL,
verbose = TRUE
)
Arguments
y |
Outcome vector. For survival, represents follow-up times (can be on
original or log scale depending on |
status |
Optional event indicator vector (1 = event occurred,
0 = censored). Required when |
X_train_control |
Covariate matrix for the control forest. Rows correspond to samples, columns to covariates. |
X_train_treat |
Covariate matrix for the treatment forest. Rows correspond to samples, columns to covariates. |
treatment_indicator_train |
Vector indicating treatment assignment for training samples (1 = treated, 0 = control). |
X_test_control |
Optional test covariate matrix for control forest. If
|
X_test_treat |
Optional test covariate matrix for treatment forest. If
|
treatment_indicator_test |
Optional vector indicating treatment assignment for test samples. |
outcome_type |
Type of outcome: one of |
timescale |
For survival outcomes: either |
number_of_trees |
Number of trees in each forest. Default is 200. |
k |
Horseshoe prior scale hyperparameter. Default is 0.1. Controls global-local shrinkage on step heights. |
power |
Power parameter for tree structure prior. Default is 2.0. |
base |
Base parameter for tree structure prior. Default is 0.95. |
p_grow |
Probability of proposing a grow move. Default is 0.4. |
p_prune |
Probability of proposing a prune move. Default is 0.4. |
nu |
Degrees of freedom for the error variance prior. Default is 3. |
q |
Quantile parameter for error variance prior. Default is 0.90. |
sigma |
Optional known standard deviation of the outcome. If
|
N_post |
Number of posterior samples to store. Default is 5000. |
N_burn |
Number of burn-in iterations. Default is 5000. |
delayed_proposal |
Number of delayed iterations before proposal updates. Default is 5. |
store_posterior_sample |
Logical; whether to store posterior samples of
predictions. Default is |
seed |
Random seed for reproducibility. Default is |
verbose |
Logical; whether to print verbose output during sampling.
Default is |
Details
The model separately regularizes the control and treatment trees using Horseshoe priors with global-local shrinkage on the step heights. This approach is designed for robust estimation of heterogeneous treatment effects in high-dimensional settings. It supports continuous and right-censored survival outcomes.
Value
A list containing:
- train_predictions
Posterior mean predictions on training data (combined forest).
- test_predictions
Posterior mean predictions on test data (combined forest).
- train_predictions_control
Estimated control outcomes on training data.
- test_predictions_control
Estimated control outcomes on test data.
- train_predictions_treat
Estimated treatment effects on training data.
- test_predictions_treat
Estimated treatment effects on test data.
- sigma
Vector of posterior samples for the error standard deviation.
- acceptance_ratio_control
Average acceptance ratio in control forest.
- acceptance_ratio_treat
Average acceptance ratio in treatment forest.
- train_predictions_sample_control
Matrix of posterior samples for control predictions (if
store_posterior_sample = TRUE
).- test_predictions_sample_control
Matrix of posterior samples for control predictions (if
store_posterior_sample = TRUE
).- train_predictions_sample_treat
Matrix of posterior samples for treatment effects (if
store_posterior_sample = TRUE
).- test_predictions_sample_treat
Matrix of posterior samples for treatment effects (if
store_posterior_sample = TRUE
).
See Also
HorseTrees
, ShrinkageTrees
, CausalShrinkageForest
Examples
# Example: Continuous outcome and homogenuous treatment effect
n <- 50
p <- 3
X_control <- matrix(runif(n * p), ncol = p)
X_treat <- matrix(runif(n * p), ncol = p)
treatment <- rbinom(n, 1, 0.5)
tau <- 2
y <- X_control[, 1] + (0.5 - treatment) * tau + rnorm(n)
fit <- CausalHorseForest(
y = y,
X_train_control = X_control,
X_train_treat = X_treat,
treatment_indicator_train = treatment,
outcome_type = "continuous",
number_of_trees = 5,
N_post = 10,
N_burn = 5,
store_posterior_sample = TRUE,
verbose = FALSE,
seed = 1
)
## Example: Right-censored survival outcome
# Set data dimensions
n <- 100
p <- 1000
# Generate covariates
X <- matrix(runif(n * p), ncol = p)
X_treat <- X
treatment <- rbinom(n, 1, pnorm(X_treat[1, ] - 1/2))
# Generate true survival times depending on X and treatment
linpred <- X[, 1] - X[, 2] + (treatment - 0.5) * (1 + X[, 2] / 2 + X[, 3] / 3
+ X[, 4] / 4)
true_time <- linpred + rnorm(n, 0, 0.5)
# Generate censoring times
censor_time <- log(rexp(n, rate = 1 / 5))
# Observed times and event indicator
time_obs <- pmin(true_time, censor_time)
status <- as.numeric(true_time == time_obs)
# Estimate propensity score using HorseTrees
fit_prop <- HorseTrees(
y = treatment,
X_train = X,
outcome_type = "binary",
number_of_trees = 200,
N_post = 1000,
N_burn = 1000
)
# Retrieve estimated probability of treatment (propensity score)
propensity <- fit_prop$train_probabilities
# Combine propensity score with covariates for control forest
X_control <- cbind(propensity, X)
# Fit the Causal Horseshoe Forest for survival outcome
fit_surv <- CausalHorseForest(
y = time_obs,
status = status,
X_train_control = X_control,
X_train_treat = X_treat,
treatment_indicator_train = treatment,
outcome_type = "right-censored",
timescale = "log",
number_of_trees = 200,
k = 0.1,
N_post = 1000,
N_burn = 1000,
store_posterior_sample = TRUE
)
## Evaluate and summarize results
# Evaluate C-index if survival package is available
if (requireNamespace("survival", quietly = TRUE)) {
predicted_survtime <- fit_surv$train_predictions
cindex_result <- survival::concordance(survival::Surv(time_obs, status) ~ predicted_survtime)
c_index <- cindex_result$concordance
cat("C-index:", round(c_index, 3), "\n")
} else {
cat("Package 'survival' not available. Skipping C-index computation.\n")
}
# Compute posterior ATE samples
ate_samples <- rowMeans(fit_surv$train_predictions_sample_treat)
mean_ate <- mean(ate_samples)
ci_95 <- quantile(ate_samples, probs = c(0.025, 0.975))
cat("Posterior mean ATE:", round(mean_ate, 3), "\n")
cat("95% credible interval: [", round(ci_95[1], 3), ", ", round(ci_95[2], 3), "]\n", sep = "")
# Plot histogram of ATE samples
hist(
ate_samples,
breaks = 30,
col = "steelblue",
freq = FALSE,
border = "white",
xlab = "Average Treatment Effect (ATE)",
main = "Posterior distribution of ATE"
)
abline(v = mean_ate, col = "orange3", lwd = 2)
abline(v = ci_95, col = "orange3", lty = 2, lwd = 2)
abline(v = 1.541667, col = "darkred", lwd = 2)
legend(
"topright",
legend = c("Mean", "95% CI", "Truth"),
col = c("orange3", "orange3", "red"),
lty = c(1, 2, 1),
lwd = 2
)
## Plot individual CATE estimates
# Summarize posterior distribution per patient
posterior_matrix <- fit_surv$train_predictions_sample_treat
posterior_mean <- colMeans(posterior_matrix)
posterior_ci <- apply(posterior_matrix, 2, quantile, probs = c(0.025, 0.975))
df_cate <- data.frame(
mean = posterior_mean,
lower = posterior_ci[1, ],
upper = posterior_ci[2, ]
)
# Sort patients by posterior mean CATE
df_cate_sorted <- df_cate[order(df_cate$mean), ]
n_patients <- nrow(df_cate_sorted)
# Create the plot
plot(
x = df_cate_sorted$mean,
y = 1:n_patients,
type = "n",
xlab = "CATE per patient (95% credible interval)",
ylab = "Patient index (sorted)",
main = "Posterior CATE estimates",
xlim = range(df_cate_sorted$lower, df_cate_sorted$upper)
)
# Add CATE intervals
segments(
x0 = df_cate_sorted$lower,
x1 = df_cate_sorted$upper,
y0 = 1:n_patients,
y1 = 1:n_patients,
col = "steelblue"
)
# Add mean points
points(df_cate_sorted$mean, 1:n_patients, pch = 16, col = "orange3", lwd = 0.1)
# Add reference line at 0
abline(v = 0, col = "black", lwd = 2)