fusedTree {fusedTree}R Documentation

Fit a fusedTree model with or without fusion penalty

Description

Fits a fusedTree model by solving a penalized regression problem using either a linear, logistic, or Cox model. The model includes both a standard ridge (L2) penalty and an optional fusion penalty to encourage similarity between leaf node-specific omics effects.

Usage

fusedTree(
  Tree,
  X,
  Y,
  Z,
  LinVars = TRUE,
  model,
  lambda,
  alpha,
  maxIter = 50,
  minSuccDiff = 10^(-10),
  dat = FALSE,
  verbose = TRUE
)

Arguments

Tree

The fitted tree object. Should be created using rpart. Support for other tree-fitting packages (e.g., partykit) may be added in the future.

X

A matrix of omics data with dimensions (sample size × number of omics variables).

Y

The response variable. Can be:

  • numeric (for linear regression),

  • binary (0/1, for logistic regression),

  • a survival object created by Surv() (right-censored data only).

Z

A data frame of clinical covariates used to fit the tree. Must be a data.frame, not a matrix.

LinVars

Logical. Whether to include continuous clinical variables linearly in the model (in addition to the tree structure). Defaults to TRUE.

model

Character. Specifies the type of outcome model to fit. One of: "linear", "logistic", or "cox".

lambda

Numeric. Value for the standard ridge (L2) penalty.

alpha

Numeric. Value for the fusion penalty.

maxIter

Integer. Maximum number of iterations for the IRLS (iterative reweighted least squares) algorithm. Used only when model = "logistic" or "cox". Defaults to 50.

minSuccDiff

Numeric. The minimum difference in log-likelihood between successive iterations of IRLS to declare convergence. Only used when model = "logistic" or "cox".

dat

Logical. Whether to return the data used in model fitting (i.e., omics, clinical, and response). Defaults to FALSE.

verbose

Logical. Whether to print progress updates from the IRLS algorithm. Only applies to model = "logistic" or "cox".

Details

Linear model: Estimated using a closed-form analytic solution.

Logistic and Cox models: Estimated using IRLS (iterative reweighted least squares), equivalent to the Newton-Raphson algorithm.

Cox model: The full likelihood approach is used, following van Houwelingen et al. (2005). See also van de Wiel et al. (2021) for additional details on penalized regression for survival outcomes.

Value

A list with the following components:

Tree

The fitted tree object from 'rpart'.

Effects

A named numeric vector of estimated effect sizes, including: intercepts (tree leaf nodes), omics effects (per node), and linear clinical effects (if LinVars = TRUE).

Breslow

(Optional) The breslow estimates of the baseline hazard ht and the cumulative baseline hazard Ht for each time point. Only returned for model = "cox".

Parameters

A list of model parameters used in fitting (e.g., lambda, alpha, model, etc.).

Clinical

(Optional) The clinical design matrix used in fitting, if dat = TRUE.

Omics

(Optional) The omics design matrix used in fitting, if dat = TRUE.

Response

(Optional) The response vector used in fitting, if dat = TRUE.

The returned list object is of class S3 for which predict() is available

References

porridge

van Houwelingen, H. C., et al.. (2005). Cross-validated Cox regression on microarray gene expression data. Stad Med

van de Wiel, M. A., et al. (2021). Fast Cross-validation for Multi-penalty High-dimensional Ridge Regression. J Comput Graph Stat

Examples

p = 5 # number of omics variables (low for illustration)
p_Clin = 5 # number of clinical variables
N = 100 # sample size
# simulate from Friedman-like function
g <- function(z) {
  15 * sin(pi * z[,1] * z[,2]) + 10 * (z[,3] - 0.5)^2 + 2 * exp(z[,4]) + 2 * z[,5]
}
set.seed(11)
Z <- as.data.frame(matrix(runif(N * p_Clin), nrow = N))
X <- matrix(rnorm(N * p), nrow = N)            # omics data
betas <- c(1,-1,3,4,2)                         # omics effects
Y <- g(Z) + X %*% betas + rnorm(N)             # continuous outcome
Y <- as.vector(Y)
dat = cbind.data.frame(Y, Z) #set-up data correctly for rpart
rp <- rpart::rpart(Y ~ ., data = dat,
                   control = rpart::rpart.control(xval = 5, minbucket = 10),
                   model = TRUE)
cp = rp$cptable[,1][which.min(rp$cptable[,4])] # best model according to pruning
Treefit <- rpart::prune(rp, cp = cp)
plot(Treefit)
folds <- CVfoldsTree(Y = Y, Tree = Treefit, Z = Z, model = "linear")
optPenalties <- PenOpt(Tree = Treefit, X = X, Y = Y, Z = Z,
                       model = "linear", lambdaInit = 10, alphaInit = 10,
                       loss = "loglik",
                       LinVars = FALSE,
                       folds = folds, multistart = FALSE)
optPenalties

# with fusion
fit <- fusedTree(Tree = Treefit, X = X, Y = Y, Z = Z,
                    LinVars = FALSE, model = "linear",
                    lambda = optPenalties[1],
                    alpha = optPenalties[2])
# without fusion
fit1 <- fusedTree(Tree = Treefit, X = X, Y = Y, Z = Z,
                     LinVars = FALSE, model = "linear",
                     lambda = optPenalties[1],
                     alpha = 0)
#compare effect estimates
fit$Effects
fit1$Effects

[Package fusedTree version 1.0.1 Index]