Skip to content

Commit

Permalink
Merge pull request DifferentiableUniverseInitiative#67 from Different…
Browse files Browse the repository at this point in the history
…iableUniverseInitiative/u/EiffL/jax_update

Updating to jax >v0.2
  • Loading branch information
EiffL authored Oct 17, 2020
2 parents 842225f + c84bfed commit e0c7b37
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 134 deletions.
12 changes: 6 additions & 6 deletions jax_cosmo/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def dchioverdlna(y, x):


def a_of_chi(cosmo, chi):
r""" Computes the scale factor for corresponding (array) of radial comoving
r"""Computes the scale factor for corresponding (array) of radial comoving
distance by reverse linear interpolation.
Parameters:
Expand Down Expand Up @@ -366,7 +366,7 @@ def angular_diameter_distance(cosmo, a):


def growth_factor(cosmo, a):
""" Compute linear growth factor D(a) at a given scale factor,
"""Compute linear growth factor D(a) at a given scale factor,
normalized such that D(a=1) = 1.
Parameters
Expand Down Expand Up @@ -396,7 +396,7 @@ def growth_factor(cosmo, a):


def growth_rate(cosmo, a):
""" Compute growth rate dD/dlna at a given scale factor.
"""Compute growth rate dD/dlna at a given scale factor.
Parameters
----------
Expand Down Expand Up @@ -438,7 +438,7 @@ def growth_rate(cosmo, a):


def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
""" Compute linear growth factor D(a) at a given scale factor,
"""Compute linear growth factor D(a) at a given scale factor,
normalised such that D(a=1) = 1.
Parameters
Expand Down Expand Up @@ -486,7 +486,7 @@ def D_derivs(y, x):


def _growth_rate_ODE(cosmo, a):
""" Compute growth rate dD/dlna at a given scale factor by solving the linear
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
growth ODE.
Parameters
Expand All @@ -510,7 +510,7 @@ def _growth_rate_ODE(cosmo, a):


def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
r""" Computes growth factor by integrating the growth rate provided by the
r"""Computes growth factor by integrating the growth rate provided by the
\gamma parametrization. Normalized such that D( a=1) =1
Parameters
Expand Down
21 changes: 10 additions & 11 deletions jax_cosmo/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@


def primordial_matter_power(cosmo, k):
""" Primordial power spectrum
Pk = k^n
"""Primordial power spectrum
Pk = k^n
"""
return k ** cosmo.n_s


def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwargs):
r""" Computes the linear matter power spectrum.
r"""Computes the linear matter power spectrum.
Parameters
----------
Expand Down Expand Up @@ -59,7 +59,7 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar


def sigmasqr(cosmo, R, transfer_fn, kmin=0.0001, kmax=1000.0, ksteps=5, **kwargs):
""" Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc
"""Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc
.. math::
Expand All @@ -84,15 +84,14 @@ def int_sigma(logk):


def linear(cosmo, k, a, transfer_fn):
"""Linear matter power spectrum
"""
"""Linear matter power spectrum"""
return linear_matter_power(cosmo, k, a, transfer_fn)


def _halofit_parameters(cosmo, a, transfer_fn):
r""" Computes the non linear scale,
effective spectral index,
spectral curvature
r"""Computes the non linear scale,
effective spectral index,
spectral curvature
"""
# Step 1: Finding the non linear scale for which sigma(R)=1
# That's our search range for the non linear scale
Expand Down Expand Up @@ -144,7 +143,7 @@ def integrand(logk):


def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
r""" Computes the non linear halofit correction to the matter power spectrum.
r"""Computes the non linear halofit correction to the matter power spectrum.
Parameters
----------
Expand Down Expand Up @@ -271,7 +270,7 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
def nonlinear_matter_power(
cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=halofit
):
""" Computes the non-linear matter power spectrum.
"""Computes the non-linear matter power spectrum.
This function is just a wrapper over several nonlinear power spectra.
"""
Expand Down
9 changes: 4 additions & 5 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def noise(self):

