Skip to content

Commit

Permalink
Re-implement sparse det in terms of slogdet
Browse files Browse the repository at this point in the history
  • Loading branch information
dkirkby committed Jul 25, 2020
1 parent ede457b commit ced8914
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 13 deletions.
43 changes: 35 additions & 8 deletions jax_cosmo/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,15 @@ def _block_det(sparse, k, N, P):
S = sparse[k + 1 : N, k + 1 : N, 0:P]
v = sparse[k + 1 : N, k : k + 1, 0:P]
Sinv_v = sparse_dot_sparse(inv(S), v)
return np.product(sparse[k, k] - sparse_dot_sparse(u, Sinv_v))
M = sparse[k, k] - sparse_dot_sparse(u, Sinv_v)
sign = np.product(np.sign(M))
logdet = np.sum(np.log(np.abs(M)))
return sign, logdet


@jit
def det(sparse):
"""Calculate the determinant of a sparse matrix.
def slogdet(sparse):
"""Calculate the log(determinant) of a sparse matrix.
Based on equation (2.2) of https://arxiv.org/abs/1112.4379
Expand All @@ -347,15 +350,39 @@ def det(sparse):
Returns
-------
float
Determinant result.
tuple
Tuple (sign, logdet) such that sign * exp(logdet) is the
determinant. If the determinant is zero, logdet = -inf.
"""
sparse = check_sparse(sparse, square=True)
N, _, P = sparse.shape
result = np.product(sparse[-1, -1])
sign = np.product(np.sign(sparse[-1, -1]))
logdet = np.sum(np.log(np.abs(sparse[-1, -1])))
# The individual blocks can be calculated in any order so there
# should be a better way to express this using lax.map but I
# can't get it to work without "concretization" errors.
for i in range(N - 1):
result *= _block_det(sparse, i, N, P)
return result
s, ld = _block_det(sparse, i, N, P)
sign *= s
logdet += ld
return sign, logdet


@jit
def det(sparse):
"""Calculate the determinant of a sparse matrix.
Uses :func:`slogdet`.
Parameters
----------
sparse : array
3D array of shape (ny, nx, ndiag) of block diagonal elements.
Returns
-------
float
Determinant result.
"""
sign, logdet = slogdet(sparse)
return sign * np.exp(logdet)
14 changes: 9 additions & 5 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,13 @@ def test_inv():


def test_det():
X = [
[[1, 2, 3], [4, 5, 6], [-1, 7, -2]],
[[1, 2, 3], [-4, -5, -6], [2, -3, 9]],
[[7, 8, 9], [5, -4, 6], [-3, -2, -1]],
]
X = np.array(
[
[[1, 2, 3], [4, 5, 6], [-1, 7, -2]],
[[1, 2, 3], [-4, -5, -6], [2, -3, 9]],
[[7, 8, 9], [5, -4, 6], [-3, -2, -1]],
]
)
assert_array_equal(-det(-X), det(X))
assert_array_equal(det(0.0 * X), 0.0)
assert_allclose(det(X), np.linalg.det(to_dense(X)), rtol=1e-6)

0 comments on commit ced8914

Please sign in to comment.