Skip to content

Commit 5d1f831

Browse files
jameslambhcho3
authored andcommitted
Adapt to scikit-learn 1.6 estimator tag changes (dmlc#11021)
1 parent 30a7fd5 commit 5d1f831

File tree

12 files changed

+200
-28
lines changed

12 files changed

+200
-28
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@ credentials.csv
139139
.bloop
140140

141141
# python tests
142+
*.bin
142143
demo/**/*.txt
143144
*.dmatrix
144145
.hypothesis
145146
__MACOSX/
146147
model*.json
148+
/tests/python/models/models/
147149

148150
# R tests
149151
*.htm

python-package/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ extension-pkg-whitelist = ["numpy"]
6262
disable = [
6363
"attribute-defined-outside-init",
6464
"import-outside-toplevel",
65+
"too-few-public-methods",
66+
"too-many-ancestors",
6567
"too-many-nested-blocks",
6668
"unexpected-special-method-signature",
6769
"unsubscriptable-object",

python-package/xgboost/compat.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,43 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
4545

4646
# sklearn
4747
try:
48+
from sklearn import __version__ as _sklearn_version
4849
from sklearn.base import BaseEstimator as XGBModelBase
4950
from sklearn.base import ClassifierMixin as XGBClassifierBase
5051
from sklearn.base import RegressorMixin as XGBRegressorBase
51-
from sklearn.preprocessing import LabelEncoder
5252

5353
try:
54-
from sklearn.model_selection import KFold as XGBKFold
5554
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
5655
except ImportError:
57-
from sklearn.cross_validation import KFold as XGBKFold
5856
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold
5957

58+
# sklearn.utils Tags types can be imported unconditionally once
59+
# xgboost's minimum scikit-learn version is 1.6 or higher
60+
try:
61+
from sklearn.utils import Tags as _sklearn_Tags
62+
except ImportError:
63+
_sklearn_Tags = object
64+
6065
SKLEARN_INSTALLED = True
6166

6267
except ImportError:
6368
SKLEARN_INSTALLED = False
6469

6570
# used for compatibility without sklearn
66-
XGBModelBase = object
67-
XGBClassifierBase = object
68-
XGBRegressorBase = object
69-
LabelEncoder = object
71+
class XGBModelBase: # type: ignore[no-redef]
72+
"""Dummy class for sklearn.base.BaseEstimator."""
73+
74+
class XGBClassifierBase: # type: ignore[no-redef]
75+
"""Dummy class for sklearn.base.ClassifierMixin."""
76+
77+
class XGBRegressorBase: # type: ignore[no-redef]
78+
"""Dummy class for sklearn.base.RegressorMixin."""
7079

71-
XGBKFold = None
7280
XGBStratifiedKFold = None
7381

82+
_sklearn_Tags = object
83+
_sklearn_version = object
84+
7485

7586
_logger = logging.getLogger(__name__)
7687

python-package/xgboost/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def c_array(
410410
def from_array_interface(interface: dict) -> NumpyOrCupy:
411411
"""Convert array interface to numpy or cupy array"""
412412

413-
class Array: # pylint: disable=too-few-public-methods
413+
class Array:
414414
"""Wrapper type for communicating with numpy and cupy."""
415415

416416
_interface: Optional[dict] = None

python-package/xgboost/dask/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# pylint: disable=too-many-arguments, too-many-locals
22
# pylint: disable=missing-class-docstring, invalid-name
33
# pylint: disable=too-many-lines
4-
# pylint: disable=too-few-public-methods
5-
# pylint: disable=import-error
64
"""
75
Dask extensions for distributed training
86
----------------------------------------

python-package/xgboost/sklearn.py

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@
2929

3030
# Do not use class names on scikit-learn directly. Re-define the classes on
3131
# .compat to guarantee the behavior without scikit-learn
32-
from .compat import SKLEARN_INSTALLED, XGBClassifierBase, XGBModelBase, XGBRegressorBase
32+
from .compat import (
33+
SKLEARN_INSTALLED,
34+
XGBClassifierBase,
35+
XGBModelBase,
36+
XGBRegressorBase,
37+
_sklearn_Tags,
38+
_sklearn_version,
39+
)
3340
from .config import config_context
3441
from .core import (
3542
Booster,
@@ -45,7 +52,7 @@
4552
from .training import train
4653

4754

48-
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
55+
class XGBRankerMixIn:
4956
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
5057
base classes.
5158
@@ -69,7 +76,7 @@ def _can_use_qdm(tree_method: Optional[str]) -> bool:
6976
return tree_method in ("hist", "gpu_hist", None, "auto")
7077

7178

72-
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
79+
class _SklObjWProto(Protocol):
7380
def __call__(
7481
self,
7582
y_true: ArrayLike,
@@ -787,6 +794,41 @@ def _more_tags(self) -> Dict[str, bool]:
787794
tags["non_deterministic"] = True
788795
return tags
789796

797+
@staticmethod
798+
def _update_sklearn_tags_from_dict(
799+
*,
800+
tags: _sklearn_Tags,
801+
tags_dict: Dict[str, bool],
802+
) -> _sklearn_Tags:
803+
"""Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes.
804+
805+
``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags.
806+
ref: https://github.com/scikit-learn/scikit-learn/pull/29677
807+
808+
This method handles updating that instance based on the values in ``self._more_tags()``.
809+
"""
810+
tags.non_deterministic = tags_dict.get("non_deterministic", False)
811+
tags.no_validation = tags_dict["no_validation"]
812+
tags.input_tags.allow_nan = tags_dict["allow_nan"]
813+
return tags
814+
815+
def __sklearn_tags__(self) -> _sklearn_Tags:
816+
# XGBModelBase.__sklearn_tags__() cannot be called unconditionally,
817+
# because that method isn't defined for scikit-learn<1.6
818+
if not hasattr(XGBModelBase, "__sklearn_tags__"):
819+
err_msg = (
820+
"__sklearn_tags__() should not be called when using scikit-learn<1.6. "
821+
f"Detected version: {_sklearn_version}"
822+
)
823+
raise AttributeError(err_msg)
824+
825+
# take whatever tags are provided by BaseEstimator, then modify
826+
# them with XGBoost-specific values
827+
return self._update_sklearn_tags_from_dict(
828+
tags=super().__sklearn_tags__(), # pylint: disable=no-member
829+
tags_dict=self._more_tags(),
830+
)
831+
790832
def __sklearn_is_fitted__(self) -> bool:
791833
return hasattr(self, "_Booster")
792834

@@ -841,13 +883,27 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
841883
"""Get parameters."""
842884
# Based on: https://stackoverflow.com/questions/59248211
843885
# The basic flow in `get_params` is:
844-
# 0. Return parameters in subclass first, by using inspect.
845-
# 1. Return parameters in `XGBModel` (the base class).
886+
# 0. Return parameters in subclass (self.__class__) first, by using inspect.
887+
# 1. Return parameters in all parent classes (especially `XGBModel`).
846888
# 2. Return whatever in `**kwargs`.
847889
# 3. Merge them.
890+
#
891+
# This needs to accommodate being called recursively in the following
892+
# inheritance graphs (and similar for classification and ranking):
893+
#
894+
# XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator
895+
# XGBRegressor -> XGBModel -> BaseEstimator
896+
# XGBModel -> BaseEstimator
897+
#
848898
params = super().get_params(deep)
849899
cp = copy.copy(self)
850-
cp.__class__ = cp.__class__.__bases__[0]
900+
# If the immediate parent defines get_params(), use that.
901+
if callable(getattr(cp.__class__.__bases__[0], "get_params", None)):
902+
cp.__class__ = cp.__class__.__bases__[0]
903+
# Otherwise, skip it and assume the next class will have it.
904+
# This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
905+
else:
906+
cp.__class__ = cp.__class__.__bases__[1]
851907
params.update(cp.__class__.get_params(cp, deep))
852908
# if kwargs is a dict, update params accordingly
853909
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict):
@@ -1431,7 +1487,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
14311487
Number of boosting rounds.
14321488
""",
14331489
)
1434-
class XGBClassifier(XGBModel, XGBClassifierBase):
1490+
class XGBClassifier(XGBClassifierBase, XGBModel):
14351491
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
14361492
@_deprecate_positional_args
14371493
def __init__(
@@ -1447,6 +1503,12 @@ def _more_tags(self) -> Dict[str, bool]:
14471503
tags["multilabel"] = True
14481504
return tags
14491505

1506+
def __sklearn_tags__(self) -> _sklearn_Tags:
1507+
tags = super().__sklearn_tags__()
1508+
tags_dict = self._more_tags()
1509+
tags.classifier_tags.multi_label = tags_dict["multilabel"]
1510+
return tags
1511+
14501512
@_deprecate_positional_args
14511513
def fit(
14521514
self,
@@ -1717,7 +1779,7 @@ def fit(
17171779
"Implementation of the scikit-learn API for XGBoost regression.",
17181780
["estimators", "model", "objective"],
17191781
)
1720-
class XGBRegressor(XGBModel, XGBRegressorBase):
1782+
class XGBRegressor(XGBRegressorBase, XGBModel):
17211783
# pylint: disable=missing-docstring
17221784
@_deprecate_positional_args
17231785
def __init__(
@@ -1731,6 +1793,13 @@ def _more_tags(self) -> Dict[str, bool]:
17311793
tags["multioutput_only"] = False
17321794
return tags
17331795

1796+
def __sklearn_tags__(self) -> _sklearn_Tags:
1797+
tags = super().__sklearn_tags__()
1798+
tags_dict = self._more_tags()
1799+
tags.target_tags.multi_output = tags_dict["multioutput"]
1800+
tags.target_tags.single_output = not tags_dict["multioutput_only"]
1801+
return tags
1802+
17341803

17351804
@xgboost_model_doc(
17361805
"scikit-learn API for XGBoost random forest regression.",
@@ -1858,7 +1927,7 @@ def _get_qid(
18581927
`qid` can be a special column of input `X` instead of a separated parameter, see
18591928
:py:meth:`fit` for more info.""",
18601929
)
1861-
class XGBRanker(XGBModel, XGBRankerMixIn):
1930+
class XGBRanker(XGBRankerMixIn, XGBModel):
18621931
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
18631932
@_deprecate_positional_args
18641933
def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any):

python-package/xgboost/spark/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import base64
44

5-
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
6-
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
5+
# pylint: disable=fixme, protected-access, no-member, invalid-name
6+
# pylint: disable=too-many-lines, too-many-branches
77
import json
88
import logging
99
import os

python-package/xgboost/spark/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Xgboost pyspark integration submodule for estimator API."""
22

3-
# pylint: disable=too-many-ancestors
4-
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
3+
# pylint: disable=fixme, protected-access, no-member, invalid-name
54
# pylint: disable=unused-argument, too-many-locals
65

76
import warnings

python-package/xgboost/spark/params.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from typing import Dict
44

5-
# pylint: disable=too-few-public-methods
65
from pyspark.ml.param import TypeConverters
76
from pyspark.ml.param.shared import Param, Params
87

python-package/xgboost/spark/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _get_default_params_from_func(
4343
return filtered_params_dict
4444

4545

46-
class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
46+
class CommunicatorContext(CCtx):
4747
"""Context with PySpark specific task ID."""
4848

4949
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:

0 commit comments

Comments
 (0)