Skip to content

Conversation

@comane
Copy link
Member

@comane comane commented Oct 1, 2024

Allows to extend the log likelihood function for wmin specific regularisation terms

An idea that allows for more flexibility is the following:

@partial(jax.jit, static_argnames=("self",))
    def log_likelihood(
        self,
        params,
        central_values,
        inv_covmat,
        fast_kernel_arrays,
        *penalty_funcs_and_args
    ):
        predictions, pdf = self.pred_and_pdf(params, fast_kernel_arrays)
        chi2_val = self.chi2(central_values, predictions, inv_covmat)
        
        penalty_sum = 0.0
        for penalty_func, penalty_args in penalty_funcs_and_args:
            penalty_sum += jnp.sum(penalty_func(*penalty_args), axis=-1)
        
        return -0.5 * (chi2_val + penalty_sum)

Note that to do so, penalty functions should follow some sort of logic. An example is that they could take as first argument the parameters (params).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants