Skip to content

Commit fdbb375

Browse files
committed
add a new feature_importance_permutation function
1 parent 9c8529a commit fdbb375

12 files changed

+894
-11
lines changed

docs/mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ pages:
5757
- user_guide/evaluate/BootstrapOutOfBag.md
5858
- user_guide/evaluate/cochrans_q.md
5959
- user_guide/evaluate/confusion_matrix.md
60+
- user_guide/evaluate/feature_importance_permutation.md
6061
- user_guide/evaluate/lift_score.md
6162
- user_guide/evaluate/mcnemar_table.md
6263
- user_guide/evaluate/mcnemar_tables.md

docs/sources/CHANGELOG.md

+4-5
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@ The CHANGELOG for the current development version is available at
1717

1818
##### New Features
1919

20-
21-
- The fit method of the ExhaustiveFeatureSelector now optionally accepts
22-
**fit_params for the estimator that is used for the feature selection. ([#354](https://github.com/rasbt/mlxtend/pull/354) by Zach Griffith)
23-
- The fit method of the SequentialFeatureSelector now optionally accepts
24-
**fit_params for the estimator that is used for the feature selection. ([#350](https://github.com/rasbt/mlxtend/pull/350) by Zach Griffith)
20+
- A new `feature_importance_permuation` function to compute the feature importance in classifiers and regressors via the *permutation importance* method ([#358](https://github.com/rasbt/mlxtend/pull/358))
21+
- The fit method of the ExhaustiveFeatureSelector now optionally accepts **fit_params for the estimator that is used for the feature selection. ([#354](https://github.com/rasbt/mlxtend/pull/354) by Zach Griffith)
22+
- The fit method of the SequentialFeatureSelector now optionally accepts
23+
**fit_params for the estimator that is used for the feature selection. ([#350](https://github.com/rasbt/mlxtend/pull/350) by Zach Griffith)
2524

2625

2726
- -

docs/sources/USER_GUIDE_INDEX.md

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
- [BootstrapOutOfBag](user_guide/evaluate/BootstrapOutOfBag.md)
3030
- [cochrans_q](user_guide/evaluate/cochrans_q.md)
3131
- [confusion_matrix](user_guide/evaluate/confusion_matrix.md)
32+
- [feature_importance_permutation](user_guide/evaluate/feature_importance_permutation.md)
3233
- [lift_score](user_guide/evaluate/lift_score.md)
3334
- [mcnemar_table](user_guide/evaluate/mcnemar_table.md)
3435
- [mcnemar_tables](user_guide/evaluate/mcnemar_tables.md)

docs/sources/user_guide/evaluate/feature_importance_permutation.ipynb

+639
Large diffs are not rendered by default.

mlxtend/evaluate/__init__.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
#
55
# License: BSD 3 clause
66

7-
from .scoring import scoring
7+
8+
from .bootstrap import bootstrap
9+
from .bootstrap_outofbag import BootstrapOutOfBag
10+
from .bootstrap_point632 import bootstrap_point632_score
11+
from .cochrans_q import cochrans_q
812
from .confusion_matrix import confusion_matrix
13+
from .feature_importance import feature_importance_permutation
914
from .lift_score import lift_score
1015
from .mcnemar import mcnemar_table
1116
from .mcnemar import mcnemar_tables
1217
from .mcnemar import mcnemar
13-
from .bootstrap import bootstrap
14-
from .bootstrap_outofbag import BootstrapOutOfBag
15-
from .bootstrap_point632 import bootstrap_point632_score
1618
from .permutation import permutation_test
17-
from .cochrans_q import cochrans_q
19+
from .scoring import scoring
1820
from .ttest import paired_ttest_resampled
1921
from .ttest import paired_ttest_kfold_cv
2022
from .ttest import paired_ttest_5x2cv
@@ -26,4 +28,5 @@
2628
"bootstrap", "permutation_test",
2729
"BootstrapOutOfBag", "bootstrap_point632_score",
2830
"cochrans_q", "paired_ttest_resampled",
29-
"paired_ttest_kfold_cv", "paired_ttest_5x2cv"]
31+
"paired_ttest_kfold_cv", "paired_ttest_5x2cv",
32+
"feature_importance_permutation"]
+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Sebastian Raschka 2014-2018
2+
# mlxtend Machine Learning Library Extensions
3+
#
4+
# Feature Importance Estimation Through Permutation
5+
# Author: Sebastian Raschka <sebastianraschka.com>
6+
#
7+
# License: BSD 3 clause
8+
9+
import numpy as np
10+
11+
12+
def feature_importance_permutation(X, y, predict_method,
13+
metric, num_rounds=1, seed=None):
14+
"""Feature importance imputation via permutation importance
15+
16+
Parameters
17+
----------
18+
19+
X : NumPy array, shape = [n_samples, n_features]
20+
Dataset, where n_samples is the number of samples and
21+
n_features is the number of features.
22+
23+
y : NumPy array, shape = [n_samples]
24+
Target values.
25+
26+
predict_method : prediction function
27+
A callable function that predicts the target values
28+
from X.
29+
30+
metric : str, callable
31+
The metric for evaluating the feature importance through
32+
permutation. By default, the strings 'accuracy' is
33+
recommended for classifiers and the string 'r2' is
34+
recommended for regressors. Optionally, a custom
35+
scoring function (e.g., `metric=scoring_func`) that
36+
accepts two arguments, y_true and y_pred, which have
37+
similar shape to the `y` array.
38+
39+
num_rounds : int (default=1)
40+
Number of rounds the feature columns are permuted to
41+
compute the permutation importance.
42+
43+
seed : int or None (default=None)
44+
Random seed for permuting the feature columns.
45+
46+
Returns
47+
---------
48+
49+
mean_importance_vals, all_importance_vals : NumPy arrays.
50+
The first array, mean_importance_vals has shape [n_features, ] and
51+
contains the importance values for all features.
52+
The shape of the second array is [n_features, num_rounds] and contains
53+
the feature importance for each repetition. If num_rounds=1,
54+
it contains the same values as the first array, mean_importance_vals.
55+
56+
"""
57+
58+
if not isinstance(num_rounds, int):
59+
raise ValueError('num_rounds must be an integer.')
60+
if num_rounds < 1:
61+
raise ValueError('num_rounds must be greater than 1.')
62+
63+
if not (metric in ('r2', 'accuracy') or hasattr(metric, '__call__')):
64+
raise ValueError('metric must be either "r2", "accuracy", '
65+
'or a function with signature func(y_true, y_pred).')
66+
67+
if metric == 'r2':
68+
def score_func(y_true, y_pred):
69+
sum_of_squares = np.sum(np.square(y_true - y_pred))
70+
res_sum_of_squares = np.sum(np.square(y_true - y_true.mean()))
71+
r2_score = 1. - (sum_of_squares / res_sum_of_squares)
72+
return r2_score
73+
74+
elif metric == 'accuracy':
75+
def score_func(y_true, y_pred):
76+
return np.mean(y_true == y_pred)
77+
78+
rng = np.random.RandomState(seed)
79+
80+
mean_importance_vals = np.zeros(X.shape[1])
81+
all_importance_vals = np.zeros((X.shape[1], num_rounds))
82+
83+
baseline = score_func(y, predict_method(X))
84+
85+
for round_idx in range(num_rounds):
86+
for col_idx in range(X.shape[1]):
87+
save_col = X[:, col_idx].copy()
88+
rng.shuffle(X[:, col_idx])
89+
new_score = score_func(y, predict_method(X))
90+
X[:, col_idx] = save_col
91+
importance = baseline - new_score
92+
mean_importance_vals[col_idx] += importance
93+
all_importance_vals[col_idx, round_idx] = importance
94+
mean_importance_vals /= num_rounds
95+
96+
return mean_importance_vals, all_importance_vals
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Sebastian Raschka 2014-2018
2+
# mlxtend Machine Learning Library Extensions
3+
#
4+
# Feature Importance Estimation Through Permutation
5+
# Author: Sebastian Raschka <sebastianraschka.com>
6+
#
7+
# License: BSD 3 clause
8+
9+
import numpy as np
10+
from sklearn.datasets import make_classification
11+
from sklearn.datasets import make_regression
12+
from sklearn.model_selection import train_test_split
13+
from sklearn.svm import SVC
14+
from sklearn.svm import SVR
15+
from mlxtend.utils import assert_raises
16+
from mlxtend.evaluate import feature_importance_permutation
17+
18+
19+
def test_num_rounds_not_int():
20+
assert_raises(ValueError,
21+
'num_rounds must be an integer.',
22+
feature_importance_permutation,
23+
lambda x, y: (x, y),
24+
np.array([[1], [2], [3]]),
25+
np.array([1, 2, 3]),
26+
'accuracy',
27+
1.23)
28+
29+
30+
def test_num_rounds_negative_int():
31+
assert_raises(ValueError,
32+
'num_rounds must be greater than 1.',
33+
feature_importance_permutation,
34+
lambda x, y: (x, y),
35+
np.array([[1], [2], [3]]),
36+
np.array([1, 2, 3]),
37+
'accuracy',
38+
-1)
39+
40+
41+
def test_metric_wrong():
42+
assert_raises(ValueError,
43+
('metric must be either "r2", "accuracy", or a '
44+
'function with signature '
45+
'func(y_true, y_pred).'),
46+
feature_importance_permutation,
47+
lambda x, y: (x, y),
48+
np.array([[1], [2], [3]]),
49+
np.array([1, 2, 3]),
50+
'some-metric')
51+
52+
53+
def test_classification():
54+
55+
X, y = make_classification(n_samples=1000,
56+
n_features=6,
57+
n_informative=3,
58+
n_redundant=0,
59+
n_repeated=0,
60+
n_classes=2,
61+
random_state=0,
62+
shuffle=False)
63+
64+
X_train, X_test, y_train, y_test = train_test_split(
65+
X, y, test_size=0.3, random_state=0, stratify=y)
66+
67+
svm = SVC(C=1.0, kernel='rbf', random_state=0)
68+
svm.fit(X_train, y_train)
69+
70+
imp_vals, imp_all = feature_importance_permutation(
71+
predict_method=svm.predict,
72+
X=X_test,
73+
y=y_test,
74+
metric='accuracy',
75+
num_rounds=1,
76+
seed=1)
77+
78+
assert imp_vals.shape == (X_train.shape[1], )
79+
assert imp_all.shape == (X_train.shape[1], 1)
80+
assert imp_vals[0] > 0.2
81+
assert imp_vals[1] > 0.2
82+
assert imp_vals[2] > 0.2
83+
assert sum(imp_vals[3:]) <= 0.02
84+
85+
86+
def test_regression():
87+
88+
X, y = make_regression(n_samples=1000,
89+
n_features=5,
90+
n_informative=2,
91+
n_targets=1,
92+
random_state=123,
93+
shuffle=False)
94+
95+
X_train, X_test, y_train, y_test = train_test_split(
96+
X, y, test_size=0.3, random_state=123)
97+
98+
svm = SVR(kernel='rbf')
99+
svm.fit(X_train, y_train)
100+
101+
imp_vals, imp_all = feature_importance_permutation(
102+
predict_method=svm.predict,
103+
X=X_test,
104+
y=y_test,
105+
metric='r2',
106+
num_rounds=1,
107+
seed=123)
108+
109+
assert imp_vals.shape == (X_train.shape[1], )
110+
assert imp_all.shape == (X_train.shape[1], 1)
111+
assert imp_vals[0] > 0.2
112+
assert imp_vals[1] > 0.2
113+
assert sum(imp_vals[3:]) <= 0.01
114+
115+
116+
def test_n_rounds():
117+
118+
X, y = make_classification(n_samples=1000,
119+
n_features=6,
120+
n_informative=3,
121+
n_redundant=0,
122+
n_repeated=0,
123+
n_classes=2,
124+
random_state=0,
125+
shuffle=False)
126+
127+
X_train, X_test, y_train, y_test = train_test_split(
128+
X, y, test_size=0.3, random_state=0, stratify=y)
129+
130+
svm = SVC(C=1.0, kernel='rbf', random_state=0)
131+
svm.fit(X_train, y_train)
132+
133+
imp_vals, imp_all = feature_importance_permutation(
134+
predict_method=svm.predict,
135+
X=X_test,
136+
y=y_test,
137+
metric='accuracy',
138+
num_rounds=100,
139+
seed=1)
140+
141+
assert imp_vals.shape == (X_train.shape[1], )
142+
assert imp_all.shape == (X_train.shape[1], 100)
143+
assert imp_vals[0].mean() > 0.2
144+
assert imp_vals[1].mean() > 0.2

0 commit comments

Comments
 (0)