[scikit-learn] Issues with clone for ensemble of classifiers
Luiz Gustavo Hafemann
luiz.gh at gmail.com
Wed Sep 19 10:40:52 EDT 2018
Hello,
I am one of the developers of a library for Dynamic Ensemble Selection
(DES) methods (the library is called DESlib), and we are currently working
to get the library fully compatible with scikit-learn (to submit it to
scikit-learn-contrib). We have "check_estimator" working for most of the
classes, but now I am having problems to make the classes compatible with
GridSearch / other CV functions.
One of the main use cases of this library is to facilitate research on this
field, and this led to a design decision that the base classifiers are fit
by the user, and the DES methods receive a pool of base classifiers that
were already fit (this allow users to compare many DES techniques with the
same base classifiers). This is creating an issue with GridSearch, since
the clone method (defined in sklearn.base) is not cloning the classes as we
would like. It does a shallow (non-deep) copy of the parameters, but we
would like the pool of base classifiers to be deep-copied.
I analyzed this issue and I could not find a solution that does not require
changes on the scikit-learn code. Here is the sequence of steps that cause
the problem:
1. GridSearchCV calls "clone" on the DES estimator (link
<https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/model_selection/_search.py#L677>
)
2. The clone function calls the "get_params" function of the DES
estimator (link
<https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/base.py#L60-L63>,
line 60). We don't re-implement this function, so it gets all the
parameters, including the pool of classifiers (at this point, they are
still "fitted")
3. The clone function then clones each parameter with safe=False (line
62). When cloning the pool of classifiers, the result is a pool that is not
"fitted" anymore.
The problem is that, to my knowledge, there is no way for my classifier to
inform "clone" that a parameter should be always deep copied. I see that
other ensemble methods in sklearn always fit the base classifiers within
the "fit" method of the ensemble, so this problem does not happen there. I
would like to know if there is a solution for this problem while having the
base classifiers fitted elsewhere.
Here is a short code that reproduces the issue:
---------------------------
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import BaggingClassifier
from sklearn.datasets import load_iris
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, base_classifiers, k):
self.base_classifiers = base_classifiers # Base classifiers that
are already trained
self.k = k # Simulate a parameter that we want to do a grid search
on
def fit(self, X_dsel, y_dsel):
pass # Here we would fit any parameters for the Dynamic selection
method, not the base classifiers
def predict(self, X):
return self.base_classifiers.predict(X) # In practice the methods
would do something with the predictions of each classifier
X, y = load_iris(return_X_y=True)
X_train, X_dsel, y_train, y_dsel = train_test_split(X, y, test_size=0.5)
base_classifiers = BaggingClassifier()
base_classifiers.fit(X_train, y_train)
clf = MyClassifier(base_classifiers, k=1)
params = {'k': [1, 3, 5, 7]}
grid = GridSearchCV(clf, params)
grid.fit(X_dsel, y_dsel) # Raises error that the bagging classifiers are
not fitted
---------------------------
Btw, here is the branch that we are using to make the library compatible
with sklearn: https://github.com/Menelau/DESlib/tree/sklearn-estimators.
The failing test related to this issue is in
https://github.com/Menelau/DESlib/blob/sklearn-estimators/deslib/tests/test_des_integration.py#L36
Thanks in advance for any help on this case,
Luiz Gustavo Hafemann
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/scikit-learn/attachments/20180919/5a6d573a/attachment.html>
More information about the scikit-learn
mailing list