Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add array api for log_loss #30439

Closed
wants to merge 9 commits into from

Conversation

OmarManzoor
Copy link
Contributor

@OmarManzoor OmarManzoor commented Dec 9, 2024

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

  • Adds array api support for log_loss

Note: As discussed in #28626, we handle conversion to numpy when using the LabelBinarizer as it does not support the array api currently. However this might not be feasible when we have to move data between devices. Therefore, this PR depends on LabelBinarizer supporting the array api in order to provide the expected performance gains.

Any other comments?

CC: @ogrisel @adrinjalali @betatim

Copy link

github-actions bot commented Dec 9, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: cfabbd0. Link to the linter CI: here

@OmarManzoor
Copy link
Contributor Author

The CUDA tests seem to be green 🟢

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick feedback about the new _allclose helper. Otherwise LGTM.

sklearn/utils/_array_api.py Outdated Show resolved Hide resolved
sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
y_pred_sum = y_pred.sum(axis=1)
if not np.allclose(y_pred_sum, 1, rtol=np.sqrt(eps)):
y_pred_sum = xp.sum(y_pred, axis=1)
if not _allclose(y_pred_sum, 1, rtol=np.sqrt(eps), xp=xp):
Copy link
Contributor Author

@OmarManzoor OmarManzoor Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Since we are internally converting to numpy in all cases in the all_close helper, we can use np.sqrt instead of xp.sqrt here as this is just a scalar value and using xp.sqrt will require us to unnecessarily convert eps into an array to satisfy array api strict.


check_consistent_length(y_pred, y_true, sample_weight)
lb = LabelBinarizer()

