-
-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG] Adds Permutation Importance (#13146)
* ENH Adds files * ENH Adds permutation importance * RFC Better names * STY Flake8 * ENH: Adds inspect module * DOC Adds pre_dispatch * DOC Adds permutation importance example * Trigger CI * BLD Adds inspect to configuration * RFC Update to only inspect fitted model * RFC Removes parameters * ENH: Adds pandas support * STY Flake8 * DOC Adds new permutation importance example * ENH Renames module to model_inspection * DOC Fix links * DOC Fixes image link * DOC Fixes image link * DOC Spelling * DOC * TST Fix keyword * Rework RF Imp vs Perm Imp example (#4) * WIP * WIP * WIP * DOC Adds multcollinear features example * WIP * DOC: Clean up docs * TST Adds tests for strings * STY Indent correction * WIP * ENH Uses check_X_y * TST Adds test with strings * STY Fix * TST Adds column transformer to test * CLN Address comments * CLN Removes import * TST Adds test with nan * CLN Removes import * ENH Parallel * DOC comments * ENH Better handling of pandas * ENH Clear checking of pandas dataframe * STY Formatting * ENH Copies in parallel helper * DOC Adds comments * BUG Fix copying * BUG Fix for pandas * BUG Fix for pandas * REV * BLD Trigger CI * BUG Fix * BUG Fix * TST Does this work * BUG Fixes test * BUG Fixes test * BUG Fix * BUG Fix * BUG Fix * STY Fix * TST Fix * TST Fix segfault * CLN Address comments * CLN Address comments * ENH Returns a bunch * STY Flake8 * CLN Renames bunch key * DOC Updates api * DOC Updates api * TST Adds permutation test with linear_regression * DOC update * DOC Fix label cutoff * CLN Address comments * TST Adds test for random_state effect * DOC Adds permutation importance * DOC Adds ogrisel suggestion * DOC Address guillaumes comments * DOC Address andreas comments * DOC Update
- Loading branch information
1 parent
e4aac74
commit d1c52f4
Showing
8 changed files
with
653 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
|
||
.. _permutation_importance: | ||
|
||
Permutation feature importance | ||
============================== | ||
|
||
.. currentmodule:: sklearn.inspection | ||
|
||
Permutation feature importance is a model inspection technique that can be used | ||
for any `fitted` `estimator` when the data is rectangular. This is especially | ||
useful for non-linear or opaque `estimators`. The permutation feature | ||
importance is defined to be the decrease in a model score when a single feature | ||
value is randomly shuffled [1]_. This procedure breaks the relationship between | ||
the feature and the target, thus the drop in the model score is indicative of | ||
how much the model depends on the feature. This technique benefits from being | ||
model agnostic and can be calculated many times with different permutations of | ||
the feature. | ||
|
||
The :func:`permutation_importance` function calculates the feature importance | ||
of `estimators` for a given dataset. The ``n_repeats`` parameter sets the number | ||
of times a feature is randomly shuffled and returns a sample of feature | ||
importances. Permutation importances can either be computed on the training set | ||
or an held-out testing or validation set. Using a held-out set makes it | ||
possible to highlight which features contribute the most to the generalization | ||
power of the inspected model. Features that are important on the training set | ||
but not on the held-out set might cause the model to overfit. | ||
|
||
Note that features that are deemed non-important for some model with a | ||
low predictive performance could be highly predictive for a model that | ||
generalizes better. The conclusions should always be drawn in the context of | ||
the specific model under inspection and cannot be automatically generalized to | ||
the intrinsic predictive value of the features by them-selves. Therefore it is | ||
always important to evaluate the predictive power of a model using a held-out | ||
set (or better with cross-validation) prior to computing importances. | ||
|
||
Relation to impurity-based importance in trees | ||
---------------------------------------------- | ||
|
||
Tree based models provides a different measure of feature importances based | ||
on the mean decrease in impurity (MDI, the splitting criterion). This gives | ||
importance to features that may not be predictive on unseen data. The | ||
permutation feature importance avoids this issue, since it can be applied to | ||
unseen data. Furthermore, impurity-based feature importance for trees | ||
are strongly biased and favor high cardinality features | ||
(typically numerical features). Permutation-based feature importances do not | ||
exhibit such a bias. Additionally, the permutation feature importance may use | ||
an arbitrary metric on the tree's predictions. These two methods of obtaining | ||
feature importance are explored in: | ||
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`. | ||
|
||
Strongly correlated features | ||
---------------------------- | ||
|
||
When two features are correlated and one of the features is permuted, the model | ||
will still have access to the feature through its correlated feature. This will | ||
result in a lower importance for both features, where they might *actually* be | ||
important. One way to handle this is to cluster features that are correlated | ||
and only keep one feature from each cluster. This use case is explored in: | ||
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py`. | ||
|
||
.. topic:: Examples: | ||
|
||
* :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py` | ||
* :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py` | ||
|
||
.. topic:: References: | ||
|
||
.. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, | ||
2001. https://doi.org/10.1023/A:1010933404324 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
""" | ||
================================================================ | ||
Permutation Importance vs Random Forest Feature Importance (MDI) | ||
================================================================ | ||
In this example, we will compare the impurity-based feature importance of | ||
:class:`~sklearn.ensemble.RandomForestClassifier` with the | ||
permutation importance on the titanic dataset using | ||
:func:`~sklearn.inspection.permutation_importance`. We will show that the | ||
impurity-based feature importance can inflate the importance of numerical | ||
features. | ||
Furthermore, the impurity-based feature importance of random forests suffers | ||
from being computed on statistics derived from the training dataset: the | ||
importances can be high even for features that are not predictive of the target | ||
variable, as long as the model has the capacity to use them to overfit. | ||
This example shows how to use Permutation Importances as an alternative that | ||
can mitigate those limitations. | ||
.. topic:: References: | ||
.. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, | ||
2001. https://doi.org/10.1023/A:1010933404324 | ||
""" | ||
print(__doc__) | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from sklearn.datasets import fetch_openml | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.impute import SimpleImputer | ||
from sklearn.inspection import permutation_importance | ||
from sklearn.compose import ColumnTransformer | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.preprocessing import OneHotEncoder | ||
|
||
|
||
############################################################################## | ||
# Data Loading and Feature Engineering | ||
# ------------------------------------ | ||
# Let's use pandas to load a copy of the titanic dataset. The following shows | ||
# how to apply separate preprocessing on numerical and categorical features. | ||
# | ||
# We further include two random variables that are not correlated in any way | ||
# with the target variable (``survived``): | ||
# | ||
# - ``random_num`` is a high cardinality numerical variable (as many unique | ||
# values as records). | ||
# - ``random_cat`` is a low cardinality categorical variable (3 possible | ||
# values). | ||
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) | ||
X['random_cat'] = np.random.randint(3, size=X.shape[0]) | ||
X['random_num'] = np.random.randn(X.shape[0]) | ||
|
||
categorical_columns = ['pclass', 'sex', 'embarked', 'random_cat'] | ||
numerical_columns = ['age', 'sibsp', 'parch', 'fare', 'random_num'] | ||
|
||
X = X[categorical_columns + numerical_columns] | ||
|
||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, stratify=y, random_state=42) | ||
|
||
categorical_pipe = Pipeline([ | ||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')), | ||
('onehot', OneHotEncoder(handle_unknown='ignore')) | ||
]) | ||
numerical_pipe = Pipeline([ | ||
('imputer', SimpleImputer(strategy='mean')) | ||
]) | ||
|
||
preprocessing = ColumnTransformer( | ||
[('cat', categorical_pipe, categorical_columns), | ||
('num', numerical_pipe, numerical_columns)]) | ||
|
||
rf = Pipeline([ | ||
('preprocess', preprocessing), | ||
('classifier', RandomForestClassifier(random_state=42)) | ||
]) | ||
rf.fit(X_train, y_train) | ||
|
||
############################################################################## | ||
# Accuracy of the Model | ||
# --------------------- | ||
# Prior to inspecting the feature importances, it is important to check that | ||
# the model predictive performance is high enough. Indeed there would be little | ||
# interest of inspecting the important features of a non-predictive model. | ||
# | ||
# Here one can observe that the train accuracy is very high (the forest model | ||
# has enough capacity to completely memorize the training set) but it can still | ||
# generalize well enough to the test set thanks to the built-in bagging of | ||
# random forests. | ||
# | ||
# It might be possible to trade some accuracy on the training set for a | ||
# slightly better accuracy on the test set by limiting the capacity of the | ||
# trees (for instance by setting ``min_samples_leaf=5`` or | ||
# ``min_samples_leaf=10``) so as to limit overfitting while not introducing too | ||
# much underfitting. | ||
# | ||
# However let's keep our high capacity random forest model for now so as to | ||
# illustrate some pitfalls with feature importance on variables with many | ||
# unique values. | ||
print("RF train accuracy: %0.3f" % rf.score(X_train, y_train)) | ||
print("RF test accuracy: %0.3f" % rf.score(X_test, y_test)) | ||
|
||
|
||
############################################################################## | ||
# Tree's Feature Importance from Mean Decrease in Impurity (MDI) | ||
# -------------------------------------------------------------- | ||
# The impurity-based feature importance ranks the numerical features to be the | ||
# most important features. As a result, the non-predictive ``random_num`` | ||
# variable is ranked the most important! | ||
# | ||
# This problem stems from two limitations of impurity-based feature | ||
# importances: | ||
# | ||
# - impurity-based importances are biased towards high cardinality features; | ||
# - impurity-based importances are computed on training set statistics and | ||
# therefore do not reflect the ability of feature to be useful to make | ||
# predictions that generalize to the test set (when the model has enough | ||
# capacity). | ||
ohe = (rf.named_steps['preprocess'] | ||
.named_transformers_['cat'] | ||
.named_steps['onehot']) | ||
feature_names = ohe.get_feature_names(input_features=categorical_columns) | ||
feature_names = np.r_[feature_names, numerical_columns] | ||
|
||
tree_feature_importances = ( | ||
rf.named_steps['classifier'].feature_importances_) | ||
sorted_idx = tree_feature_importances.argsort() | ||
|
||
y_ticks = np.arange(0, len(feature_names)) | ||
fig, ax = plt.subplots() | ||
ax.barh(y_ticks, tree_feature_importances[sorted_idx]) | ||
ax.set_yticklabels(feature_names[sorted_idx]) | ||
ax.set_yticks(y_ticks) | ||
ax.set_title("Random Forest Feature Importances (MDI)") | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
|
||
############################################################################## | ||
# As an alternative, the permutation importances of ``rf`` are computed on a | ||
# held out test set. This shows that the low cardinality categorical feature, | ||
# ``sex`` is the most important feature. | ||
# | ||
# Also note that both random features have very low importances (close to 0) as | ||
# expected. | ||
result = permutation_importance(rf, X_test, y_test, n_repeats=10, | ||
random_state=42, n_jobs=2) | ||
sorted_idx = result.importances_mean.argsort() | ||
|
||
fig, ax = plt.subplots() | ||
ax.boxplot(result.importances[sorted_idx].T, | ||
vert=False, labels=X_test.columns[sorted_idx]) | ||
ax.set_title("Permutation Importances (test set)") | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
############################################################################## | ||
# It is also possible to compute the permutation importances on the training | ||
# set. This reveals that ``random_num`` gets a significantly higher importance | ||
# ranking than when computed on the test set. The difference between those two | ||
# plots is a confirmation that the RF model has enough capacity to use that | ||
# random numerical feature to overfit. You can further confirm this by | ||
# re-running this example with constrained RF with min_samples_leaf=10. | ||
result = permutation_importance(rf, X_train, y_train, n_repeats=10, | ||
random_state=42, n_jobs=2) | ||
sorted_idx = result.importances_mean.argsort() | ||
|
||
fig, ax = plt.subplots() | ||
ax.boxplot(result.importances[sorted_idx].T, | ||
vert=False, labels=X_train.columns[sorted_idx]) | ||
ax.set_title("Permutation Importances (train set)") | ||
fig.tight_layout() | ||
plt.show() |
111 changes: 111 additions & 0 deletions
111
examples/inspection/plot_permutation_importance_multicollinear.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
================================================================= | ||
Permutation Importance with Multicollinear or Correlated Features | ||
================================================================= | ||
In this example, we compute the permutation importance on the Wisconsin | ||
breast cancer dataset using :func:`~sklearn.inspection.permutation_importance`. | ||
The :class:`~sklearn.ensemble.RandomForestClassifier` can easily get about 97% | ||
accuracy on a test dataset. Because this dataset contains multicollinear | ||
features, the permutation importance will show that none of the features are | ||
important. One approach to handling multicollinearity is by performing | ||
hierarchical clustering on the features' Spearman rank-order correlations, | ||
picking a threshold, and keeping a single feature from each cluster. | ||
.. note:: | ||
See also | ||
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py` | ||
""" | ||
print(__doc__) | ||
from collections import defaultdict | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from scipy.stats import spearmanr | ||
from scipy.cluster import hierarchy | ||
|
||
from sklearn.datasets import load_breast_cancer | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.inspection import permutation_importance | ||
from sklearn.model_selection import train_test_split | ||
|
||
############################################################################## | ||
# Random Forest Feature Importance on Breast Cancer Data | ||
# ------------------------------------------------------ | ||
# First, we train a random forest on the breast cancer dataset and evaluate | ||
# its accuracy on a test set: | ||
data = load_breast_cancer() | ||
X, y = data.data, data.target | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) | ||
|
||
clf = RandomForestClassifier(n_estimators=100, random_state=42) | ||
clf.fit(X_train, y_train) | ||
print("Accuracy on test data: {:.2f}".format(clf.score(X_test, y_test))) | ||
|
||
############################################################################## | ||
# Next, we plot the tree based feature importance and the permutation | ||
# importance. The permutation importance plot shows that permuting a feature | ||
# drops the accuracy by at most `0.012`, which would suggest that none of the | ||
# features are important. This is in contradiction with the high test accuracy | ||
# computed above: some feature must be important. The permutation importance | ||
# is calculated on the training set to show how much the model relies on each | ||
# feature during training. | ||
result = permutation_importance(clf, X_train, y_train, n_repeats=10, | ||
random_state=42) | ||
perm_sorted_idx = result.importances_mean.argsort() | ||
|
||
tree_importance_sorted_idx = np.argsort(clf.feature_importances_) | ||
tree_indicies = np.arange(1, len(clf.feature_importances_) + 1) | ||
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) | ||
ax1.barh(tree_indicies, clf.feature_importances_[tree_importance_sorted_idx]) | ||
ax1.set_yticklabels(data.feature_names) | ||
ax1.set_yticks(tree_indicies) | ||
ax2.boxplot(result.importances[perm_sorted_idx].T, vert=False, | ||
labels=data.feature_names) | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
############################################################################## | ||
# Handling Multicollinear Features | ||
# -------------------------------- | ||
# When features are collinear, permutating one feature will have little | ||
# effect on the models performance because it can get the same information | ||
# from a correlated feature. One way to handle multicollinear features is by | ||
# performing hierarchical clustering on the Spearman rank-order correlations, | ||
# picking a threshold, and keeping a single feature from each cluster. First, | ||
# we plot a heatmap of the correlated features: | ||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) | ||
corr = spearmanr(X).correlation | ||
corr_linkage = hierarchy.ward(corr) | ||
dendro = hierarchy.dendrogram(corr_linkage, labels=data.feature_names, ax=ax1, | ||
leaf_rotation=90) | ||
dendro_idx = np.arange(0, len(dendro['ivl'])) | ||
|
||
ax2.imshow(corr[dendro['leaves'], :][:, dendro['leaves']]) | ||
ax2.set_xticks(dendro_idx) | ||
ax2.set_yticks(dendro_idx) | ||
ax2.set_xticklabels(dendro['ivl'], rotation='vertical') | ||
ax2.set_yticklabels(dendro['ivl']) | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
############################################################################## | ||
# Next, we manually pick a threshold by visual inspection of the dendrogram | ||
# to group our features into clusters and choose a feature from each cluster to | ||
# keep, select those features from our dataset, and train a new random forest. | ||
# The test accuracy of the new random forest did not change much compared to | ||
# the random forest trained on the complete dataset. | ||
cluster_ids = hierarchy.fcluster(corr_linkage, 1, criterion='distance') | ||
cluster_id_to_feature_ids = defaultdict(list) | ||
for idx, cluster_id in enumerate(cluster_ids): | ||
cluster_id_to_feature_ids[cluster_id].append(idx) | ||
selected_features = [v[0] for v in cluster_id_to_feature_ids.values()] | ||
|
||
X_train_sel = X_train[:, selected_features] | ||
X_test_sel = X_test[:, selected_features] | ||
|
||
clf_sel = RandomForestClassifier(n_estimators=100, random_state=42) | ||
clf_sel.fit(X_train_sel, y_train) | ||
print("Accuracy on test data with features removed: {:.2f}".format( | ||
clf_sel.score(X_test_sel, y_test))) |
Oops, something went wrong.