fit_nnet {ARTtransfer}R Documentation

fit_nnet: Neural Network Wrapper for the ARTtransfer package

Description

This function fits a neural network model using 'nnet()' from the R package nnet. It returns the deviance on a validation set and predictions on a test set. It is designed for use in the 'ART' adaptive and robust transfer learning framework.

Usage

fit_nnet(
  X,
  y,
  X_val,
  y_val,
  X_test,
  min_prod = 1e-05,
  max_prod = 1 - 1e-05,
  ...
)

Arguments

X

A matrix of predictors for the training set.

y

A vector of binary responses for the training set.

X_val

A matrix of predictors for the validation set. If 'NULL', deviance is not calculated.

y_val

A vector of binary responses for the validation set. If 'NULL', deviance is not calculated.

X_test

A matrix of predictors for the test set. If 'NULL', predictions are not generated.

min_prod

A numeric value indicating the minimum probability bound for predictions. Default is '1e-5'.

max_prod

A numeric value indicating the maximum probability bound for predictions. Default is '1-1e-5'.

...

Additional arguments passed to 'nnet()'.

Value

A list containing:

dev

The deviance (negative log-likelihood) on the validation set if provided, otherwise 'NULL'.

pred

The predicted probabilities on the test set if 'X_test' is provided, otherwise 'NULL'.

Examples

# Fit a neural network model with validation and test data
X_train <- matrix(rnorm(100 * 5), 100, 5)
y_train <- rbinom(100, 1, 0.5)
X_val <- matrix(rnorm(50 * 5), 50, 5)
y_val <- rbinom(50, 1, 0.5)
X_test <- matrix(rnorm(20 * 5), 20, 5)

fit <- fit_nnet(X_train, y_train, X_val, y_val, X_test)


[Package ARTtransfer version 1.0.0 Index]