Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Sklearn1.6 compatibility #447

Merged
merged 46 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6901126
Fix _sgd imports
TamaraAtanasoska Nov 4, 2024
19b71c5
Fix _safe_tags import issue
TamaraAtanasoska Nov 4, 2024
e718fce
Change _construct_instance import
TamaraAtanasoska Nov 4, 2024
b3f6401
Change get_tags syntax
TamaraAtanasoska Nov 4, 2024
dadf4e3
Ignore FutureWarning in sklearn
TamaraAtanasoska Nov 4, 2024
0146a74
Merge branch 'main' into sklearn1.6-compatibility
TamaraAtanasoska Nov 4, 2024
45af8a0
Update skops/io/_sklearn.py
TamaraAtanasoska Nov 11, 2024
7abf51e
Update skops/io/_sklearn.py
TamaraAtanasoska Nov 11, 2024
6332470
fix typo
TamaraAtanasoska Nov 11, 2024
d8da963
Fix variable name inconsitency
TamaraAtanasoska Nov 11, 2024
d9a163b
Add clearer message about warning supression
TamaraAtanasoska Nov 11, 2024
cb0b215
WIP
TamaraAtanasoska Nov 12, 2024
c367109
Add explicit typing
TamaraAtanasoska Nov 14, 2024
c96be0b
Merge branch 'main' into sklearn1.6-compatibility
TamaraAtanasoska Nov 14, 2024
051eead
Remove stray WIP with prints
TamaraAtanasoska Nov 14, 2024
a1f4344
Fix tags issues
TamaraAtanasoska Nov 14, 2024
e6b4df3
Update skops/io/_sklearn.py
TamaraAtanasoska Nov 18, 2024
ed77ced
Make the use of SGD models conditional on sklearn version
TamaraAtanasoska Nov 18, 2024
0983b80
Add relative paths to fix import errors
TamaraAtanasoska Nov 18, 2024
fdb4b20
Merge branch 'main' into sklearn1.6-compatibility
TamaraAtanasoska Nov 21, 2024
0388d0b
Add construct_instances for both versions
TamaraAtanasoska Nov 21, 2024
926f972
Move imports for construct_instances
TamaraAtanasoska Nov 21, 2024
1fd8432
Partially make tags work between the two versions
TamaraAtanasoska Nov 21, 2024
de4774d
Tags working with both versions
TamaraAtanasoska Nov 21, 2024
c043a82
Remove typing import
TamaraAtanasoska Nov 22, 2024
81950ff
Attepmt to fix catboost issues
TamaraAtanasoska Nov 25, 2024
bb66ac7
Skip quantile-forest futurewarning sklearn 1.7
TamaraAtanasoska Nov 25, 2024
aeb6baf
Supress quantile-foreset warning
TamaraAtanasoska Nov 25, 2024
bdfc37d
Update spaces/skops_model_card_creator/requirements.txt
TamaraAtanasoska Nov 25, 2024
1cf9c87
Update skops/_min_dependencies.py
TamaraAtanasoska Nov 25, 2024
d5696ff
Add error for SGD class and incompatible sklearn version
TamaraAtanasoska Nov 25, 2024
cfeef0a
Copy code for scikit-learn for est tags
TamaraAtanasoska Nov 26, 2024
e1751fc
Fix loss issues
adrinjalali Nov 27, 2024
b916346
minor fix
adrinjalali Nov 27, 2024
e1d0132
reduce diff
adrinjalali Nov 27, 2024
960dff9
annotations import
adrinjalali Nov 28, 2024
712cb13
work with all instances from _construct_instances
TamaraAtanasoska Nov 28, 2024
c3da1b9
Refactor get_input()
TamaraAtanasoska Nov 29, 2024
a8dad87
trigger CI
adrinjalali Dec 2, 2024
69325c0
debug CI
adrinjalali Dec 2, 2024
d2ecc45
...
adrinjalali Dec 2, 2024
87065a3
...
adrinjalali Dec 2, 2024
d9eaaff
...
adrinjalali Dec 2, 2024
bcc78d4
...
adrinjalali Dec 2, 2024
24783a4
...
adrinjalali Dec 2, 2024
e2f0c82
...
adrinjalali Dec 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ filterwarnings = [
"ignore:DataFrameGroupBy.apply operated on the grouping columns.:DeprecationWarning",
# Ignore Pandas 2.2 warning on PyArrow. It might be reverted in a later release.
"ignore:\\s*Pyarrow will become a required dependency of pandas.*:DeprecationWarning",
# LightGBM sklearn 1.6 deprecation warning, fixed in the next release
"ignore:'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.:FutureWarning",
# RandomForestQuantileRegressor tags deprecation warning in sklearn 1.7
"ignore:The RandomForestQuantileRegressor or classes from which it inherits use `_get_tags` and `_more_tags`:FutureWarning",
# ExtraTreesQuantileRegressor tags deprecation warning in sklearn 1.7
"ignore:The ExtraTreesQuantileRegressor or classes from which it inherits use `_get_tags` and `_more_tags`:FutureWarning",
# BaseEstimator._validate_data deprecation warning in sklearn 1.6 #TODO can be removed when a new release of quantile-forest is out
"ignore:`BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7:FutureWarning",
]
markers = [
"network: marks tests as requiring internet (deselect with '-m \"not network\"')",
Expand Down
5 changes: 2 additions & 3 deletions scripts/check_file_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from zipfile import ZIP_DEFLATED, ZipFile

import pandas as pd
from sklearn.utils._tags import _safe_tags
from sklearn.utils._testing import set_random_state

import skops.io as sio
Expand All @@ -29,6 +28,7 @@
_tested_estimators,
get_input,
)
from skops.utils._fixes import get_tags

