Skip to content

Commit

Permalink
Remove __future__ imports
Browse files Browse the repository at this point in the history
Since Python2 is not supported by tis program
the `__future__` imports are irrelevant.
  • Loading branch information
aboucaud committed Sep 28, 2020
1 parent 842225f commit 60e51f2
Show file tree
Hide file tree
Showing 14 changed files with 32 additions and 98 deletions.
4 changes: 0 additions & 4 deletions jax_cosmo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# Cosmology in JAX
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pkg_resources import DistributionNotFound
from pkg_resources import get_distribution

Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/angular_cl.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module contains functions to compute angular cls for various tracers
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial

import jax.numpy as np
Expand Down
16 changes: 6 additions & 10 deletions jax_cosmo/background.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module implements various functions for the background COSMOLOGY
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np

import jax_cosmo.constants as const
Expand Down Expand Up @@ -249,7 +245,7 @@ def dchioverdlna(y, x):


def a_of_chi(cosmo, chi):
r""" Computes the scale factor for corresponding (array) of radial comoving
r"""Computes the scale factor for corresponding (array) of radial comoving
distance by reverse linear interpolation.
Parameters:
Expand Down Expand Up @@ -366,7 +362,7 @@ def angular_diameter_distance(cosmo, a):


def growth_factor(cosmo, a):
""" Compute linear growth factor D(a) at a given scale factor,
"""Compute linear growth factor D(a) at a given scale factor,
normalized such that D(a=1) = 1.
Parameters
Expand Down Expand Up @@ -396,7 +392,7 @@ def growth_factor(cosmo, a):


def growth_rate(cosmo, a):
""" Compute growth rate dD/dlna at a given scale factor.
"""Compute growth rate dD/dlna at a given scale factor.
Parameters
----------
Expand Down Expand Up @@ -438,7 +434,7 @@ def growth_rate(cosmo, a):


def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
""" Compute linear growth factor D(a) at a given scale factor,
"""Compute linear growth factor D(a) at a given scale factor,
normalised such that D(a=1) = 1.
Parameters
Expand Down Expand Up @@ -486,7 +482,7 @@ def D_derivs(y, x):


def _growth_rate_ODE(cosmo, a):
""" Compute growth rate dD/dlna at a given scale factor by solving the linear
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
growth ODE.
Parameters
Expand All @@ -510,7 +506,7 @@ def _growth_rate_ODE(cosmo, a):


def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
r""" Computes growth factor by integrating the growth rate provided by the
r"""Computes growth factor by integrating the growth rate provided by the
\gamma parametrization. Normalized such that D( a=1) =1
Parameters
Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/bias.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module contains implementations of galaxy bias
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
from jax.tree_util import register_pytree_node_class

Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
from jax.experimental.ode import odeint
from jax.tree_util import register_pytree_node_class
Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/jax_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np


Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module implements a few likelihoods useful
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
import jax.scipy as sp

Expand Down
4 changes: 0 additions & 4 deletions jax_cosmo/parameters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module defines a few default cosmologies
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial

from jax_cosmo.core import Cosmology
Expand Down
25 changes: 10 additions & 15 deletions jax_cosmo/power.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module computes power spectra
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax
import jax.numpy as np

Expand All @@ -17,14 +13,14 @@


def primordial_matter_power(cosmo, k):
""" Primordial power spectrum
Pk = k^n
"""Primordial power spectrum
Pk = k^n
"""
return k ** cosmo.n_s


def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwargs):
r""" Computes the linear matter power spectrum.
r"""Computes the linear matter power spectrum.
Parameters
----------
Expand Down Expand Up @@ -59,7 +55,7 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar


def sigmasqr(cosmo, R, transfer_fn, kmin=0.0001, kmax=1000.0, ksteps=5, **kwargs):
""" Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc
"""Computes the energy of the fluctuations within a sphere of R h^{-1} Mpc
.. math::
Expand All @@ -84,15 +80,14 @@ def int_sigma(logk):


def linear(cosmo, k, a, transfer_fn):
"""Linear matter power spectrum
"""
"""Linear matter power spectrum"""
return linear_matter_power(cosmo, k, a, transfer_fn)


def _halofit_parameters(cosmo, a, transfer_fn):
r""" Computes the non linear scale,
effective spectral index,
spectral curvature
r"""Computes the non linear scale,
effective spectral index,
spectral curvature
"""
# Step 1: Finding the non linear scale for which sigma(R)=1
# That's our search range for the non linear scale
Expand Down Expand Up @@ -144,7 +139,7 @@ def integrand(logk):


def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
r""" Computes the non linear halofit correction to the matter power spectrum.
r"""Computes the non linear halofit correction to the matter power spectrum.
Parameters
----------
Expand Down Expand Up @@ -271,7 +266,7 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
def nonlinear_matter_power(
cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=halofit
):
""" Computes the non-linear matter power spectrum.
"""Computes the non-linear matter power spectrum.
This function is just a wrapper over several nonlinear power spectra.
"""
Expand Down
13 changes: 4 additions & 9 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module defines kernel functions for various tracers
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np
from jax import jit
from jax import vmap
Expand Down Expand Up @@ -187,7 +183,7 @@ def noise(self):

