survdnn {survdnn}R Documentation

Fit a Deep Neural Network for Survival Analysis

Description

Trains a deep neural network (DNN) to model right-censored survival data using one of the predefined loss functions: Cox, AFT, or Coxtime.

Usage

survdnn(
  formula,
  data,
  hidden = c(32L, 16L),
  activation = "relu",
  lr = 1e-04,
  epochs = 300L,
  loss = c("cox", "cox_l2", "aft", "coxtime"),
  verbose = TRUE
)

Arguments

formula

A survival formula of the form 'Surv(time, status) ~ predictors'.

data

A data frame containing the variables in the model.

hidden

Integer vector. Sizes of the hidden layers (default: c(32, 16)).

activation

Character string specifying the activation function to use in each layer. Supported options: '"relu"', '"leaky_relu"', '"tanh"', '"sigmoid"', '"gelu"', '"elu"', '"softplus"'.

lr

Learning rate for the Adam optimizer (default: '1e-4').

epochs

Number of training epochs (default: 300).

loss

Character name of the loss function to use. One of '"cox"', '"cox_l2"', '"aft"', or '"coxtime"'.

verbose

Logical; whether to print loss progress every 50 epochs (default: TRUE).

Value

An object of class '"survdnn"' containing:

model

Trained 'nn_module' object.

formula

Original survival formula.

data

Training data used for fitting.

xnames

Predictor variable names.

x_center

Column means of predictors.

x_scale

Column standard deviations of predictors.

loss_history

Vector of loss values per epoch.

final_loss

Final training loss.

loss

Loss function name used ("cox", "aft", etc.).

activation

Activation function used.

hidden

Hidden layer sizes.

lr

Learning rate.

epochs

Number of training epochs.

Examples

set.seed(123)
df <- data.frame(
  time = rexp(100, rate = 0.1),
  status = rbinom(100, 1, 0.7),
  x1 = rnorm(100),
  x2 = rbinom(100, 1, 0.5)
)
mod <- survdnn(Surv(time, status) ~ x1 + x2, data = df, epochs = 5

, loss = "cox", verbose = FALSE)
mod$final_loss

[Package survdnn version 0.6.0 Index]