|
8 | 8 | #
|
9 | 9 | # License: BSD 3 clause
|
10 | 10 |
|
11 |
| -from sklearn.base import BaseEstimator |
12 |
| -from sklearn.base import ClassifierMixin |
13 |
| -from sklearn.base import TransformerMixin |
14 |
| -from sklearn.preprocessing import LabelEncoder |
15 |
| -from sklearn.base import clone |
| 11 | +import numpy as np |
| 12 | +from sklearn.base import (BaseEstimator, ClassifierMixin, TransformerMixin, |
| 13 | + clone) |
16 | 14 | from sklearn.exceptions import NotFittedError
|
17 |
| -from ..externals.name_estimators import _name_estimators |
| 15 | +from sklearn.preprocessing import LabelEncoder |
| 16 | + |
18 | 17 | from ..externals import six
|
19 |
| -import numpy as np |
| 18 | +from ..externals.name_estimators import _name_estimators |
20 | 19 |
|
21 | 20 |
|
22 | 21 | class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
|
@@ -61,7 +60,7 @@ class EnsembleVoteClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
|
61 | 60 | fit_base_estimators : bool (default: True)
|
62 | 61 | Refits classifiers in `clfs` if True; uses references to the `clfs`,
|
63 | 62 | otherwise (assumes that the classifiers were already fit).
|
64 |
| - Note: fit_base_estimators=False will enforce use_clones to be False, |
| 63 | + Note: fit_base_estimators=False will enforce use_clones to be False, |
65 | 64 | and is incompatible to most scikit-learn wrappers!
|
66 | 65 | For instance, if any form of cross-validation is performed
|
67 | 66 | this would require the re-fitting classifiers to training folds, which
|
@@ -161,6 +160,7 @@ def fit(self, X, y, sample_weight=None):
|
161 | 160 | self.classes_ = self.le_.classes_
|
162 | 161 |
|
163 | 162 | if not self.fit_base_estimators:
|
| 163 | + print('Warning: enforce use_clones to be False') |
164 | 164 | self.use_clones = False
|
165 | 165 |
|
166 | 166 | if self.use_clones:
|
@@ -283,8 +283,8 @@ def get_params(self, deep=True):
|
283 | 283 | for key, value in six.iteritems(step.get_params(deep=True)):
|
284 | 284 | out['%s__%s' % (name, key)] = value
|
285 | 285 |
|
286 |
| - for key, value in six.iteritems(super(EnsembleVoteClassifier, |
287 |
| - self).get_params(deep=False)): |
| 286 | + for key, value in six.iteritems( |
| 287 | + super(EnsembleVoteClassifier, self).get_params(deep=False)): |
288 | 288 | out['%s' % key] = value
|
289 | 289 | return out
|
290 | 290 |
|
|
0 commit comments