Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MNT Sklearn1.6 compatibility #447

Merged
merged 46 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6901126
Fix _sgd imports
TamaraAtanasoska Nov 4, 2024
19b71c5
Fix _safe_tags import issue
TamaraAtanasoska Nov 4, 2024
e718fce
Change _construct_instance import
TamaraAtanasoska Nov 4, 2024
b3f6401
Change get_tags syntax
TamaraAtanasoska Nov 4, 2024
dadf4e3
Ignore FutureWarning in sklearn
TamaraAtanasoska Nov 4, 2024
0146a74
Merge branch 'main' into sklearn1.6-compatibility
TamaraAtanasoska Nov 4, 2024
45af8a0
Update skops/io/_sklearn.py
TamaraAtanasoska Nov 11, 2024
7abf51e
Update skops/io/_sklearn.py
TamaraAtanasoska Nov 11, 2024
6332470
fix typo
TamaraAtanasoska Nov 11, 2024
d8da963
Fix variable name inconsitency
TamaraAtanasoska Nov 11, 2024
d9a163b
Add clearer message about warning supression
TamaraAtanasoska Nov 11, 2024
cb0b215
WIP
TamaraAtanasoska Nov 12, 2024
c367109
Add explicit typing
TamaraAtanasoska Nov 14, 2024
c96be0b
Merge branch 'main' into sklearn1.6-compatibility
TamaraAtanasoska Nov 14, 2024
051eead
Remove stray WIP with prints
TamaraAtanasoska Nov 14, 2024
a1f4344
Fix tags issues
TamaraAtanasoska Nov 14, 2024
e6b4df3
Update skops/io/_sklearn.py
TamaraAtanasoska Nov 18, 2024
ed77ced
Make the use of SGD models conditional on sklearn version
TamaraAtanasoska Nov 18, 2024
0983b80
Add relative paths to fix import errors
TamaraAtanasoska Nov 18, 2024
fdb4b20
Merge branch 'main' into sklearn1.6-compatibility
TamaraAtanasoska Nov 21, 2024
0388d0b
Add construct_instances for both versions
TamaraAtanasoska Nov 21, 2024
926f972
Move imports for construct_instances
TamaraAtanasoska Nov 21, 2024
1fd8432
Partially make tags work between the two versions
TamaraAtanasoska Nov 21, 2024
de4774d
Tags working with both versions
TamaraAtanasoska Nov 21, 2024
c043a82
Remove typing import
TamaraAtanasoska Nov 22, 2024
81950ff
Attepmt to fix catboost issues
TamaraAtanasoska Nov 25, 2024
bb66ac7
Skip quantile-forest futurewarning sklearn 1.7
TamaraAtanasoska Nov 25, 2024
aeb6baf
Supress quantile-foreset warning
TamaraAtanasoska Nov 25, 2024
bdfc37d
Update spaces/skops_model_card_creator/requirements.txt
TamaraAtanasoska Nov 25, 2024
1cf9c87
Update skops/_min_dependencies.py
TamaraAtanasoska Nov 25, 2024
d5696ff
Add error for SGD class and incompatible sklearn version
TamaraAtanasoska Nov 25, 2024
cfeef0a
Copy code for scikit-learn for est tags
TamaraAtanasoska Nov 26, 2024
e1751fc
Fix loss issues
adrinjalali Nov 27, 2024
b916346
minor fix
adrinjalali Nov 27, 2024
e1d0132
reduce diff
adrinjalali Nov 27, 2024
960dff9
annotations import
adrinjalali Nov 28, 2024
712cb13
work with all instances from _construct_instances
TamaraAtanasoska Nov 28, 2024
c3da1b9
Refactor get_input()
TamaraAtanasoska Nov 29, 2024
a8dad87
trigger CI
adrinjalali Dec 2, 2024
69325c0
debug CI
adrinjalali Dec 2, 2024
d2ecc45
...
adrinjalali Dec 2, 2024
87065a3
...
adrinjalali Dec 2, 2024
d9eaaff
...
adrinjalali Dec 2, 2024
bcc78d4
...
adrinjalali Dec 2, 2024
24783a4
...
adrinjalali Dec 2, 2024
e2f0c82
...
adrinjalali Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions skops/io/tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from skops.utils._fixes import construct_instances, get_tags