@register_pytree_node_class
class NumberCounts(container):
""" Class representing a galaxy clustering probe, with a bunch of bins
"""Class representing a galaxy clustering probe, with a bunch of bins
Parameters:
-----------
Expand All @@ -214,14 +214,13 @@ def zmax(self):

@property
def n_tracers(self):
""" Returns the number of tracers for this probe, i.e. redshift bins
"""
"""Returns the number of tracers for this probe, i.e. redshift bins"""
# Extract parameters
pzs = self.params[0]
return len(pzs)

def kernel(self, cosmo, z, ell):
""" Compute the radial kernel for all nz bins in this probe.
"""Compute the radial kernel for all nz bins in this probe.
Returns:
--------
Expand All @@ -235,7 +234,7 @@ def kernel(self, cosmo, z, ell):
return kernel

def noise(self):
""" Returns the noise power for all redshifts
"""Returns the noise power for all redshifts
return: shape [nbins]
"""
# Extract parameters
Expand Down
23 changes: 9 additions & 14 deletions jax_cosmo/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,18 @@

class redshift_distribution(container):
def __init__(self, *args, gals_per_arcmin2=1.0, zmax=10.0, **kwargs):
""" Initialize the parameters of the redshift distribution
"""
"""Initialize the parameters of the redshift distribution"""
self._norm = None
self._gals_per_arcmin2 = gals_per_arcmin2
super(redshift_distribution, self).__init__(*args, zmax=zmax, **kwargs)

@abstractmethod
def pz_fn(self, z):
""" Un-normalized n(z) function provided by sub classes
"""
"""Un-normalized n(z) function provided by sub classes"""
pass

def __call__(self, z):
""" Computes the normalized n(z)
"""
"""Computes the normalized n(z)"""
if self._norm is None:
self._norm = simps(lambda t: self.pz_fn(t), 0.0, self.config["zmax"], 256)
return self.pz_fn(z) / self._norm
Expand All @@ -44,15 +41,14 @@ def zmax(self):

@property
def gals_per_arcmin2(self):
""" Returns the number density of galaxies in gals/sq arcmin
"""Returns the number density of galaxies in gals/sq arcmin
TODO: find a better name
"""
return self._gals_per_arcmin2

@property
def gals_per_steradian(self):
""" Returns the number density of galaxies in steradian
"""
"""Returns the number density of galaxies in steradian"""
return self._gals_per_arcmin2 * steradian_to_arcmin2

# Operations for flattening/unflattening representation
Expand All @@ -69,7 +65,7 @@ def tree_unflatten(cls, aux_data, children):

@register_pytree_node_class
class smail_nz(redshift_distribution):
""" Defines a smail distribution with these arguments
"""Defines a smail distribution with these arguments
Parameters:
-----------
a:
Expand All @@ -88,7 +84,7 @@ def pz_fn(self, z):

@register_pytree_node_class
class kde_nz(redshift_distribution):
""" A redshift distribution based on a KDE estimate of the nz of a
"""A redshift distribution based on a KDE estimate of the nz of a
given catalog currently uses a Gaussian kernel.
TODO: add more if necessary
Expand All @@ -106,8 +102,7 @@ class kde_nz(redshift_distribution):
"""

def _kernel(self, bw, X, x):
""" Gaussian kernel for KDE
"""
"""Gaussian kernel for KDE"""
return (1.0 / np.sqrt(2 * np.pi) / bw) * np.exp(
-((X - x) ** 2) / (bw ** 2 * 2.0)
)
Expand All @@ -124,7 +119,7 @@ def pz_fn(self, z):

@register_pytree_node_class
class systematic_shift(redshift_distribution):
""" Implements a systematic shift in a redshift distribution
"""Implements a systematic shift in a redshift distribution
TODO: Find a better name for this
Arguments:
Expand Down
Loading

0 comments on commit e0c7b37

Please sign in to comment.