-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MNT Sklearn1.6 compatibility #447
Changes from 37 commits
6901126
19b71c5
e718fce
b3f6401
dadf4e3
0146a74
45af8a0
7abf51e
6332470
d8da963
d9a163b
cb0b215
c367109
c96be0b
051eead
a1f4344
e6b4df3
ed77ced
0983b80
fdb4b20
0388d0b
926f972
1fd8432
de4774d
c043a82
81950ff
bb66ac7
aeb6baf
bdfc37d
1cf9c87
d5696ff
cfeef0a
e1751fc
b916346
e1d0132
960dff9
712cb13
c3da1b9
a8dad87
69325c0
d2ecc45
87065a3
d9eaaff
bcc78d4
24783a4
e2f0c82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,44 +3,94 @@ | |
from typing import Any, Optional, Sequence, Type | ||
|
||
from sklearn.cluster import Birch | ||
from sklearn.tree._tree import Tree | ||
|
||
from ._general import TypeNode | ||
from ._audit import Node, get_tree | ||
from ._general import TypeNode, unsupported_get_state | ||
from ._protocol import PROTOCOL | ||
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype | ||
from .exceptions import UnsupportedTypeException | ||
|
||
try: | ||
# TODO: remove once support for sklearn<1.2 is dropped. See #187 | ||
from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys | ||
except ImportError: | ||
_DictWithDeprecatedKeys = None | ||
|
||
from sklearn.linear_model._sgd_fast import ( | ||
EpsilonInsensitive, | ||
Hinge, | ||
Huber, | ||
Log, | ||
LossFunction, | ||
Comment on lines
-18
to
-20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scikit-learn has moved to a more central place for loss functions / classes, and therefore the ones which are used between estimators are removed from the sgd specific file. |
||
ModifiedHuber, | ||
SquaredEpsilonInsensitive, | ||
SquaredHinge, | ||
SquaredLoss, | ||
) | ||
from sklearn.tree._tree import Tree | ||
|
||
from ._audit import Node, get_tree | ||
from ._general import unsupported_get_state | ||
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype | ||
from .exceptions import UnsupportedTypeException | ||
|
||
ALLOWED_SGD_LOSSES = { | ||
ModifiedHuber, | ||
Hinge, | ||
SquaredHinge, | ||
Log, | ||
SquaredLoss, | ||
Huber, | ||
ALLOWED_LOSSES = { | ||
EpsilonInsensitive, | ||
Hinge, | ||
ModifiedHuber, | ||
SquaredEpsilonInsensitive, | ||
SquaredHinge, | ||
} | ||
|
||
try: | ||
# TODO: remove once support for sklearn<1.6 is dropped. | ||
from sklearn.linear_model._sgd_fast import ( | ||
Huber, | ||
Log, | ||
SquaredLoss, | ||
) | ||
|
||
ALLOWED_LOSSES |= { | ||
Huber, | ||
Log, | ||
SquaredLoss, | ||
} | ||
except ImportError: | ||
pass | ||
|
||
try: | ||
# sklearn>=1.6 | ||
from sklearn._loss._loss import ( | ||
CyAbsoluteError, | ||
CyExponentialLoss, | ||
CyHalfBinomialLoss, | ||
CyHalfGammaLoss, | ||
CyHalfMultinomialLoss, | ||
CyHalfPoissonLoss, | ||
CyHalfSquaredError, | ||
CyHalfTweedieLoss, | ||
CyHalfTweedieLossIdentity, | ||
CyHuberLoss, | ||
CyPinballLoss, | ||
) | ||
|
||
Comment on lines
+54
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are the new loss classes |
||
ALLOWED_LOSSES |= { | ||
CyAbsoluteError, | ||
CyExponentialLoss, | ||
CyHalfBinomialLoss, | ||
CyHalfGammaLoss, | ||
CyHalfMultinomialLoss, | ||
CyHalfPoissonLoss, | ||
CyHalfSquaredError, | ||
CyHalfTweedieLoss, | ||
CyHalfTweedieLossIdentity, | ||
CyHuberLoss, | ||
CyPinballLoss, | ||
} | ||
except ImportError: | ||
pass | ||
|
||
# This import is for the parent class of all loss functions, which is used to | ||
# set the dispatch function for all loss functions. | ||
try: | ||
# From sklearn>=1.6 | ||
from sklearn._loss._loss import CyLossFunction as ParentLossClass | ||
except ImportError: | ||
# sklearn<1.6 | ||
from sklearn.linear_model._sgd_fast import LossFunction as ParentLossClass | ||
|
||
|
||
UNSUPPORTED_TYPES = {Birch} | ||
|
||
|
||
|
@@ -163,13 +213,13 @@ def __init__( | |
super().__init__(state, load_context, constructor=Tree, trusted=self.trusted) | ||
|
||
|
||
def sgd_loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: | ||
def loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: | ||
state = reduce_get_state(obj, save_context) | ||
state["__loader__"] = "SGDNode" | ||
state["__loader__"] = "LossNode" | ||
return state | ||
|
||
|
||
class SGDNode(ReduceNode): | ||
class LossNode(ReduceNode): | ||
def __init__( | ||
self, | ||
state: dict[str, Any], | ||
|
@@ -178,7 +228,7 @@ def __init__( | |
) -> None: | ||
# TODO: make sure trusted here makes sense and used. | ||
self.trusted = self._get_trusted( | ||
trusted, [get_module(x) + "." + x.__name__ for x in ALLOWED_SGD_LOSSES] | ||
trusted, [get_module(x) + "." + x.__name__ for x in ALLOWED_LOSSES] | ||
) | ||
super().__init__( | ||
state, | ||
|
@@ -240,15 +290,16 @@ def _construct(self): | |
|
||
# tuples of type and function that gets the state of that type | ||
GET_STATE_DISPATCH_FUNCTIONS = [ | ||
(LossFunction, sgd_loss_get_state), | ||
(ParentLossClass, loss_get_state), | ||
(Tree, tree_get_state), | ||
] | ||
|
||
for type_ in UNSUPPORTED_TYPES: | ||
GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state)) | ||
|
||
# tuples of type and function that creates the instance of that type | ||
NODE_TYPE_MAPPING = { | ||
("SGDNode", PROTOCOL): SGDNode, | ||
NODE_TYPE_MAPPING: dict[tuple[str, int], Any] = { | ||
("LossNode", PROTOCOL): LossNode, | ||
("TreeNode", PROTOCOL): TreeNode, | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scikit-learn now uses dataclasses for tags, which are emulated here for old sklearn as well