TOPK = 10 # number of largest estimators reported
MAX_ALLOWED_SIZE = 1024 # maximum allowed file size in kb
Expand All @@ -46,8 +46,7 @@ def check_file_size() -> None:
set_random_state(estimator, random_state=0)

X, y = get_input(estimator)
tags = _safe_tags(estimator)
if tags.get("requires_fit", True):
if get_tags(estimator).requires_fit:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scikit-learn now uses dataclasses for tags, which are emulated here for old sklearn as well

with warnings.catch_warnings():
warnings.filterwarnings("ignore", module="sklearn")
if y is not None:
Expand Down
5 changes: 2 additions & 3 deletions scripts/check_persistence_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any

import pandas as pd
from sklearn.utils._tags import _safe_tags
from sklearn.utils._testing import set_random_state

import skops.io as sio
Expand All @@ -24,6 +23,7 @@
_tested_estimators,
get_input,
)
from skops.utils._fixes import get_tags

ATOL = 1 # seconds absolute difference allowed at max
NUM_REPS = 10 # number of times the check is repeated
Expand All @@ -43,8 +43,7 @@ def check_persist_performance() -> None:
set_random_state(estimator, random_state=0)

X, y = get_input(estimator)
tags = _safe_tags(estimator)
if tags.get("requires_fit", True):
if get_tags(estimator).requires_fit:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", module="sklearn")
if y is not None:
Expand Down
4 changes: 3 additions & 1 deletion skops/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
# required for persistence tests of external libraries
"lightgbm": ("3", "tests", None),
"xgboost": ("1.6", "tests", None),
"catboost": ("1.0", "tests", None),
# remove python constraint when catboost supports 3.13
# https://github.com/catboost/catboost/issues/2748
"catboost": ("1.0", "tests", 'python_version < "3.13"'),
TamaraAtanasoska marked this conversation as resolved.
Show resolved Hide resolved
"fairlearn": ("0.7.0", "docs, tests", None),
"rich": ("12", "tests, rich", None),
}
Expand Down
101 changes: 76 additions & 25 deletions skops/io/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,94 @@
from typing import Any, Optional, Sequence, Type

from sklearn.cluster import Birch
from sklearn.tree._tree import Tree

from ._general import TypeNode
from ._audit import Node, get_tree
from ._general import TypeNode, unsupported_get_state
from ._protocol import PROTOCOL
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException

try:
# TODO: remove once support for sklearn<1.2 is dropped. See #187
from sklearn.covariance._graph_lasso import _DictWithDeprecatedKeys
except ImportError:
_DictWithDeprecatedKeys = None

