interpret {midr} | R Documentation |
Fit MID Models
Description
interpret()
is used to fit a MID model specifically as an interpretable surrogate for black-box predictive models.
A fitted MID model consists of a set of component functions, each with up to two variables.
Usage
interpret(object, ...)
## Default S3 method:
interpret(
object,
x,
y = NULL,
weights = NULL,
pred.fun = get.yhat,
link = NULL,
k = c(NA, NA),
type = c(1L, 1L),
frames = list(),
interaction = FALSE,
terms = NULL,
singular.ok = FALSE,
mode = 1L,
method = NULL,
lambda = 0,
kappa = 1e+06,
na.action = getOption("na.action"),
verbosity = 1L,
encoding.digits = 3L,
use.catchall = FALSE,
catchall = "(others)",
max.ncol = 10000L,
nil = 1e-07,
tol = 1e-07,
pred.args = list(),
...
)
## S3 method for class 'formula'
interpret(
formula,
data = NULL,
model = NULL,
pred.fun = get.yhat,
weights = NULL,
subset = NULL,
na.action = getOption("na.action"),
verbosity = 1L,
mode = 1L,
drop.unused.levels = FALSE,
pred.args = list(),
...
)
Arguments
object |
a fitted model object to be interpreted. |
... |
for |
x |
a matrix or data.frame of predictor variables to be used in the fitting process. The response variable should not be included. |
y |
an optional numeric vector of the model predictions or the response variable. |
weights |
a numeric vector of sample weights for each observation in |
pred.fun |
a function to obtain predictions from a fitted model, where the first argument is for the fitted model and the second argument is for new data. The default is |
link |
a character string specifying the link function: one of "logit", "probit", "cauchit", "cloglog", "identity", "log", "sqrt", "1/mu^2", "inverse", "translogit", "transprobit", "identity-logistic" and "identity-gaussian", or an object containing two functions |
k |
an integer or integer-valued vector of length two. The maximum number of sample points for each variable. If a vector is passed, |
type |
an integer or integer-valued vector of length two. The type of encoding. The effects of quantitative variables are modeled as piecewise linear functions if |
frames |
a named list of encoding frames ("numeric.frame" or "factor.frame" objects). The encoding frames are used to encode the variable of the corresponding name. If the name begins with "|" or ":", the encoding frame is used only for main effects or interactions, respectively. |
interaction |
logical. If |
terms |
a character vector of term labels specifying the set of component functions to be modeled. If not passed, |
singular.ok |
logical. If |
mode |
an integer specifying the method of calculation. If |
method |
an integer specifying the method to be used to solve the least squares problem. A non-negative value will be passed to |
lambda |
the penalty factor for pseudo smoothing. The default is |
kappa |
the penalty factor for centering constraints. Used only when |
na.action |
a function or character string specifying the method of |
verbosity |
the level of verbosity. |
encoding.digits |
an integer. The rounding digits for encoding numeric variables. Used only when |
use.catchall |
logical. If |
catchall |
a character string specifying the catchall level. |
max.ncol |
integer. The maximum number of columns of the design matrix. |
nil |
a threshold for the intercept and coefficients to be treated as zero. The default is |
tol |
a tolerance for the singular value decomposition. The default is |
pred.args |
optional parameters other than the fitted model and new data to be passed to |
formula |
a symbolic description of the MID model to be fit. |
data |
a data.frame, list or environment containing the variables in |
model |
a fitted model object to be interpreted. |
subset |
an optional vector specifying a subset of observations to be used in the fitting process. |
drop.unused.levels |
logical. If |
Details
interpret()
returns a global surrogate model of the target predictive model.
The prediction function of this surrogate model is derived from Maximum Interpretation
Decomposition (MID) applied to the prediction function of the target model
(denoted f(\mathbf{x})
).
The prediction function of the global surrogate model, denoted \mathcal{F}(\mathbf{x})
, has the following structure:
\mathcal{F}(\mathbf{x}) = f_\phi + \sum_{j} f_{j}(x_j) + \sum_{j<k} f_{jk}(x_j, x_k)
where f_\phi
is the intercept, f_{j}(x_j)
is the main effect of feature j
,
and f_{jk}(x_j, x_k)
is the second-order interaction effect between features j
and k
.
To ensure the identifiability (uniqueness) of these decomposed components, they are subject to centering constraints during the fitting process.
Specifically, each main effect function f_j(x_j)
is constrained such that its average over the data distribution of feature X_j
is zero.
Similarly, each second-order interaction effect function f_{jk}(x_j, x_k)
is constrained such that its conditional average over X_j
(for any fixed value x_k
) is zero, and its conditional average over X_k
(for any fixed value x_j
) is also zero.
The surrogate model is fitted using the least squares method, which minimizes the squared error between the predictions of the target model f(\mathbf{x})
and the surrogate model \mathcal{F}(\mathbf{x})
(typically evaluated on a representative dataset).
Value
interpret()
returns a "mid" object with the following components:
weights |
a numeric vector of the sample weights. |
call |
the matched call. |
terms |
the term labels. |
link |
a "link-glm" or "link-midr" object containing the link function. |
intercept |
the intercept. |
encoders |
a list of variable encoders. |
main.effects |
a list of data frames representing the main effects. |
interacions |
a list of data frames representing the interactions. |
ratio |
the ratio of the sum of squared error between the target model predictions and the fitted MID values, to the sum of squared deviations of the target model predictions. |
fitted.matrix |
a matrix showing the breakdown of the predictions into the effects of the component functions. |
linear.predictors |
a numeric vector of the linear predictors. |
fitted.values |
a numeric vector of the fitted values. |
residuals |
a numeric vector of the working residuals. |
na.action |
information about the special handlings of |
Examples
# fit a MID model as a surrogate model
data(cars, package = "datasets")
model <- lm(dist ~ I(speed^2) + speed, cars)
mid <- interpret(dist ~ speed, cars, model)
plot(mid, "speed", intercept = TRUE)
points(cars)
# customize the flexibility of a MID model
data(Nile, package = "datasets")
mid <- interpret(x = 1L:100L, y = Nile, k = 100L)
plot(mid, "x", intercept = TRUE, limits = c(600L, 1300L))
points(x = 1L:100L, y = Nile)
# reduce the number of knots by setting the 'k' parameter
mid <- interpret(x = 1L:100L, y = Nile, k = 10L)
plot(mid, "x", intercept = TRUE, limits = c(600L, 1300L))
points(x = 1L:100L, y = Nile)
# perform a pseudo smoothing by setting the 'lambda' parameter
mid <- interpret(x = 1L:100L, y = Nile, k = 100L, lambda = 100L)
plot(mid, "x", intercept = TRUE, limits = c(600L, 1300L))
points(x = 1L:100L, y = Nile)
# fit a MID model as a predictive model
data(airquality, package = "datasets")
mid <- interpret(Ozone ~ .^2, na.omit(airquality), lambda = .4)
plot(mid, "Wind")
plot(mid, "Temp")
plot(mid, "Wind:Temp", theme = "RdBu")
plot(mid, "Wind:Temp", main.effects = TRUE)