167 lines
4.9 KiB
Python
167 lines
4.9 KiB
Python
![]() |
# Authors: The scikit-learn developers
|
||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
from copy import deepcopy
|
||
|
|
||
|
from ..base import BaseEstimator
|
||
|
from ..exceptions import NotFittedError
|
||
|
from ..utils import get_tags
|
||
|
from ..utils.metaestimators import available_if
|
||
|
from ..utils.validation import check_is_fitted
|
||
|
|
||
|
|
||
|
def _estimator_has(attr):
|
||
|
"""Check that final_estimator has `attr`.
|
||
|
|
||
|
Used together with `available_if`.
|
||
|
"""
|
||
|
|
||
|
def check(self):
|
||
|
# raise original `AttributeError` if `attr` does not exist
|
||
|
getattr(self.estimator, attr)
|
||
|
return True
|
||
|
|
||
|
return check
|
||
|
|
||
|
|
||
|
class FrozenEstimator(BaseEstimator):
|
||
|
"""Estimator that wraps a fitted estimator to prevent re-fitting.
|
||
|
|
||
|
This meta-estimator takes an estimator and freezes it, in the sense that calling
|
||
|
`fit` on it has no effect. `fit_predict` and `fit_transform` are also disabled.
|
||
|
All other methods are delegated to the original estimator and original estimator's
|
||
|
attributes are accessible as well.
|
||
|
|
||
|
This is particularly useful when you have a fitted or a pre-trained model as a
|
||
|
transformer in a pipeline, and you'd like `pipeline.fit` to have no effect on this
|
||
|
step.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
estimator : estimator
|
||
|
The estimator which is to be kept frozen.
|
||
|
|
||
|
See Also
|
||
|
--------
|
||
|
None: No similar entry in the scikit-learn documentation.
|
||
|
|
||
|
Examples
|
||
|
--------
|
||
|
>>> from sklearn.datasets import make_classification
|
||
|
>>> from sklearn.frozen import FrozenEstimator
|
||
|
>>> from sklearn.linear_model import LogisticRegression
|
||
|
>>> X, y = make_classification(random_state=0)
|
||
|
>>> clf = LogisticRegression(random_state=0).fit(X, y)
|
||
|
>>> frozen_clf = FrozenEstimator(clf)
|
||
|
>>> frozen_clf.fit(X, y) # No-op
|
||
|
FrozenEstimator(estimator=LogisticRegression(random_state=0))
|
||
|
>>> frozen_clf.predict(X) # Predictions from `clf.predict`
|
||
|
array(...)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, estimator):
|
||
|
self.estimator = estimator
|
||
|
|
||
|
@available_if(_estimator_has("__getitem__"))
|
||
|
def __getitem__(self, *args, **kwargs):
|
||
|
"""__getitem__ is defined in :class:`~sklearn.pipeline.Pipeline` and \
|
||
|
:class:`~sklearn.compose.ColumnTransformer`.
|
||
|
"""
|
||
|
return self.estimator.__getitem__(*args, **kwargs)
|
||
|
|
||
|
def __getattr__(self, name):
|
||
|
# `estimator`'s attributes are now accessible except `fit_predict` and
|
||
|
# `fit_transform`
|
||
|
if name in ["fit_predict", "fit_transform"]:
|
||
|
raise AttributeError(f"{name} is not available for frozen estimators.")
|
||
|
return getattr(self.estimator, name)
|
||
|
|
||
|
def __sklearn_clone__(self):
|
||
|
return self
|
||
|
|
||
|
def __sklearn_is_fitted__(self):
|
||
|
try:
|
||
|
check_is_fitted(self.estimator)
|
||
|
return True
|
||
|
except NotFittedError:
|
||
|
return False
|
||
|
|
||
|
def fit(self, X, y, *args, **kwargs):
|
||
|
"""No-op.
|
||
|
|
||
|
As a frozen estimator, calling `fit` has no effect.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
X : object
|
||
|
Ignored.
|
||
|
|
||
|
y : object
|
||
|
Ignored.
|
||
|
|
||
|
*args : tuple
|
||
|
Additional positional arguments. Ignored, but present for API compatibility
|
||
|
with `self.estimator`.
|
||
|
|
||
|
**kwargs : dict
|
||
|
Additional keyword arguments. Ignored, but present for API compatibility
|
||
|
with `self.estimator`.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
self : object
|
||
|
Returns the instance itself.
|
||
|
"""
|
||
|
check_is_fitted(self.estimator)
|
||
|
return self
|
||
|
|
||
|
def set_params(self, **kwargs):
|
||
|
"""Set the parameters of this estimator.
|
||
|
|
||
|
The only valid key here is `estimator`. You cannot set the parameters of the
|
||
|
inner estimator.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
**kwargs : dict
|
||
|
Estimator parameters.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
self : FrozenEstimator
|
||
|
This estimator.
|
||
|
"""
|
||
|
estimator = kwargs.pop("estimator", None)
|
||
|
if estimator is not None:
|
||
|
self.estimator = estimator
|
||
|
if kwargs:
|
||
|
raise ValueError(
|
||
|
"You cannot set parameters of the inner estimator in a frozen "
|
||
|
"estimator since calling `fit` has no effect. You can use "
|
||
|
"`frozenestimator.estimator.set_params` to set parameters of the inner "
|
||
|
"estimator."
|
||
|
)
|
||
|
|
||
|
def get_params(self, deep=True):
|
||
|
"""Get parameters for this estimator.
|
||
|
|
||
|
Returns a `{"estimator": estimator}` dict. The parameters of the inner
|
||
|
estimator are not included.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
deep : bool, default=True
|
||
|
Ignored.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
params : dict
|
||
|
Parameter names mapped to their values.
|
||
|
"""
|
||
|
return {"estimator": self.estimator}
|
||
|
|
||
|
def __sklearn_tags__(self):
|
||
|
tags = deepcopy(get_tags(self.estimator))
|
||
|
tags._skip_test = True
|
||
|
return tags
|