Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
dkirkby committed Jul 23, 2020
1 parent d8eb788 commit 22a1f66
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions jax_cosmo/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,35 @@ def sparse_dot_sparse(sparse1, sparse2):
)(sparse1, sparse2)


@jit
def bilinear(X, Y, Z):
"""Calculate the bilinear form X @ Y @ Z where B is sparse.
Inputs must be jax numpy arrays. No error checking is performed.
Parameters
----------
X : array
2D array of shape (a, b * ndiag) with dense matrix elements.
Y : array
3D array of shape (b, c, ndiag) with sparse matrix elements.
Z : array
2D array of shape (c * ndiag, d) with dense matrix elements.
Returns
-------
array
2D array of shape (a, d) with dense matrix elements.
"""
return vmap(
vmap(
lambda row, sparse, col: np.dot(row, sparse_dot_vec(sparse, col)),
(None, None, 1),
),
(0, None, None),
)(X, Y, Z)


@jit
def inv(sparse):
"""Calculate the inverse of a square matrix in sparse format.
Expand Down Expand Up @@ -281,8 +310,8 @@ def _block_det(sparse, k, N, P):
u = sparse[k : k + 1, k + 1 : N, 0:P]
S = sparse[k + 1 : N, k + 1 : N, 0:P]
v = sparse[k + 1 : N, k : k + 1, 0:P]
Sinv_v = matmul(inv(S), v)
return np.product(sparse[k, k] - matmul(u, Sinv_v))
Sinv_v = sparse_dot_sparse(inv(S), v)
return np.product(sparse[k, k] - sparse_dot_sparse(u, Sinv_v))


@jit
Expand Down

0 comments on commit 22a1f66

Please sign in to comment.