from sklearn.linear_model._sgd_fast import (
EpsilonInsensitive,
Hinge,
Huber,
Log,
LossFunction,
Comment on lines -18 to -20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scikit-learn has moved to a more central place for loss functions / classes, and therefore the ones which are used between estimators are removed from the sgd specific file.

ModifiedHuber,
SquaredEpsilonInsensitive,
SquaredHinge,
SquaredLoss,
)
from sklearn.tree._tree import Tree

from ._audit import Node, get_tree
from ._general import unsupported_get_state
from ._utils import LoadContext, SaveContext, get_module, get_state, gettype
from .exceptions import UnsupportedTypeException

ALLOWED_SGD_LOSSES = {
ModifiedHuber,
Hinge,
SquaredHinge,
Log,
SquaredLoss,
Huber,
ALLOWED_LOSSES = {
EpsilonInsensitive,
Hinge,
ModifiedHuber,
SquaredEpsilonInsensitive,
SquaredHinge,
}

try:
# TODO: remove once support for sklearn<1.6 is dropped.
from sklearn.linear_model._sgd_fast import (
Huber,
Log,
SquaredLoss,
)

ALLOWED_LOSSES |= {
Huber,
Log,
SquaredLoss,
}
except ImportError:
pass

try:
# sklearn>=1.6
from sklearn._loss._loss import (
CyAbsoluteError,
CyExponentialLoss,
CyHalfBinomialLoss,
CyHalfGammaLoss,
CyHalfMultinomialLoss,
CyHalfPoissonLoss,
CyHalfSquaredError,
CyHalfTweedieLoss,
CyHalfTweedieLossIdentity,
CyHuberLoss,
CyPinballLoss,
)

Comment on lines +54 to +67
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are the new loss classes

ALLOWED_LOSSES |= {
CyAbsoluteError,
CyExponentialLoss,
CyHalfBinomialLoss,
CyHalfGammaLoss,
CyHalfMultinomialLoss,
CyHalfPoissonLoss,
CyHalfSquaredError,
CyHalfTweedieLoss,
CyHalfTweedieLossIdentity,
CyHuberLoss,
CyPinballLoss,
}
except ImportError:
pass

# This import is for the parent class of all loss functions, which is used to
# set the dispatch function for all loss functions.
try:
# From sklearn>=1.6
from sklearn._loss._loss import CyLossFunction as ParentLossClass
except ImportError:
# sklearn<1.6
from sklearn.linear_model._sgd_fast import LossFunction as ParentLossClass


UNSUPPORTED_TYPES = {Birch}


Expand Down Expand Up @@ -163,13 +213,13 @@ def __init__(
super().__init__(state, load_context, constructor=Tree, trusted=self.trusted)


def sgd_loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
def loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
state = reduce_get_state(obj, save_context)
state["__loader__"] = "SGDNode"
state["__loader__"] = "LossNode"
return state


class SGDNode(ReduceNode):
class LossNode(ReduceNode):
def __init__(
self,
state: dict[str, Any],
Expand All @@ -178,7 +228,7 @@ def __init__(
) -> None:
# TODO: make sure trusted here makes sense and used.
self.trusted = self._get_trusted(
trusted, [get_module(x) + "." + x.__name__ for x in ALLOWED_SGD_LOSSES]
trusted, [get_module(x) + "." + x.__name__ for x in ALLOWED_LOSSES]
)
super().__init__(
state,
Expand Down Expand Up @@ -240,15 +290,16 @@ def _construct(self):

# tuples of type and function that gets the state of that type
GET_STATE_DISPATCH_FUNCTIONS = [
(LossFunction, sgd_loss_get_state),
(ParentLossClass, loss_get_state),
(Tree, tree_get_state),
]

for type_ in UNSUPPORTED_TYPES:
GET_STATE_DISPATCH_FUNCTIONS.append((type_, unsupported_get_state))

# tuples of type and function that creates the instance of that type
NODE_TYPE_MAPPING = {
("SGDNode", PROTOCOL): SGDNode,
NODE_TYPE_MAPPING: dict[tuple[str, int], Any] = {
("LossNode", PROTOCOL): LossNode,
("TreeNode", PROTOCOL): TreeNode,
}

Expand Down
Loading
Loading