Skip to content

Commit

Permalink
MAINT PairwiseDistancesReduction: Do correctly warn on unused metri…
Browse files Browse the repository at this point in the history
…c kwargs (scikit-learn#22865)
  • Loading branch information
jjerphan authored Mar 17, 2022
1 parent b9bd2d5 commit 8123ed1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
20 changes: 18 additions & 2 deletions sklearn/metrics/_pairwise_distances_reduction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1101,9 +1101,13 @@ cdef class FastEuclideanPairwiseDistancesArgKmin(PairwiseDistancesArgKmin):
strategy=None,
metric_kwargs=None,
):
if metric_kwargs is not None and len(metric_kwargs) > 0:
if (
metric_kwargs is not None and
len(metric_kwargs) > 0 and
"Y_norm_squared" not in metric_kwargs
):
warnings.warn(
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't"
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "
f"usable for this case ({self.__class__.__name__}) and will be ignored.",
UserWarning,
stacklevel=3,
Expand Down Expand Up @@ -1647,6 +1651,18 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood(PairwiseDistancesRad
sort_results=False,
metric_kwargs=None,
):
if (
metric_kwargs is not None and
len(metric_kwargs) > 0 and
"Y_norm_squared" not in metric_kwargs
):
warnings.warn(
f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "
f"usable for this case ({self.__class__.__name__}) and will be ignored.",
UserWarning,
stacklevel=3,
)

super().__init__(
# The datasets pair here is used for exact distances computations
datasets_pair=DatasetsPair.get_for(X, Y, metric="euclidean"),
Expand Down
26 changes: 26 additions & 0 deletions sklearn/metrics/tests/test_pairwise_distances_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,19 @@ def test_argkmin_factory_method_wrong_usages():
X=np.asfortranarray(X), Y=Y, k=k, metric=metric
)

unused_metric_kwargs = {"p": 3}

message = (
r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this"
r" case \("
r"FastEuclideanPairwiseDistancesArgKmin\) and will be ignored."
)

with pytest.warns(UserWarning, match=message):
PairwiseDistancesArgKmin.compute(
X=X, Y=Y, k=k, metric=metric, metric_kwargs=unused_metric_kwargs
)


def test_radius_neighborhood_factory_method_wrong_usages():
rng = np.random.RandomState(1)
Expand Down Expand Up @@ -216,6 +229,19 @@ def test_radius_neighborhood_factory_method_wrong_usages():
X=np.asfortranarray(X), Y=Y, radius=radius, metric=metric
)

unused_metric_kwargs = {"p": 3}

message = (
r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this"
r" case \(FastEuclideanPairwiseDistancesRadiusNeighborhood\) and will be"
r" ignored."
)

with pytest.warns(UserWarning, match=message):
PairwiseDistancesRadiusNeighborhood.compute(
X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=unused_metric_kwargs
)


@pytest.mark.parametrize("n_samples", [100, 1000])
@pytest.mark.parametrize("chunk_size", [50, 512, 1024])
Expand Down

0 comments on commit 8123ed1

Please sign in to comment.