ASpredict.as_train {ASML} | R Documentation |
Predicting the KPI value for the algorithms
Description
For each algorithm, the output (KPI) is predicted using the models traing with AStrain()
.
Usage
## S3 method for class 'as_train'
ASpredict(training_object, newdata = NULL, f = NULL, ...)
Arguments
training_object |
list of class |
newdata |
dataframe with the new data to predict. If not present, predictions are computed using the training data. |
f |
function to use for the predictions. If NULL, |
... |
arguments passed to the predict function f when f is not NULL. |
Details
The ASpredict()
uses the prediction function from caret
to compute (for each of the models trained) the predictions for the new data provided by the user.
If the user used a custom function in AStrain()
(given by parameter f
), caret
's default prediction function might not work, and the user might have to provide a custom function for ASpredict()
as well.
Additionally, this custom prediction function allows to pass additional arguments, something that caret
's default prediction function does not.
The object return by the train function used in AStrain()
(caret
's or a custom one) is the one passed to the custom f
function defined by the user. This f
function must return a vector with the predictions.
Value
A data frame with the predictions for each instance (rows), corresponding to each algorithm (columns). In case f is specified, some actions might be needed to get the predictions from the returned value.
Examples
data(branchingsmall)
data_object <- partition_and_normalize(branchingsmall$x, branchingsmall$y, test_size = 0.3,
family_column = 1, split_by_family = TRUE)
training <- AStrain(data_object, method = "glm")
predictions <- ASpredict(training, newdata = data_object$x.test)
qrf_q_predict <- function(modelFit, newdata, what = 0.5, submodels = NULL) {
out <- predict(modelFit, newdata, what = what)
if (is.matrix(out))
out <- out[, 1]
out
}
custom_predictions <- ASpredict(training, newdata = data_object$x.test, f = "qrf_q_predict",
what = 0.25)