Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into pr-63/aboucaud/remo…
Browse files Browse the repository at this point in the history
…ve-duplicate-sqrtk
  • Loading branch information
EiffL committed Oct 17, 2020
2 parents 22e9cae + 56ad082 commit 5dda720
Show file tree
Hide file tree
Showing 19 changed files with 144 additions and 192 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@
"code",
"bug"
]
},
{
"login": "aboucaud",
"name": "Alexandre Boucaud",
"avatar_url": "https://avatars0.githubusercontent.com/u/3065310?v=4",
"profile": "https://aboucaud.github.io",
"contributions": [
"code"
]
}
],
"contributorsPerLine": 7,
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# jax-cosmo

[![Join the chat at https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo](https://badges.gitter.im/DifferentiableUniverseInitiative/jax_cosmo.svg)](https://gitter.im/DifferentiableUniverseInitiative/jax_cosmo?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Documentation Status](https://readthedocs.org/projects/jax-cosmo/badge/?version=latest)](https://jax-cosmo.readthedocs.io/en/latest/?badge=latest) [![CI Test](https://github.com/DifferentiableUniverseInitiative/jax_cosmo/workflows/Python%20package/badge.svg)]() [![black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![PyPI](https://img.shields.io/pypi/v/jax-cosmo)](https://pypi.org/project/jax-cosmo/) [![PyPI - License](https://img.shields.io/pypi/l/jax-cosmo)](https://github.com/google/jax-cosmo/blob/master/LICENSE) <!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-5-orange.svg?style=flat-square)](#contributors-)
[![All Contributors](https://img.shields.io/badge/all_contributors-6-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->

<h3 align="center">Finally a differentiable cosmology library, and it's in JAX!</h3>
Expand Down Expand Up @@ -95,6 +95,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<td align="center"><a href="https://github.com/austinpeel"><img src="https://avatars0.githubusercontent.com/u/17024310?v=4" width="100px;" alt=""/><br /><sub><b>Austin Peel</b></sub></a><br /><a href="https://github.com/DifferentiableUniverseInitiative/jax_cosmo/commits?author=austinpeel" title="Code">💻</a></td>
<td align="center"><a href="https://minaskaramanis.com"><img src="https://avatars2.githubusercontent.com/u/23280751?v=4" width="100px;" alt=""/><br /><sub><b>Minas Karamanis</b></sub></a><br /><a href="https://github.com/DifferentiableUniverseInitiative/jax_cosmo/commits?author=minaskar" title="Code">💻</a></td>
<td align="center"><a href="https://faculty.sites.uci.edu/dkirkby/"><img src="https://avatars1.githubusercontent.com/u/185007?v=4" width="100px;" alt=""/><br /><sub><b>David Kirkby</b></sub></a><br /><a href="https://github.com/DifferentiableUniverseInitiative/jax_cosmo/commits?author=dkirkby" title="Code">💻</a> <a href="https://github.com/DifferentiableUniverseInitiative/jax_cosmo/issues?q=author%3Adkirkby" title="Bug reports">🐛</a></td>
<td align="center"><a href="https://aboucaud.github.io"><img src="https://avatars0.githubusercontent.com/u/3065310?v=4" width="100px;" alt=""/><br /><sub><b>Alexandre Boucaud</b></sub></a><br /><a href="https://github.com/DifferentiableUniverseInitiative/jax_cosmo/commits?author=aboucaud" title="Code">💻</a></td>
</tr>
</table>

Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# Cosmology in JAX
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pkg_resources import DistributionNotFound
from pkg_resources import get_distribution

Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/angular_cl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module contains functions to compute angular cls for various tracers
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial

import jax.numpy as np
Expand Down
16 changes: 6 additions & 10 deletions jax_cosmo/background.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module implements various functions for the background COSMOLOGY
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np

import jax_cosmo.constants as const
Expand Down Expand Up @@ -249,7 +245,7 @@ def dchioverdlna(y, x):


def a_of_chi(cosmo, chi):
r""" Computes the scale factor for corresponding (array) of radial comoving
r"""Computes the scale factor for corresponding (array) of radial comoving
distance by reverse linear interpolation.
Parameters:
Expand Down Expand Up @@ -366,7 +362,7 @@ def angular_diameter_distance(cosmo, a):


def growth_factor(cosmo, a):
""" Compute linear growth factor D(a) at a given scale factor,
"""Compute linear growth factor D(a) at a given scale factor,
normalized such that D(a=1) = 1.
Parameters
Expand Down Expand Up @@ -396,7 +392,7 @@ def growth_factor(cosmo, a):


def growth_rate(cosmo, a):
""" Compute growth rate dD/dlna at a given scale factor.
"""Compute growth rate dD/dlna at a given scale factor.
Parameters
----------
Expand Down Expand Up @@ -438,7 +434,7 @@ def growth_rate(cosmo, a):


def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
""" Compute linear growth factor D(a) at a given scale factor,
"""Compute linear growth factor D(a) at a given scale factor,
normalised such that D(a=1) = 1.
Parameters
Expand Down Expand Up @@ -486,7 +482,7 @@ def D_derivs(y, x):


def _growth_rate_ODE(cosmo, a):
""" Compute growth rate dD/dlna at a given scale factor by solving the linear
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
growth ODE.
Parameters
Expand All @@ -510,7 +506,7 @@ def _growth_rate_ODE(cosmo, a):


def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
r""" Computes growth factor by integrating the growth rate provided by the
r"""Computes growth factor by integrating the growth rate provided by the
\gamma parametrization. Normalized such that D( a=1) =1
Parameters
Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/bias.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module contains implementations of galaxy bias
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
from jax.tree_util import register_pytree_node_class

Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
from jax.experimental.ode import odeint
from jax.tree_util import register_pytree_node_class
Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/jax_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np


Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module implements a few likelihoods useful
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
import jax.scipy as sp

Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/parameters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module defines a few default cosmologies
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial

from jax_cosmo.core import Cosmology
Expand Down
25 changes: 10 additions & 15 deletions jax_cosmo/power.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module computes power spectra
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax
import jax.numpy as np

Expand All @@ -17,14 +13,14 @@


def primordial_matter_power(cosmo, k):
""" Primordial power spectrum
Pk = k^n
"""Primordial power spectrum
Pk = k^n
"""
return k ** cosmo.n_s


def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwargs):
r""" Computes the linear matter power spectrum.
r"""Computes the linear matter power spectrum.
Parameters
----------
Expand Down Expand Up @@ -59,7 +55,7 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar


def sigmasqr(cosmo, R, transfer_fn, kmin=0.0001, kmax=1000.0, ksteps=5, **kwargs):
""" Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc
"""Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc
.. math::
Expand All @@ -84,15 +80,14 @@ def int_sigma(logk):


def linear(cosmo, k, a, transfer_fn):
"""Linear matter power spectrum
"""
"""Linear matter power spectrum"""
return linear_matter_power(cosmo, k, a, transfer_fn)


def _halofit_parameters(cosmo, a, transfer_fn):
r""" Computes the non linear scale,
effective spectral index,
spectral curvature
r"""Computes the non linear scale,
effective spectral index,
spectral curvature
"""
# Step 1: Finding the non linear scale for which sigma(R)=1
# That's our search range for the non linear scale
Expand Down Expand Up @@ -144,7 +139,7 @@ def integrand(logk):


def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
r""" Computes the non linear halofit correction to the matter power spectrum.
r"""Computes the non linear halofit correction to the matter power spectrum.
Parameters
----------
Expand Down Expand Up @@ -271,7 +266,7 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
def nonlinear_matter_power(
cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=halofit
):
""" Computes the non-linear matter power spectrum.
"""Computes the non-linear matter power spectrum.
This function is just a wrapper over several nonlinear power spectra.
"""
Expand Down
13 changes: 4 additions & 9 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module defines kernel functions for various tracers
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
from jax import jit
from jax import vmap
Expand Down Expand Up @@ -187,7 +183,7 @@ def noise(self):

@register_pytree_node_class
class NumberCounts(container):
""" Class representing a galaxy clustering probe, with a bunch of bins
"""Class representing a galaxy clustering probe, with a bunch of bins
Parameters:
-----------
Expand All @@ -214,14 +210,13 @@ def zmax(self):

@property
def n_tracers(self):
""" Returns the number of tracers for this probe, i.e. redshift bins
"""
"""Returns the number of tracers for this probe, i.e. redshift bins"""
# Extract parameters
pzs = self.params[0]
return len(pzs)

def kernel(self, cosmo, z, ell):
""" Compute the radial kernel for all nz bins in this probe.
"""Compute the radial kernel for all nz bins in this probe.
Returns:
--------
Expand All @@ -235,7 +230,7 @@ def kernel(self, cosmo, z, ell):
return kernel

def noise(self):
""" Returns the noise power for all redshifts
"""Returns the noise power for all redshifts
return: shape [nbins]
"""
# Extract parameters
Expand Down
27 changes: 9 additions & 18 deletions jax_cosmo/redshift.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# Module to define redshift distributions we can differentiate through
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from abc import ABC
from abc import abstractmethod

Expand All @@ -19,21 +15,18 @@

class redshift_distribution(container):
def __init__(self, *args, gals_per_arcmin2=1.0, zmax=10.0, **kwargs):
""" Initialize the parameters of the redshift distribution
"""
"""Initialize the parameters of the redshift distribution"""
self._norm = None
self._gals_per_arcmin2 = gals_per_arcmin2
super(redshift_distribution, self).__init__(*args, zmax=zmax, **kwargs)

@abstractmethod
def pz_fn(self, z):
""" Un-normalized n(z) function provided by sub classes
"""
"""Un-normalized n(z) function provided by sub classes"""
pass

def __call__(self, z):
""" Computes the normalized n(z)
"""
"""Computes the normalized n(z)"""
if self._norm is None:
self._norm = simps(lambda t: self.pz_fn(t), 0.0, self.config["zmax"], 256)
return self.pz_fn(z) / self._norm
Expand All @@ -44,15 +37,14 @@ def zmax(self):

@property
def gals_per_arcmin2(self):
""" Returns the number density of galaxies in gals/sq arcmin
"""Returns the number density of galaxies in gals/sq arcmin
TODO: find a better name
"""
return self._gals_per_arcmin2

@property
def gals_per_steradian(self):
""" Returns the number density of galaxies in steradian
"""
"""Returns the number density of galaxies in steradian"""
return self._gals_per_arcmin2 * steradian_to_arcmin2

# Operations for flattening/unflattening representation
Expand All @@ -69,7 +61,7 @@ def tree_unflatten(cls, aux_data, children):

@register_pytree_node_class
class smail_nz(redshift_distribution):
""" Defines a smail distribution with these arguments
"""Defines a smail distribution with these arguments
Parameters:
-----------
a:
Expand All @@ -88,7 +80,7 @@ def pz_fn(self, z):

@register_pytree_node_class
class kde_nz(redshift_distribution):
""" A redshift distribution based on a KDE estimate of the nz of a
"""A redshift distribution based on a KDE estimate of the nz of a
given catalog currently uses a Gaussian kernel.
TODO: add more if necessary
Expand All @@ -106,8 +98,7 @@ class kde_nz(redshift_distribution):
"""

def _kernel(self, bw, X, x):
""" Gaussian kernel for KDE
"""
"""Gaussian kernel for KDE"""
return (1.0 / np.sqrt(2 * np.pi) / bw) * np.exp(
-((X - x) ** 2) / (bw ** 2 * 2.0)
)
Expand All @@ -124,7 +115,7 @@ def pz_fn(self, z):

@register_pytree_node_class
class systematic_shift(redshift_distribution):
""" Implements a systematic shift in a redshift distribution
"""Implements a systematic shift in a redshift distribution
TODO: Find a better name for this
Arguments:
Expand Down
Loading

0 comments on commit 5dda720

Please sign in to comment.