Skip to content

Commit

Permalink
ENH add support for Array API to mean_pinball_loss and `explained_v…
Browse files Browse the repository at this point in the history
…ariance_score` (#29978)
  • Loading branch information
virchan authored Dec 3, 2024
1 parent fba028b commit 8ded7f4
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 25 deletions.
2 changes: 2 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,13 @@ Metrics
- :func:`sklearn.metrics.cluster.entropy`
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.explained_variance_score`
- :func:`sklearn.metrics.f1_score`
- :func:`sklearn.metrics.max_error`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_absolute_percentage_error`
- :func:`sklearn.metrics.mean_gamma_deviance`
- :func:`sklearn.metrics.mean_pinball_loss`
- :func:`sklearn.metrics.mean_poisson_deviance` (requires `enabling array API support for SciPy <https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support>`_)
- :func:`sklearn.metrics.mean_squared_error`
- :func:`sklearn.metrics.mean_squared_log_error`
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/29978.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :func:`sklearn.metrics.explained_variance_score` and
:func:`sklearn.metrics.mean_pinball_loss` now support Array API compatible inputs.
by :user:`Virgil Chan <virchan>`
62 changes: 38 additions & 24 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def mean_absolute_error(
if multioutput == "raw_values":
return output_errors
elif multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
# pass None as weights to _average: uniform mean
multioutput = None

# Average across the outputs (if needed).
Expand Down Expand Up @@ -360,35 +360,45 @@ def mean_pinball_loss(
>>> from sklearn.metrics import mean_pinball_loss
>>> y_true = [1, 2, 3]
>>> mean_pinball_loss(y_true, [0, 2, 3], alpha=0.1)
np.float64(0.03...)
0.03...
>>> mean_pinball_loss(y_true, [1, 2, 4], alpha=0.1)
np.float64(0.3...)
0.3...
>>> mean_pinball_loss(y_true, [0, 2, 3], alpha=0.9)
np.float64(0.3...)
0.3...
>>> mean_pinball_loss(y_true, [1, 2, 4], alpha=0.9)
np.float64(0.03...)
0.03...
>>> mean_pinball_loss(y_true, y_true, alpha=0.1)
np.float64(0.0)
0.0
>>> mean_pinball_loss(y_true, y_true, alpha=0.9)
np.float64(0.0)
0.0
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)

_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)

check_consistent_length(y_true, y_pred, sample_weight)
diff = y_true - y_pred
sign = (diff >= 0).astype(diff.dtype)
sign = xp.astype(diff >= 0, diff.dtype)
loss = alpha * sign * diff - (1 - alpha) * (1 - sign) * diff
output_errors = np.average(loss, weights=sample_weight, axis=0)
output_errors = _average(loss, weights=sample_weight, axis=0)

if isinstance(multioutput, str) and multioutput == "raw_values":
return output_errors

if isinstance(multioutput, str) and multioutput == "uniform_average":
# pass None as weights to np.average: uniform mean
# pass None as weights to _average: uniform mean
multioutput = None

return np.average(output_errors, weights=multioutput)
# Average across the outputs (if needed).
# The second call to `_average` should always return
# a scalar array that we convert to a Python float to
# consistently return the same eager evaluated value.
# Therefore, `axis=None`.
return float(_average(output_errors, weights=multioutput))


@validate_params(
Expand Down Expand Up @@ -949,12 +959,12 @@ def _assemble_r2_explained_variance(
# return scores individually
return output_scores
elif multioutput == "uniform_average":
# Passing None as weights to np.average results is uniform mean
# pass None as weights to _average: uniform mean
avg_weights = None
elif multioutput == "variance_weighted":
avg_weights = denominator
if not xp.any(nonzero_denominator):
# All weights are zero, np.average would raise a ZeroDiv error.
# All weights are zero, _average would raise a ZeroDiv error.
# This only happens when all y are constant (or 1-element long)
# Since weights are all equal, fall back to uniform weights.
avg_weights = None
Expand Down Expand Up @@ -1083,28 +1093,32 @@ def explained_variance_score(
>>> explained_variance_score(y_true, y_pred, force_finite=False)
-inf
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight, multioutput)

_, y_true, y_pred, sample_weight, multioutput = (
_check_reg_targets_with_floating_dtype(
y_true, y_pred, sample_weight, multioutput, xp=xp
)
)

check_consistent_length(y_true, y_pred, sample_weight)

y_diff_avg = np.average(y_true - y_pred, weights=sample_weight, axis=0)
numerator = np.average(
y_diff_avg = _average(y_true - y_pred, weights=sample_weight, axis=0)
numerator = _average(
(y_true - y_pred - y_diff_avg) ** 2, weights=sample_weight, axis=0
)

y_true_avg = np.average(y_true, weights=sample_weight, axis=0)
denominator = np.average((y_true - y_true_avg) ** 2, weights=sample_weight, axis=0)
y_true_avg = _average(y_true, weights=sample_weight, axis=0)
denominator = _average((y_true - y_true_avg) ** 2, weights=sample_weight, axis=0)

return _assemble_r2_explained_variance(
numerator=numerator,
denominator=denominator,
n_outputs=y_true.shape[1],
multioutput=multioutput,
force_finite=force_finite,
xp=get_namespace(y_true)[0],
# TODO: update once Array API support is added to explained_variance_score.
device=None,
xp=xp,
device=device,
)


Expand Down
8 changes: 8 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,10 +2084,18 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_regression_metric_multioutput,
],
cosine_similarity: [check_array_api_metric_pairwise],
explained_variance_score: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_absolute_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_pinball_loss: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
],
mean_squared_error: [
check_array_api_regression_metric,
check_array_api_regression_metric_multioutput,
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def test_mean_pinball_loss_on_constant_predictions(distribution, target_quantile
# Check that the loss of this constant predictor is greater or equal
# than the loss of using the optimal quantile (up to machine
# precision):
assert pbl >= best_pbl - np.finfo(best_pbl.dtype).eps
assert pbl >= best_pbl - np.finfo(np.float64).eps

# Check that the value of the pinball loss matches the analytical
# formula.
Expand Down

0 comments on commit 8ded7f4

Please sign in to comment.