@register_pytree_node_class
class NumberCounts(container):
""" Class representing a galaxy clustering probe, with a bunch of bins
"""Class representing a galaxy clustering probe, with a bunch of bins
Parameters:
-----------
Expand All @@ -214,14 +210,13 @@ def zmax(self):

@property
def n_tracers(self):
""" Returns the number of tracers for this probe, i.e. redshift bins
"""
"""Returns the number of tracers for this probe, i.e. redshift bins"""
# Extract parameters
pzs = self.params[0]
return len(pzs)

def kernel(self, cosmo, z, ell):
""" Compute the radial kernel for all nz bins in this probe.
"""Compute the radial kernel for all nz bins in this probe.
Returns:
--------
Expand All @@ -235,7 +230,7 @@ def kernel(self, cosmo, z, ell):
return kernel

def noise(self):
""" Returns the noise power for all redshifts
"""Returns the noise power for all redshifts
return: shape [nbins]
"""
# Extract parameters
Expand Down
27 changes: 9 additions & 18 deletions jax_cosmo/redshift.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# Module to define redshift distributions we can differentiate through
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from abc import ABC
from abc import abstractmethod

Expand All @@ -19,21 +15,18 @@

class redshift_distribution(container):
def __init__(self, *args, gals_per_arcmin2=1.0, zmax=10.0, **kwargs):
""" Initialize the parameters of the redshift distribution
"""
"""Initialize the parameters of the redshift distribution"""
self._norm = None
self._gals_per_arcmin2 = gals_per_arcmin2
super(redshift_distribution, self).__init__(*args, zmax=zmax, **kwargs)

@abstractmethod
def pz_fn(self, z):
""" Un-normalized n(z) function provided by sub classes
"""
"""Un-normalized n(z) function provided by sub classes"""
pass

def __call__(self, z):
""" Computes the normalized n(z)
"""
"""Computes the normalized n(z)"""
if self._norm is None:
self._norm = simps(lambda t: self.pz_fn(t), 0.0, self.config["zmax"], 256)
return self.pz_fn(z) / self._norm
Expand All @@ -44,15 +37,14 @@ def zmax(self):

@property
def gals_per_arcmin2(self):
""" Returns the number density of galaxies in gals/sq arcmin
"""Returns the number density of galaxies in gals/sq arcmin
TODO: find a better name
"""
return self._gals_per_arcmin2

@property
def gals_per_steradian(self):
""" Returns the number density of galaxies in steradian
"""
"""Returns the number density of galaxies in steradian"""
return self._gals_per_arcmin2 * steradian_to_arcmin2

# Operations for flattening/unflattening representation
Expand All @@ -69,7 +61,7 @@ def tree_unflatten(cls, aux_data, children):

@register_pytree_node_class
class smail_nz(redshift_distribution):
""" Defines a smail distribution with these arguments
"""Defines a smail distribution with these arguments
Parameters:
-----------
a:
Expand All @@ -88,7 +80,7 @@ def pz_fn(self, z):

@register_pytree_node_class
class kde_nz(redshift_distribution):
""" A redshift distribution based on a KDE estimate of the nz of a
"""A redshift distribution based on a KDE estimate of the nz of a
given catalog currently uses a Gaussian kernel.
TODO: add more if necessary
Expand All @@ -106,8 +98,7 @@ class kde_nz(redshift_distribution):
"""

def _kernel(self, bw, X, x):
""" Gaussian kernel for KDE
"""
"""Gaussian kernel for KDE"""
return (1.0 / np.sqrt(2 * np.pi) / bw) * np.exp(
-((X - x) ** 2) / (bw ** 2 * 2.0)
)
Expand All @@ -124,7 +115,7 @@ def pz_fn(self, z):

@register_pytree_node_class
class systematic_shift(redshift_distribution):
""" Implements a systematic shift in a redshift distribution
"""Implements a systematic shift in a redshift distribution
TODO: Find a better name for this
Arguments:
Expand Down
10 changes: 2 additions & 8 deletions jax_cosmo/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@
- :fun:`sparse_dot_sparse`
- :fun:`dense_dot_sparse_dot_dense`
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

import jax.numpy as np
Expand All @@ -39,14 +35,12 @@


def is_sparse(sparse):
"""Test if the input is interpretable as a sparse matrix.
"""
"""Test if the input is interpretable as a sparse matrix."""
return np.asarray(sparse).ndim == 3


def check_sparse(sparse, square=False):
"""Check for a valid sparse matrix.
"""
"""Check for a valid sparse matrix."""
sparse = np.asarray(sparse)
if sparse.ndim != 3:
raise ValueError("Expected 3D array of sparse diagonals.")
Expand Down
6 changes: 1 addition & 5 deletions jax_cosmo/transfer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# This module contains various transfer functions from the literatu
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import jax.numpy as np

import jax_cosmo.background as bkgrd
Expand All @@ -12,7 +8,7 @@


def Eisenstein_Hu(cosmo, k, type="eisenhu_osc"):
""" Computes the Eisenstein & Hu matter transfer function.
"""Computes the Eisenstein & Hu matter transfer function.
Parameters
----------
Expand Down
5 changes: 0 additions & 5 deletions jax_cosmo/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
# This defines a few utility functions
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


def z2a(z):
""" converts from redshift to scale factor """
return 1.0 / (1.0 + z)
Expand Down

0 comments on commit 60e51f2

Please sign in to comment.