-
Notifications
You must be signed in to change notification settings - Fork 876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using gridsearch to test different models #257
Comments
Phew, that's a tricky one ;). Looking at the code, I see that we have sth like for key, value in six.iteritems(super(StackingClassifier,
self).get_params(deep=False)):
if key in ('classifiers', 'meta-classifier'):
continue
else:
out['%s' % key] = value
return out which is basically hiding those two from scikit-learn's grid search. The if clause could be removed to allow tuning the from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from mlxtend.classifier import StackingClassifier
from sklearn import datasets
iris = datasets.load_iris()
X, y = iris.data[:, 1:3], iris.target
# Initializing models
clf1 = KNeighborsClassifier(n_neighbors=1)
clf2 = RandomForestClassifier(random_state=1)
clf3 = GaussianNB()
lr = LogisticRegression()
sclf = StackingClassifier(classifiers=[clf1, clf2, clf3],
meta_classifier=lr)
params = {'classifiers': [[clf1, clf2, clf3], [clf2, clf3]], 'kneighborsclassifier__n_neighbors': [1, 5]}
grid = GridSearchCV(estimator=sclf,
param_grid=params,
cv=5,
refit=True)
grid.fit(X, y)
cv_keys = ('mean_test_score', 'std_test_score', 'params')
for r, _ in enumerate(grid.cv_results_['mean_test_score']):
print("%0.3f +/- %0.2f %r"
% (grid.cv_results_[cv_keys[0]][r],
grid.cv_results_[cv_keys[1]][r] / 2.0,
grid.cv_results_[cv_keys[2]][r]))
print('Best parameters: %s' % grid.best_params_)
print('Accuracy: %.2f' % grid.best_score_) So, in short, yeah, changing for key, value in six.iteritems(super(StackingClassifier,
self).get_params(deep=False)):
if key in ('classifiers', 'meta-classifier'):
continue
else:
out['%s' % key] = value
return out to for key, value in six.iteritems(super(StackingClassifier,
self).get_params(deep=False)):
out['%s' % key] = value
return out would allow that! Happy to make that change (the reason why I included the if-else was that I wasn't sure how it's handled by GridSearchCV, but it seems to be okay :)) |
Awesome, thanks Sebastian! I was having trouble with the StackedRegressor, but finally got it working! It looks like I still need to set some default settings when initializing it? |
yeah, I think defaults are required to get it working (maybe, in future we could have some sensible defaults for the regressors and meta-regressors though). Glad to hear that it's working though, and I will update the mlxtend implementations with regard to the modification I mentioned above so that you don't have to tweak the code manually ;) |
Addressed in now via #259. There's one little caveat though (and I added it to the docs): you cannot search over both classifiers/regressors and their parameters at the same time (it may be due to how Anyway, what I meant is that for instance, while the following parameter dictionary works in a sense that it does not produce an error:
but it will ignore |
Apologies if this has been discussed already, but I couldn't find any mention of it.
I see an example in the documentation of how to tune the parameters using a gridsearch, but would it be possible to also try different combinations of models? Here's the code I used, but it throws an error:
And then I get an error saying I can't initialize the regressor without a meta model:
TypeError: __init__() missing 1 required positional argument: 'meta_model'
Similarly, it would also be nice to try different combinations of regressors.
Would it be possible to add this sort of functionality? Or is there a workaround I can use in the meantime?
The text was updated successfully, but these errors were encountered: