Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement optional sparse Gaussian covariance #58

Merged
merged 21 commits into from
Jul 26, 2020

Conversation

dkirkby
Copy link
Collaborator

@dkirkby dkirkby commented Jul 22, 2020

Add a sparse=False option to angular_cl.gaussian_cl_covariance and angular_cl.gaussian_cl_covariance_and_mean that returns a sparse representation of the covariance matrix that uses a factor n_ell less memory. Specifically, the sparse representation has shape (n_cls, n_cls, n_ell) compared with the dense shape (n_cls * n_ell, n_cls * n_ell).

Add a sparse module to implement efficient linear algebra operations on this sparse representation. Currently, there is to_dense, inv and vecdot. I still need to implement det which is more complicated. These are generally a factor ~n_ell faster than their dense equivalents (on a CPU at least).

@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 22, 2020

I don't see any obvious error message in the output of the failing style check. Any ideas?

@EiffL
Copy link
Member

EiffL commented Jul 22, 2020

^^' ok, fixed the issues. It's because I included reorder_python_imports in the style checks, in addition to Black. You can setup pre-commit hooks to solve this for you automatically. Check out that section of the design document: https://github.com/DifferentiableUniverseInitiative/jax_cosmo/blob/master/design.md#code-style

@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 22, 2020

Thanks @EiffL. I did follow the instructions in the design doc, but I hadn't used pre-commit before and see now that you have to run pre-commit install after the pip installs. I just updated design.md accordingly.

@EiffL
Copy link
Member

EiffL commented Jul 22, 2020

My bad. Thanks for the added documentation!
I've gone through the PR, it all looks good to me, thanks a lot for even including tests :-)

Only comment, is that we might want to move sparse.py to scipy/sparse.py. As a way to keep track of functions and tools that are not yet available in jax.scipy but might one day be. As I was writing this, I saw that there are already a few ports of sparse functions https://github.com/google/jax/tree/master/jax/scipy/sparse but I don't think there is overlap with what you have coded here.

@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 22, 2020

This isn't one of the sparse formats supported by scipy.sparse (I don't think it even has a name but I like "diagonal block") so should it still go in scipy/sparse.py if it is not implementing functionality that already exists there?

Note that we could store Gaussian covariances efficiently using the dia_matrix format, but most of the speedup in the linear algebra comes from the block structure which dia_matrix doesn't use.

@EiffL
Copy link
Member

EiffL commented Jul 22, 2020

yeah just wanted to highlight that aspect but indeed this is slightly different. I'm very ok with your current solution.

@EiffL
Copy link
Member

EiffL commented Jul 22, 2020

I was in the middle of implementing the challenge metrics with these new tools. Because my brain is fried I got stuck on this part of the Fisher computation ^^' originally here: https://github.com/LSSTDESC/tomo_challenge/blob/280a5a566158e8775c91ea3cf297b27c5232fd38/tomo_challenge/jax_metrics.py#L159

        mu = mean(fid_params)
        dmu = jac_mean(fid_params)

        # Compute the covariance matrix
        cl_noise = jc.angular_cl.noise_cl(ell, probes)
        C = jc.angular_cl.gaussian_cl_covariance(ell, probes, mu, cl_noise)

        invCov = np.linalg.inv(C)

        # Compute Fisher matrix for constant covariance
        F = np.einsum('pa,pq,qb->ab', dmu, invCov, dmu)

Any ideas for an elegant way to do this einsum at the end when invCov is sparse?

@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 23, 2020

Isn't that just dmu.T @ invCov @ dmu? I already implemented A @ B when A and B are both sparse, so I need to add support for either A or B being dense.

Btw, why do you assume constant covariance?

@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 23, 2020

I implemented the sparse determinant, which was trickier than I expected. I was never able to eliminate the outer python loop using any of the lax flow control structures, but any tips are welcome! The computations within each loop can be done in parallel, so it seems a shame to force sequential flow with the for loop.

@EiffL
Copy link
Member

EiffL commented Jul 23, 2020

I was assuming constant covariance as this is what the challenge metric was using. I think that's all that the Fisher sampler in cosmosis can handle. I don't have issues with adding cosmology dependent covariance, which would be the right thing to do ^^' but that's a discussion for the tomo_challenge thread.

I'm gonna have a look at your sparse det :-)

@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 23, 2020

Any ideas for an elegant way to do this einsum at the end when invCov is sparse?

I implemented a sparse.dot front end to the jit-compiled kernels for different type combinations so you can now do:

F = sparse.dot(dmu.T, inv(C), dmu)

or to use the compiled kernel directly (without any input validation):

F = sparse.dense_dot_sparse_dot_dense(dmu.T, inv(C), dmu)

@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 23, 2020

I am marking this "ready for review" now but let me know if you think anything is still missing or could be improved.

@dkirkby dkirkby marked this pull request as ready for review July 23, 2020 21:31
@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 23, 2020

I just ran some timing benchmarks on colab of the likelihood calculation in docs/notebooks/jax-cosmo-intro:

Mode CPU GPU TPU
dense 170ms 48ms 48ms
sparse 159ms 20ms 26ms

So we see a nice speed up with an accelerator, in addition to the smaller memory footprint (50x in this case). At this point, the cl calculation is the tall pole.

@EiffL EiffL self-requested a review July 25, 2020 16:55
@EiffL
Copy link
Member

EiffL commented Jul 25, 2020

It all looks good to me, I'm just finishing testing that I don't get any problems when computing derivatives of FoM.

Copy link
Member

@EiffL EiffL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything works :-) I just made one suggestion on the code, another small suggestion would be to add the code for sparse likelihood evaluation in

def gaussian_log_likelihood(data, mu, C, constant_cov=True, inverse_method="inverse"):

jax_cosmo/sparse.py Outdated Show resolved Hide resolved
@dkirkby
Copy link
Collaborator Author

dkirkby commented Jul 25, 2020

I am working on updating likelihood.py now...

Should we be using jax.lax.stop_gradient here when constant_cov == True and perhaps wrapping with @jit (using static_argnums for the last two args)?

@EiffL
Copy link
Member

EiffL commented Jul 26, 2020

I've given it some thought, I think maybe let's not put the stop gradient in the likelihood, and instead I've renamed the option to just specify if yes or no the log determinant is computed, and added some documentation.
My worry would be to have gradient stops hidden in the code, where the user doesnt see them, and at the same time without the gradient stop there is a higher chance that the likelihood function may be used incorrectly with the keyword 'constant_cov'.

Regading jit, I'm honestly not 100% sure where and when to jit. So far I've gone with only jitting very low level functions, that may be reused in different places. But more complex functions, or functions with complex arguments I have left alone. And the user can jit their own functions built on top of that. So I think it's ok to leave the likelihood function as it is, but if you see somewhere an argument for jitting, I'm curious :-)

Copy link
Member

@EiffL EiffL left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good! I'm approving, thanks so much @dkirkby for this!

@EiffL EiffL merged commit 060ef0e into DifferentiableUniverseInitiative:master Jul 26, 2020
@EiffL
Copy link
Member

EiffL commented Jul 26, 2020

@all-contributors please add @dkirkby for code

@allcontributors
Copy link
Contributor

@EiffL

I've put up a pull request to add @dkirkby! 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants