Specialize sample for sparse weights#943
Specialize sample for sparse weights#943AntonOresten wants to merge 5 commits intoJuliaStats:masterfrom
sample for sparse weights#943Conversation
|
Bump. |
|
My gut feeling is that we should address #885 first and then add a specialisation to the SparseArrays extension. By keeping a hard dependency on SparseArrays, StatsBase is holding back large parts of the Julia ecosystem. |
|
Yeah but this method is easy to move to an extension as soon as we create it, and it doesn't make things worse until then. |
Co-authored-by: Milan Bouchet-Valat <[email protected]>
Co-authored-by: Milan Bouchet-Valat <[email protected]>
d852edf to
db126d1
Compare
| i = sample(rng, Weights(nonzeros(wv.values), sum(wv))) | ||
| return rowvals(wv.values)[i] |
There was a problem hiding this comment.
The code is unsafe - in general AbstractWeights are not required to have a values field. It's just a few AbstractWeights subtypes in StatsBase that have an (undocumented and internal) values field.
There was a problem hiding this comment.
So actually better define this method only for types defines in Base. Probably using:
for W in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights)
@eval function sample(rng::AbstractRNG, wv::W{<:Real,<:Real,<:SparseVector})
...(I'm saying this because AFAICT there's no public API which allows accessing the backing array. And anyway I'm not aware of custom AbstractWeights types defined elsewhere so we don't really care to apply this optimization to them.)
There was a problem hiding this comment.
The tests are insufficient - since the method is implemented for AbstractWeights, to be sure it works not only for Weights we should test all subtypes implemented in StatsBase and a custom subtype of AbstractWeights.
This PR adds a new
samplemethod for sparse weights, as well as tests. It brings the time complexity fromO(n)toO(n_nonzero).This would be useful for e.g. top-p sampling, where one might have on the order of 100k tokens to sample from, but only a few are considered.
Benchmarks across different sizes and densities
Results
This shows the dense baseline, and the relative performance increase to invoking
samplewith the generic method for sparse weights.Benchmark setup
Note: For small vector lengths (~10) and low densities (~0.2) the performance difference becomes noisy and less meaningful. The generic method can sometimes be faster in these cases due to less overhead when it happens to find the target probability mass early in the vector. However, for these small cases the absolute timing differences are negligible (few nanoseconds) and sparse storage isn't really beneficial anyway.
Note: The implementation uses
SparseArrays.nonzeroinds, which is not public.