# Default settings for X
N_SAMPLES = 100
N_SAMPLES = 120
N_FEATURES = 20


Expand Down Expand Up @@ -146,20 +146,21 @@ def _tested_estimators(type_filter=None):
# scikit-learn < 1.4.0) is not available in scipy >= 1.11.0. The
# default solver will be "highs" from scikit-learn >= 1.4.0.
# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.QuantileRegressor.html
estimator = construct_instances(partial(Estimator, solver="highs"))
estimators = construct_instances(partial(Estimator, solver="highs"))
else:
estimator = construct_instances(Estimator)

# with the kind of data we pass, it needs to be 1 for the few
# estimators which have this.
if "n_components" in estimator.get_params():
estimator.set_params(n_components=1)
# Then n_best needs to be <= n_components
if "n_best" in estimator.get_params():
estimator.set_params(n_best=1)
if "patch_size" in estimator.get_params():
# set patch size to fix PatchExtractor test.
estimator.set_params(patch_size=(3, 3))
estimators = construct_instances(Estimator)

for estimator in estimators:
# with the kind of data we pass, it needs to be 1 for the few
# estimators which have this.
if "n_components" in estimator.get_params():
estimator.set_params(n_components=1)
# Then n_best needs to be <= n_components
if "n_best" in estimator.get_params():
estimator.set_params(n_best=1)
if "patch_size" in estimator.get_params():
# set patch size to fix PatchExtractor test.
estimator.set_params(patch_size=(3, 3))
except SkipTest:
continue

Expand Down Expand Up @@ -270,17 +271,18 @@ def _unsupported_estimators(type_filter=None):
message="Can't instantiate estimator",
)
# Get the first instance directly from the generator
estimator = construct_instances(Estimator)
estimators = construct_instances(Estimator)
# with the kind of data we pass, it needs to be 1 for the few
# estimators which have this.
if "n_components" in estimator.get_params():
estimator.set_params(n_components=1)
# Then n_best needs to be <= n_components
if "n_best" in estimator.get_params():
estimator.set_params(n_best=1)
if "patch_size" in estimator.get_params():
# set patch size to fix PatchExtractor test.
estimator.set_params(patch_size=(3, 3))
for estimator in estimators:
if "n_components" in estimator.get_params():
estimator.set_params(n_components=1)
# Then n_best needs to be <= n_components
if "n_best" in estimator.get_params():
estimator.set_params(n_best=1)
if "patch_size" in estimator.get_params():
# set patch size to fix PatchExtractor test.
estimator.set_params(patch_size=(3, 3))
except SkipTest:
continue

Expand Down Expand Up @@ -317,7 +319,10 @@ def get_input(estimator):
tags = get_tags(estimator)

if tags.input_tags.pairwise:
return np.random.rand(N_FEATURES, N_FEATURES), None
if not tags.target_tags.required:
return np.random.rand(N_FEATURES, N_FEATURES), None
else:
return np.random.rand(N_FEATURES, N_FEATURES), y[:N_FEATURES]
TamaraAtanasoska marked this conversation as resolved.
Show resolved Hide resolved

if tags.input_tags.two_d_array:
# Some models require positive X
Expand All @@ -338,7 +343,7 @@ def get_input(estimator):

if tags.input_tags.categorical:
X = [["Male", 1], ["Female", 3], ["Female", 2]]
y = y[: len(X)] if tags.y_required else None
y = y[: len(X)] if tags.target_tags.required else None
return X, y

if tags.input_tags.dict:
Expand Down Expand Up @@ -417,7 +422,6 @@ def test_can_trust_types(type_):
def test_unsupported_type_raises(estimator):
"""Estimators that are known to fail should raise an error"""
set_random_state(estimator, random_state=0)

X, y = get_input(estimator)
if get_tags(estimator).requires_fit:
with warnings.catch_warnings():
Expand Down
4 changes: 2 additions & 2 deletions skops/utils/_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def construct_instances(estimator):
try:
from sklearn.utils._test_common.instance_generator import _construct_instances

return next(_construct_instances(estimator))
return list(_construct_instances(estimator))

except ImportError:
from sklearn.utils.estimator_checks import _construct_instance

return _construct_instance(estimator)
return [_construct_instance(estimator)]


"""
Expand Down
Loading