CVfoldsTree {fusedTree} | R Documentation |
Create balanced cross-validation folds for hyperparameter tuning
Description
Constructs repeated K-fold cross-validation folds, balanced with respect to the fitted tree structure and outcome (if applicable). The folds contain only the test sample indices. This function is useful for tuning penalty parameters in the fusedTree model.
Usage
CVfoldsTree(Y, Tree, Z, model = NULL, kfold = 5, nrepeat = 3)
Arguments
Y |
The response variable. Should be:
Only right-censored survival data is currently supported. |
Tree |
A fitted decision tree, typically created using |
Z |
A |
model |
Character. Specifies the type of outcome model. Must be one of:
|
kfold |
Integer. Number of folds K for cross-validation. Defaults to 5. |
nrepeat |
Integer. Number of times the K-fold cross-validation is repeated. Defaults to 3. |
Details
For binary and survival outcomes, the function ensures that the proportion of cases vs. controls (or events vs. censored observations) remains relatively constant across folds. In addition, samples are balanced across the leaf nodes of the fitted tree to ensure consistency in node composition between folds.
Value
A list of length kfold × nrepeat
, where each element contains
the test indices for a specific fold. These indices can be used to
systematically split the data during cross-validation.
Examples
p = 5 # number of omics variables (low for illustration)
p_Clin = 5 # number of clinical variables
N = 100 # sample size
# simulate from Friedman-like function
g <- function(z) {
15 * sin(pi * z[,1] * z[,2]) + 10 * (z[,3] - 0.5)^2 + 2 * exp(z[,4]) + 2 * z[,5]
}
Z <- as.data.frame(matrix(runif(N * p_Clin), nrow = N))
X <- matrix(rnorm(N * p), nrow = N) # omics data
betas <- c(1,-1,3,4,2) # omics effects
Y <- g(Z) + X %*% betas + rnorm(N) # continuous outcome
Y <- as.vector(Y)
dat = cbind.data.frame(Y, Z) #set-up data correctly for rpart
rp <- rpart::rpart(Y ~ ., data = dat,
control = rpart::rpart.control(xval = 5, minbucket = 10),
model = TRUE)
cp = rp$cptable[,1][which.min(rp$cptable[,4])] # best model according to pruning
Treefit <- rpart::prune(rp, cp = cp)
plot(Treefit)
folds <- CVfoldsTree(Y = Y, Tree = Treefit, Z = Z, model = "linear")