compute_pfi {veesa}R Documentation

Compute permutation feature importance (PFI)

Description

Function for computing PFI for a given model and dataset (training or testing)

Usage

compute_pfi(x, y, f, K, metric, eps = 1e-15)

Arguments

x

Dataset with n observations and p variables (training or testing)

y

Response variable (or matrix) associated with x

f

Model to explain

K

Number of repetitions to perform for PFI

metric

Metric used to compute PFI (choose from "accuracy", "logloss", and "nmse")

eps

Log loss is undefined for p = 0 or p = 1, so probabilities are

Value

List containing

Examples

# Load packages
library(dplyr)
library(tidyr)
library(randomForest)

# Select a subset of functions from shifted peaks data
sub_ids <-
  shifted_peaks$data |>
  select(data, group, id) |>
  distinct() |>
  group_by(data, group) |>
  slice(1:4) |>
  ungroup()

# Create a smaller version of shifted data
shifted_peaks_sub <-
  shifted_peaks$data |>
  filter(id %in% sub_ids$id)

# Extract times
shifted_peaks_times = unique(shifted_peaks_sub$t)

# Convert training data to matrix
shifted_peaks_train_matrix <-
  shifted_peaks_sub |>
  filter(data == "Training") |>
  select(-t) |>
  mutate(index = paste0("t", index)) |>
  pivot_wider(names_from = index, values_from = y) |>
  select(-data, -id, -group) |>
  as.matrix() |>
  t()

# Obtain veesa pipeline training data
veesa_train <-
  prep_training_data(
    f = shifted_peaks_train_matrix,
    time = shifted_peaks_times,
    fpca_method = "jfpca"
  )

# Obtain response variable values
model_output <-
  shifted_peaks_sub |>
  filter(data == "Training") |>
  select(id, group) |>
  distinct()

# Prepare data for model
model_data <-
  veesa_train$fpca_res$coef |>
  data.frame() |>
  mutate(group = factor(model_output$group))

# Train model
set.seed(20210301)
rf <-
  randomForest(
    formula = group ~ .,
    data = model_data
  )

# Compute feature importance values
pfi <-
  compute_pfi(
    x = model_data |> select(-group),
    y = model_data$group,
    f = rf,
    K = 1,
    metric = "accuracy"
 )

[Package veesa version 0.1.6 Index]