if labels is not None:
lb.fit(labels)
lb.fit(_convert_to_numpy(labels, xp=xp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat then missing the point of "supporting array API" here. I'd say we support array API if we don't convert to Numpy, and here we do. So in effect, there's not much of an improvement with this PR.

I think in order to get this merged, LabelBinarizer should support array API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that would be better, but I think we still perform computations after the LabelBinarizer part. Particularly the sums, clipping and xlogy, that might still bring some improvements as scipy's xlogy supports the array api.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might not be worth moving the data back and forth between devices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think you are right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the description to reflect this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that means we need to shelf this PR until we fix label binarizer.

@OmarManzoor
Copy link
Contributor Author

Closing for now, as this depends on LabelBinarizer supporting the array api.

@ogrisel
Copy link
Member

ogrisel commented Dec 10, 2024

array API does not and will never support str objects in arrays. Since LabelBinarizer is mostly about mapping an array of str objects class labels to an array of integers, its input will rarely be array API. So I don't think it makes sense for LabelBinarizer to ever "support" array API.

But we want to have (some) classifiers that accept X as array API, y as numpy of object labels: the y is internally encoded as from class labels to integers (using a LabelBinarizer internally), then the resulting integer coded y is moved into the same namespace and device as X and the probabilistic predictions of that classifier will naturally be in the array API namespace.

So, to me, it would make sense to make log loss and other classification metrics accept array API inputs, even if it's just converting back to numpy internally. The computational intensive part is in the fit of the classifier, not in the computation of the log loss.

For instance, we would like to be able to do something like

pipeline = make_pipeline(
    TableVectorizer(),  # feature engineering with heterogenous types in pandas
    FunctionTransformer(func=lambda x: torch.tensor(x).astype(torch.float32).to("cuda")),
    RidgeClassifier(),
)


cross_validate(
   pipeline,
   X_pandas_dataframe,
   y_pandas_series,
   scoring="neg_log_loss",
)

And similarly for a RandomizedSearchCV call.

Note that the function transformer in charge of moving the numerical features (extracted from X_pandas_dataframe by skrub or similar pandas capable transformers) to the GPU does not impact y_pandas_series that is fed directly to the fit method of the classifier and to the log_loss metric function

@adrinjalali
Copy link
Member

But in the above example, the data passed to neg_log_loss would be numpy, not array API. So you can already do that.

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Dec 10, 2024

If the y_true or labels will usually be strings then I don't think there should be any issue as they will already be on the cpu and we convert them to the array api and move them to the GPU after they are binarized. The only concern is if y_true or labels are already on the GPU and we move them to the CPU and then move the binarized output back to the GPU.

@ogrisel
Copy link
Member

ogrisel commented Dec 10, 2024

But in the above example, the data passed to neg_log_loss would be numpy, not array API. So you can already do that.

The y_true argument passed to log_loss will be a slice pandas series (derived from y_pandas_series by the CV splitter) but the second argument will be the output of pipeline.predict_proba(X_pandas_series_cv_split_i) which will be a pytorch tensor allocated on a CUDA GPU.

@ogrisel
Copy link
Member

ogrisel commented Dec 10, 2024

Similarly to the "y-follows-X" policy we want to implement in the fit method of estimators, we might want to decide officially on a y_true-follows-y_pred or y_pred-follows-y_true policy for the metric functions in case both inputs do not stem from the same namespace.

For soft-classification metrics like log-loss or ROC AUC, y_pred will be the output of either predict_proba or decision_function and could therefore naturally be allocated on the namespace / device of the X parameter passed to the fit method of the last step of the pipeline. y_true could either be an array API integer array or a numpy arrays or pandas series with arbitrarily typed class labels:

>>> from sklearn.metrics import log_loss
>>> log_loss(y_true=["a", "b", "a"], y_pred=[[0.9, 0.1], [0.2, 0.8], [0.8, 0.2]], labels=["a", "b"])
0.18388253942874858

For log-loss we don't care if we do y_true-follows-y_pred or y_pred-follows-y_true as the computation should not be expensive, so CPU or GPU computation should not matter much compared to fitting or predicting with complex models and feature sets.

But for ROC-AUC, it might be interesting to do most of the computation using the namespace and device of y_pred because it involves sorting, which can be much faster on GPUs. So we could move the label-encoded y_true to the same device and namespace as y_pred.

The result is a flow scalar in most cases. Not sure what should be the output namespace / device in case we output an array, e.g. roc_auc_score with average=None on multiclass problems...

@StefanieSenger
Copy link
Contributor

StefanieSenger commented Dec 10, 2024

If the y_true or labels will usually be strings then I don't think there should be any issue as they will already be on the cpu and we convert them to the array api and move them to the GPU after they are binarized. The only concern is if y_true or labels are already on the GPU and we move them to the CPU and then move the binarized output back to the GPU.

Sorry, @OmarManzoor, can you explain that? Is it that moving data to GPU is less costly than the other way around? Or do you mean that moving twice is worse than moving once?

I'm asking to get some more information on data transfer between devices. I believe that with my cpu pc and Colab (where I have to start a dedicated runtime), I cannot try this out myself?

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Dec 11, 2024

Sorry, @OmarManzoor, can you explain that? Is it that moving data to GPU is less costly than the other way around? Or do you mean that moving twice is worse than moving once?

I don't think that moving to the GPU or from the GPU should be too different. But yes moving back and forth is more heavy than simply moving once and doing all the computations there.

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Dec 11, 2024

For log-loss we don't care if we do y_true-follows-y_pred or y_pred-follows-y_true as the computation should not be expensive, so CPU or GPU computation should not matter much compared to fitting or predicting with complex models and feature sets.

I think since y_pred seems to be the actual numerical array it might make sense to extract the namespace based on that.

Is it possible to maybe fix the device to be "cpu" inside log_loss and then do all the computations on the cpu? Even though that would eliminate any advantages that could possibly be gained by using the GPU, I think the computation as you mentioned is not too expensive.

@ogrisel
Copy link
Member

ogrisel commented Dec 11, 2024

But yes moving back and forth

The impact of moving back and forth across devices really depends on how many times (twice vs e.g. vs hundred of millions of times) we do it and how much computation we do on the device (arithmetic intensity) compared to data transfers. It's best to do measurements with timeit in case of doubt.

@StefanieSenger
Copy link
Contributor

Best it to do measurements with timeit in case of doubt.

Yes. Currently, we don't have a place to put pre-defined data to use for these kind of timeit tests, correct? Currently, we decide on case to case what we are interested in about the performance: whether it is the edge cases where the data is very big or has specific characteristics, or more in the normal use cases with what users would normally provide. And we decide from case to case how we evaluate the tradeoff between those different benchmarks, if I see this correctly.

I believe that having a set of data defined to use for testing would be very helpful to compare results between different array libraries with array api implemented and it would be more efficient to discuss the results compared to the current status.

What do you think about defining data that we want to test on together and write it down somewhere? People could then use it if they find it helpful (I certainly would).

@ogrisel
Copy link
Member

ogrisel commented Dec 11, 2024

I believe that having a set of data defined to use for testing would be very helpful to compare results between different array libraries with array api implemented and it would be more efficient to discuss the results compared to the current status.

In most cases, CPU vs GPU quick perf check can be conducted with random data compatible with whatever the expectations of the function to benchmark. Since functions have different expectations (as defined in the docstring of the function), it's hard to come up with generic test data.

For instance, if you want a 1D array with 1 million random floating-point, with positive and negative values:

import numpy as np
import torch
data_np = np.random.default_rng(0).normal(size=int(1e6)).astype(np.float32)
data_torch = torch.tensor(data_np)

If you want random 1 million random integers between 0 and 9 included:

import numpy as np
import torch
data_np = np.random.default_rng(0).integers(0, 10, size=int(1e6)).astype(np.int32)
data_torch = torch.tensor(data_np)

If you need 2D data with shape=(n_samples=10_000_000, n_classes=5) with positive-only values and rows that sum to 1 (like the output of y_pred = clf.predict_proba(X)) you can use something like:

import numpy as np
import torch
data_np = np.random.default_rng(0).uniform(0, 1, size=(int(1e7), 5)).astype(np.float32)
data_np /= data_np.sum(axis=1, keepdims=True)
data_torch = torch.tensor(data_np)

Depending on the functions we are benchmarking, we have to think about the typical data shapes our users who have access to GPUs will care about. And we need data that is large enough for the difference to be meaningfully measurable. Some functions are very scalable (with linear complexity) and should therefore be benchmarked with datasets that are quite large (but still fit in host or device memory). Some other functions are not as scalable, e.g. n log(n) or even quadratic in the number of data points or features and therefore, we need to adjust the test data size to benchmark with a case that stays tractable.

In general, my rule of thumb would be to feed input data that is large enough for the function to take at least a few seconds to run with the slowest of the two alternatives to compare (e.g. when comparing execution on CPU with numpy vs CUDA GPU with torch or cupy).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants