Forest {stochtree} | R Documentation |
Class that stores a single ensemble of decision trees (often treated as the "active forest")
Description
Wrapper around a C++ tree ensemble
Public fields
forest_ptr
External pointer to a C++ TreeEnsemble class
internal_forest_is_empty
Whether the forest has not yet been "initialized" such that its
predict
function can be called.
Methods
Public methods
Method new()
Create a new Forest object.
Usage
Forest$new( num_trees, leaf_dimension = 1, is_leaf_constant = FALSE, is_exponentiated = FALSE )
Arguments
num_trees
Number of trees in the forest
leaf_dimension
Dimensionality of the outcome model
is_leaf_constant
Whether leaf is constant
is_exponentiated
Whether forest predictions should be exponentiated before being returned
Returns
A new Forest
object.
Method predict()
Predict forest on every sample in forest_dataset
Usage
Forest$predict(forest_dataset)
Arguments
forest_dataset
ForestDataset
R class
Returns
vector of predictions with as many rows as in forest_dataset
Method predict_raw()
Predict "raw" leaf values (without being multiplied by basis) for every sample in forest_dataset
Usage
Forest$predict_raw(forest_dataset)
Arguments
forest_dataset
ForestDataset
R class
Returns
Array of predictions for each observation in forest_dataset
and
each sample in the ForestSamples
class with each prediction having the
dimensionality of the forests' leaf model. In the case of a constant leaf model
or univariate leaf regression, this array is a vector (length is the number of
observations). In the case of a multivariate leaf regression,
this array is a matrix (number of observations by leaf model dimension,
number of samples).
Method set_root_leaves()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Usage
Forest$set_root_leaves(leaf_value)
Arguments
leaf_value
Constant leaf value(s) to be fixed for each tree in the ensemble indexed by
forest_num
. Can be either a single number or a vector, depending on the forest's leaf dimension.
Method prepare_for_sampler()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Usage
Forest$prepare_for_sampler( dataset, outcome, forest_model, leaf_model_int, leaf_value )
Arguments
dataset
ForestDataset
Dataset class (covariates, basis, etc...)outcome
Outcome
Outcome class (residual / partial residual)forest_model
ForestModel
object storing tracking structures used in training / samplingleaf_model_int
Integer value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
leaf_value
Constant leaf value(s) to be fixed for each tree in the ensemble indexed by
forest_num
. Can be either a single number or a vector, depending on the forest's leaf dimension.
Method adjust_residual()
Adjusts residual based on the predictions of a forest
This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
Usage
Forest$adjust_residual(dataset, outcome, forest_model, requires_basis, add)
Arguments
dataset
ForestDataset
object storing the covariates and bases for a given forestoutcome
Outcome
object storing the residuals to be updated based on forest predictionsforest_model
ForestModel
object storing tracking structures used in training / samplingrequires_basis
Whether or not a forest requires a basis for prediction
add
Whether forest predictions should be added to or subtracted from residuals
Method num_trees()
Return number of trees in each ensemble of a Forest
object
Usage
Forest$num_trees()
Returns
Tree count
Method leaf_dimension()
Return output dimension of trees in a Forest
object
Usage
Forest$leaf_dimension()
Returns
Leaf node parameter size
Method is_constant_leaf()
Return constant leaf status of trees in a Forest
object
Usage
Forest$is_constant_leaf()
Returns
TRUE
if leaves are constant, FALSE
otherwise
Method is_exponentiated()
Return exponentiation status of trees in a Forest
object
Usage
Forest$is_exponentiated()
Returns
TRUE
if leaf predictions must be exponentiated, FALSE
otherwise
Method add_numeric_split_tree()
Add a numeric (i.e. X[,i] <= c
) split to a given tree in the ensemble
Usage
Forest$add_numeric_split_tree( tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value )
Arguments
tree_num
Index of the tree to be split
leaf_num
Leaf to be split
feature_num
Feature that defines the new split
split_threshold
Value that defines the cutoff of the new split
left_leaf_value
Value (or vector of values) to assign to the newly created left node
right_leaf_value
Value (or vector of values) to assign to the newly created right node
Method get_tree_leaves()
Retrieve a vector of indices of leaf nodes for a given tree in a given forest
Usage
Forest$get_tree_leaves(tree_num)
Arguments
tree_num
Index of the tree for which leaf indices will be retrieved
Method get_tree_split_counts()
Retrieve a vector of split counts for every training set variable in a given tree in the forest
Usage
Forest$get_tree_split_counts(tree_num, num_features)
Arguments
tree_num
Index of the tree for which split counts will be retrieved
num_features
Total number of features in the training set
Method get_forest_split_counts()
Retrieve a vector of split counts for every training set variable in the forest
Usage
Forest$get_forest_split_counts(num_features)
Arguments
num_features
Total number of features in the training set
Method tree_max_depth()
Maximum depth of a specific tree in the forest
Usage
Forest$tree_max_depth(tree_num)
Arguments
tree_num
Tree index within forest
Returns
Maximum leaf depth
Method average_max_depth()
Average the maximum depth of each tree in the forest
Usage
Forest$average_max_depth()
Returns
Average maximum depth
Method is_empty()
When a forest object is created, it is "empty" in the sense that none
of its component trees have leaves with values. There are two ways to
"initialize" a Forest object. First, the set_root_leaves()
method
simply initializes every tree in the forest to a single node carrying
the same (user-specified) leaf value. Second, the prepare_for_sampler()
method initializes every tree in the forest to a single node with the
same value and also propagates this information through to a ForestModel
object, which must be synchronized with a Forest during a forest
sampler loop.
Usage
Forest$is_empty()
Returns
TRUE
if a Forest has not yet been initialized with a constant
root value, FALSE
otherwise if the forest has already been
initialized / grown.