Skip to content

Commit

Permalink
requested modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
renecotyfanboy committed Apr 18, 2024
1 parent 1ad72bd commit c446151
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions src/chainconsumer/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import logging
from typing import TYPE_CHECKING, Any, TypeAlias

import arviz as az
import numpy as np
import pandas as pd
from pydantic import Field, field_validator, model_validator
Expand Down Expand Up @@ -397,7 +396,7 @@ def from_numpyro(
cls,
mcmc: numpyro.infer.MCMC,
name: str,
var_names: list[str] = [],
var_names: list[str] | None = None,
**kwargs: Any,
) -> Chain:
"""Constructor from numpyro samples
Expand All @@ -424,7 +423,7 @@ def from_arviz(
cls,
arviz_id: arviz.InferenceData,
name: str,
var_names: list[str] = [],
var_names: list[str] | None = None,
**kwargs: Any,
) -> Chain:
"""Constructor from an arviz InferenceData object
Expand All @@ -440,6 +439,8 @@ def from_arviz(
A ChainConsumer Chain made from the arviz chain
"""

import arviz as az

var_names = _filter_var_names(var_names, list(arviz_id.posterior.keys()))
reduced_id = az.extract(arviz_id, var_names=var_names, group="posterior")
df = reduced_id.to_dataframe().drop(columns=["chain", "draw"])
Expand All @@ -460,28 +461,29 @@ def vec_coordinate(self) -> np.ndarray:
return np.array(list(self.coordinate.values()))


def _filter_var_names(var_names: list[str], all_vars: list[str]):
def _filter_var_names(var_names: list[str] | None, all_vars: list[str]) -> list[str]:
"""
Helper function to return the var_names to allows filtering parameters names.
"""

if not var_names:
if var_names is None:
return all_vars

elif var_names:
if not (all([var.startswith("~") for var in var_names]) or all([not var.startswith("~") for var in var_names])):
raise ValueError(
"all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset"
)
negations = set([var.startswith("~") for var in var_names])

if len(negations) != 1:
raise ValueError(
"all values in var_names must start with ~ to exclude a subset OR none of them to keep a subset"
)

if all([var.startswith("~") for var in var_names]):
# remove the ~ from the var names
var_names = [var[1:] for var in var_names]
var_names = [var for var in all_vars if var not in var_names]
if True in negations:
# remove the ~ from the var names
var_names = [var[1:] for var in var_names]
var_names = [var for var in all_vars if var not in var_names]

return var_names
return var_names

else:
# keep var_names as is but check if var is in all_vars
var_names = [var for var in all_vars if var in var_names]
return var_names
else:
# keep var_names as is but check if var is in all_vars
var_names = [var for var in all_vars if var in var_names]
return var_names

0 comments on commit c446151

Please sign in to comment.