interpret_community.common.model_wrapper module¶
Defines helpful model wrapper and utils for implicitly rewrapping the model to conform to explainer contracts.
- class interpret_community.common.model_wrapper.BaseWrappedModel(model, eval_function, examples, model_task)¶
Bases:
object
A base class for WrappedClassificationModel and WrappedRegressionModel.
- class interpret_community.common.model_wrapper.WrappedClassificationModel(model, eval_function, examples=None)¶
Bases:
interpret_community.common.model_wrapper.BaseWrappedModel
A class for wrapping a classification model.
- predict(dataset)¶
Predict the output using the wrapped classification model.
- Parameters
dataset (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The dataset to predict on.
- predict_proba(dataset)¶
Predict the output probability using the wrapped model.
- Parameters
dataset (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The dataset to predict_proba on.
- class interpret_community.common.model_wrapper.WrappedClassificationWithoutProbaModel(model)¶
Bases:
object
A class for wrapping a classifier without a predict_proba method.
Note: the classifier may not output numeric values for its predictions. We generate a trival boolean version of predict_proba
- predict(dataset)¶
Predict the output using the wrapped regression model.
- Parameters
dataset (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The dataset to predict on.
- predict_proba(dataset)¶
Predict the output probability using the wrapped model.
- Parameters
dataset (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The dataset to predict_proba on.
- class interpret_community.common.model_wrapper.WrappedPytorchModel(model)¶
Bases:
object
A class for wrapping a PyTorch model in the scikit-learn specification.
- predict(dataset)¶
Predict the output using the wrapped PyTorch model.
- Parameters
dataset (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The dataset to predict on.
- predict_classes(dataset)¶
Predict the class using the wrapped PyTorch model.
- Parameters
dataset (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The dataset to predict on.
- predict_proba(dataset)¶
Predict the output probability using the wrapped PyTorch model.
- Parameters
dataset (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The dataset to predict_proba on.
- class interpret_community.common.model_wrapper.WrappedRegressionModel(model, eval_function, examples=None)¶
Bases:
interpret_community.common.model_wrapper.BaseWrappedModel
A class for wrapping a regression model.
- predict(dataset)¶
Predict the output using the wrapped regression model.
- Parameters
dataset (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The dataset to predict on.
- interpret_community.common.model_wrapper.wrap_model(model, examples, model_task)¶
If needed, wraps the model in a common API based on model task and prediction function contract.
- Parameters
model (model with a predict or predict_proba function.) – The model to evaluate on the examples.
examples (interpret_community.dataset.dataset_wrapper.DatasetWrapper) – The model evaluation examples.
model_task (str) – Optional parameter to specify whether the model is a classification or regression model. In most cases, the type of the model can be inferred based on the shape of the output, where a classifier has a predict_proba method and outputs a 2 dimensional array, while a regressor has a predict method and outputs a 1 dimensional array.
- Returns
The wrapper model.
- Return type
model