""" The core tools used in pyfstat """
import getpass
import glob
import logging
import os
import socket
from datetime import datetime
from pprint import pformat
from weakref import finalize
import lal
import lalpulsar
import numpy as np
import scipy.optimize
import scipy.special
import pyfstat.tcw_fstat_map_funcs as tcw
import pyfstat.utils as utils
from ._version import get_versions
plt = utils.safe_X_less_plt()
logger = logging.getLogger(__name__)
detector_colors = {"h1": "C0", "l1": "C1"}
class BaseSearchClass:
"""The base class providing parent methods to other PyFstat classes.
This does not actually have any 'search' functionality,
which needs to be added by child classes
along with full initialization and any other custom methods.
"""
binary_keys = ["asini", "period", "ecc", "tp", "argp"]
"""List of extra parameters for sources in binaries."""
default_search_keys = [
"F0",
"F1",
"F2",
"Alpha",
"Delta",
]
"""Default order of the traditionally supported search parameter names.
FIXME: these are only used as fallbacks for the deprecated style
of passing keys one by one;
not needed when using the new parameters dictionary.
"""
tex_labels = {
# standard Doppler parameters
"F0": r"$f$",
"F1": r"$\dot{f}$",
"F2": r"$\ddot{f}$",
"F3": r"$\dddot{f}$",
"Alpha": r"$\alpha$",
"Delta": r"$\delta$",
# binary parameters
"asini": r"$\mathrm{asin}\,i$",
"period": r"$P$",
"ecc": r"$\mathrm{ecc}$",
"tp": r"$t_p$",
"argp": r"$\mathrm{argp}$",
# transient parameters
"transient_tstart": r"$t_\mathrm{start}$",
"transient_duration": r"$\Delta T$",
# glitch parameters
"delta_F0": r"$\delta f$",
"delta_F1": r"$\delta \dot{f}$",
"tglitch": r"$t_\mathrm{glitch}$",
# detection statistics
"twoF": r"$\widetilde{2\mathcal{F}}$",
"maxTwoF": r"$\max\widetilde{2\mathcal{F}}$",
"log10BSGL": r"$\log_{10}\mathcal{B}_{\mathrm{SGL}}$",
"lnBtSG": r"$\ln\mathcal{B}_{\mathrm{tS/G}}$",
}
"""Formatted labels used for plot annotations."""
unit_dictionary = dict(
# standard Doppler parameters
F0=r"Hz",
F1=r"Hz/s",
F2=r"Hz/s$^2$",
F3=r"Hz/s$^3$",
Alpha=r"rad",
Delta=r"rad",
# binary parameters
asini="",
period=r"s",
ecc="",
tp=r"s",
argp="",
# transient parameters
transient_tstart=r"s",
transient_duration=r"s",
# glitch parameters
delta_F0=r"Hz",
delta_F1=r"Hz/s",
tglitch=r"s",
)
"""Units for standard parameters."""
fmt_detstat = "%.9g"
"""Standard output precision for detection statistics."""
fmt_doppler = "%.16g"
"""Standard output precision for Doppler (frequency evolution) parameters."""
def __new__(cls, *args, **kwargs):
logger.info(f"Creating {cls.__name__} object...")
instance = super().__new__(cls)
return instance
def _get_list_of_matching_sfts(self):
"""Returns a list of sfts matching the attribute sftfilepattern"""
sftfilepatternlist = np.atleast_1d(self.sftfilepattern.split(";"))
matches = [glob.glob(p) for p in sftfilepatternlist]
matches = [item for sublist in matches for item in sublist]
if len(matches) > 0:
return matches
else: # pragma: no cover
raise IOError("No sfts found matching {}".format(self.sftfilepattern))
def tex_label0(self, key):
"""Formatted labels used for annotating central values in plots."""
label = self.tex_labels[key].strip("$")
return f"${label} - {label}_0$"
def set_ephemeris_files(self, earth_ephem=None, sun_ephem=None):
"""Set the ephemeris files to use for the Earth and Sun.
NOTE: If not given explicit arguments,
default values from utils.get_ephemeris_files()
are used.
Parameters
----------
earth_ephem, sun_ephem: str
Paths of the two files containing positions of Earth and Sun,
respectively at evenly spaced times, as passed to CreateFstatInput
"""
earth_ephem_default, sun_ephem_default = utils.get_ephemeris_files()
self.earth_ephem = earth_ephem or earth_ephem_default
self.sun_ephem = sun_ephem or sun_ephem_default
def _set_init_params_dict(self, argsdict):
"""Store the initial input arguments, e.g. for logging output."""
argsdict.pop("self")
self.init_params_dict = argsdict
def pprint_init_params_dict(self):
"""Pretty-print a parameters dictionary for output file headers.
Returns
-------
pretty_init_parameters: list
A list of lines to be printed,
including opening/closing "{" and "}",
consistent indentation,
as well as end-of-line commas,
but no comment markers at start of lines.
"""
pretty_init_parameters = pformat(
self.init_params_dict, indent=2, width=74
).split("\n")
pretty_init_parameters = (
["{"]
+ [pretty_init_parameters[0].replace("{", " ")]
+ pretty_init_parameters[1:-1]
+ [pretty_init_parameters[-1].rstrip("}")]
+ ["}"]
)
return pretty_init_parameters
def get_output_file_header(self):
"""Constructs a meta-information header for text output files.
This will include
PyFstat and LALSuite versioning,
information about when/where/how the code was run,
and input parameters of the instantiated class.
Returns
-------
header: list
A list of formatted header lines.
"""
header = [
"date: {}".format(str(datetime.now())),
"user: {}".format(getpass.getuser()),
"hostname: {}".format(socket.gethostname()),
"PyFstat: {}".format(get_versions()["version"]),
]
lalVCSinfo = lal.VCSInfoString(lalpulsar.PulsarVCSInfoList, 0, "")
header += filter(None, lalVCSinfo.split("\n"))
header += [
"search: {}".format(type(self).__name__),
"parameters: ",
]
header += self.pprint_init_params_dict()
return header
def read_par(
self, filename=None, label=None, outdir=None, suffix="par", raise_error=True
):
"""Read a `key=val` file and return a dictionary.
Parameters
----------
filename: str or None
Filename (path) containing rows of `key=val` data to read in.
label, outdir, suffix : str or None
If filename is None, form the file to read as `outdir/label.suffix`.
raise_error : bool
If True, raise an error for lines which are not comments,
but cannot be read.
Returns
-------
params_dict: dict
A dictionary of the parsed `key=val` pairs.
"""
params_dict = utils.read_par(
filename=filename,
label=label or getattr(self, "label", None),
outdir=outdir or getattr(self, "outdir", None),
suffix=suffix,
raise_error=raise_error,
)
return params_dict
@staticmethod
def translate_keys_to_lal(dictionary):
"""Convert input keys into lalpulsar convention.
In PyFstat's convention, input keys (search parameter names)
are F0, F1, F2, ...,
while lalpulsar functions prefer to use Freq, f1dot, f2dot, ....
Since lalpulsar keys are only used internally to call lalpulsar routines,
this function is provided so the keys can be translated on the fly.
Parameters
----------
dictionary: dict
Dictionary to translate. A copy will be made (and returned)
before translation takes place.
Returns
-------
translated_dict: dict
Copy of "dictionary" with new keys according to lalpulsar convention.
"""
translation = {
"F0": "Freq",
"F1": "f1dot",
"F2": "f2dot",
"phi": "phi0",
"tref": "refTime",
"asini": "orbitasini",
"period": "orbitPeriod",
"tp": "orbitTp",
"argp": "orbitArgp",
"ecc": "orbitEcc",
"transient_tstart": "transient-t0Epoch",
"transient_duration": "transient-tau",
}
keys_to_translate = [key for key in dictionary.keys() if key in translation]
translated_dict = dictionary.copy()
for key in keys_to_translate:
translated_dict[translation[key]] = translated_dict.pop(key)
return translated_dict
class ComputeFstat(BaseSearchClass):
"""Base search class providing an interface to `lalpulsar.ComputeFstat`.
In most cases, users should be using one of the higher-level search classes
from the grid_based_searches or mcmc_based_searches modules instead.
See the lalpulsar documentation at https://lscsoft.docs.ligo.org/lalsuite/lalpulsar/group___compute_fstat__h.html
and R. Prix, The F-statistic and its implementation in ComputeFstatistic_v2 ( https://dcc.ligo.org/T0900149/public )
for details of the lalpulsar module and the meaning of various technical concepts
as embodied by some of the class's parameters.
Normally this will read in existing data through the `sftfilepattern` argument,
but if that option is `None` and the necessary alternative arguments are used,
it can also generate simulated data (including noise and/or signals) on the fly.
NOTE that the detection statistics that can be computed from an instance of this class
depend on the `BSGL`, `BtSG` and `transientWindowType` arguments given at initialisation.
See `get_fullycoherent_detstat()` and `get_transient_detstats()` for details.
To change what you want to compute,
you may need to initialise a new instance with different options.
NOTE for GPU users (`tCWFstatMapVersion="pycuda"`):
This class tries to conveniently deal with GPU context management behind the scenes.
A known problematic case is if you try to instantiate it twice from the same
session/script. If you then get some messages like
`RuntimeError: make_default_context()`
and `invalid device context`,
that is because the GPU is still blocked from the first instance when
you try to initiate the second.
To avoid this problem, use context management::
with pyfstat.ComputeFstat(
[...],
tCWFstatMapVersion="pycuda",
) as search:
search.get_fullycoherent_detstat([...])
or manually call the `search.finalizer_()` method where needed.
"""
@utils.initializer
def __init__(
self,
tref,
sftfilepattern=None,
minStartTime=None,
maxStartTime=None,
Tsft=1800,
binary=False,
singleFstats=False,
BSGL=False,
BtSG=False,
transientWindowType=None,
t0Band=None,
tauBand=None,
tauMin=None,
dt0=None,
dtau=None,
detectors=None,
minCoverFreq=None,
maxCoverFreq=None,
search_ranges=None,
injectSources=None,
injectSqrtSX=None,
randSeed=None,
assumeSqrtSX=None,
SSBprec=None,
RngMedWindow=None,
tCWFstatMapVersion="lal",
cudaDeviceName=None,
computeAtoms=False,
earth_ephem=None,
sun_ephem=None,
allowedMismatchFromSFTLength=None,
):
"""
Parameters
----------
tref : int
GPS seconds of the reference time.
sftfilepattern : str
Pattern to match SFTs using wildcards (`*?`) and ranges [0-9];
multiple patterns can be given separated by colons.
minStartTime, maxStartTime : int
Only use SFTs with timestamps starting from within this range,
following the XLALCWGPSinRange convention:
half-open intervals [minStartTime,maxStartTime].
Tsft: int
SFT duration in seconds.
Only required if `sftfilepattern=None` and hence simulted data is
generated on the fly.
binary : bool
If true, search over binary parameters.
singleFstats : bool
If true, also compute the single-detector twoF values.
BSGL : bool
If true, compute the log10BSGL statistic rather than the twoF value.
For details, see Keitel et al (PRD 89, 064023, 2014):
https://arxiv.org/abs/1311.5738
Note this automatically sets `singleFstats=True` as well.
Tuning parameters are currently hardcoded:
* `Fstar0=15` for coherent searches.
* A p-value of 1e-6 and correspondingly recalculated Fstar0
for semicoherent searches.
* Uniform per-detector prior line-vs-Gaussian odds.
BtSG: bool
If true and `transientWindowType` is not `None`,
compute the transient
:math:`\\ln\\mathcal{B}_{\\mathrm{tS}/\\mathrm{G}}`
statistic from Prix, Giampanis & Messenger (PRD 84, 023007, 2011)
(tCWFstatMap marginalised over uniform t0, tau priors).
rather than the maxTwoF value.
transientWindowType: str
If `rect` or `exp`,
allow for the Fstat to be computed over a transient range.
(`none` instead of `None` explicitly calls the transient-window
function, but with the full range, for debugging.)
(If not None, will also force atoms regardless of computeAtoms option.)
t0Band, tauBand: int
Search ranges for transient start-time t0 and duration tau.
If >0, search t0 in (minStartTime,minStartTime+t0Band)
and tau in (tauMin,2*Tsft+tauBand).
If =0, only compute the continuous-wave Fstat with t0=minStartTime,
tau=maxStartTime-minStartTime.
tauMin: int
Minimum transient duration to cover,
defaults to 2*Tsft.
dt0: int
Grid resolution in transient start-time,
defaults to Tsft.
dtau: int
Grid resolution in transient duration,
defaults to Tsft.
detectors : str
Two-character references to the detectors for which to use data.
Specify `None` for no constraint.
For multiple detectors, separate by commas.
minCoverFreq, maxCoverFreq : float
The min and max cover frequency passed to lalpulsar.CreateFstatInput.
For negative values, these will be used as offsets from the min/max
frequency contained in the sftfilepattern.
If either is `None`, the search_ranges argument is used to estimate them.
If the automatic estimation fails and you do not have a good idea
what to set these two options to, setting both to -0.5 will
reproduce the default behaviour of PyFstat <=1.4 and may be a
reasonably safe fallback in many cases.
search_ranges: dict
Dictionary of ranges in all search parameters,
only used to estimate frequency band passed to lalpulsar.CreateFstatInput,
if minCoverFreq, maxCoverFreq are not specified (==`None`).
For actually running searches,
grids/points will have to be passed separately to the .run() method.
The entry for each parameter must be a list of length 1, 2 or 3:
[single_value], [min,max] or [min,max,step].
injectSources : dict or str
Either a dictionary of the signal parameters to inject,
or a string pointing to a .cff file defining a signal.
injectSqrtSX : float or list or str
Single-sided PSD values for generating fake Gaussian noise on the fly.
Single float or str value: use same for all IFOs.
List or comma-separated string: must match len(detectors)
and/or the data in sftfilepattern.
Detectors will be paired to list elements following alphabetical order.
randSeed : int or None
random seed for on-the-fly noise generation using `injectSqrtSX`.
Setting this to 0 or None is equivalent; both will randomise the seed,
following the behaviour of XLALAddGaussianNoise(),
while any number not equal to 0 will produce a reproducible noise realisation.
assumeSqrtSX : float or list or str
Don't estimate noise-floors but assume this (stationary) single-sided PSD.
Single float or str value: use same for all IFOs.
List or comma-separated string: must match len(detectors)
and/or the data in sftfilepattern.
Detectors will be paired to list elements following alphabetical order.
If working with signal-only data, please set assumeSqrtSX=1 .
SSBprec : int
Flag to set the Solar System Barycentring (SSB) calculation in lalpulsar:
0=Newtonian, 1=relativistic,
2=relativistic optimised, 3=DMoff, 4=NO_SPIN
RngMedWindow : int
Running-Median window size for F-statistic noise normalization
(number of SFT bins).
tCWFstatMapVersion: str
Choose between implementations of the transient F-statistic functionality:
standard `lal` implementation,
`pycuda` for GPU version,
and some others only for devel/debug.
cudaDeviceName: str
GPU name to be matched against drv.Device output,
only for `tCWFstatMapVersion=pycuda`.
computeAtoms: bool
Request calculation of 'F-statistic atoms' regardless of transientWindowType.
earth_ephem: str
Earth ephemeris file path.
If None, will check standard sources as per
utils.get_ephemeris_files().
sun_ephem: str
Sun ephemeris file path.
If None, will check standard sources as per
utils.get_ephemeris_files().
allowedMismatchFromSFTLength: float
Maximum allowed mismatch from SFTs being too long
[Default: what's hardcoded in XLALFstatMaximumSFTLength]
"""
self._setup_finalizer()
self._set_init_params_dict(locals())
self.set_ephemeris_files(earth_ephem, sun_ephem)
self.init_computefstatistic()
self.output_file_header = self.get_output_file_header()
self.get_det_stat = self.get_fullycoherent_detstat
self.allowedMismatchFromSFTLength = allowedMismatchFromSFTLength
def _setup_finalizer(self):
"""
Setup for proper cleanup at end of context in pycuda case.
Users should normally *not* have to call self._finalizer() manually:
the `finalize` call is enough to set up python garbage collection,
and we only store it as an attribute for debugging/testing purposes.
However, if one wants to initialise two or more of these objects from
a single script, one has to manually clean up after each one by either
using context management or calling the `.finalizer_()` method.
"""
if "cuda" in self.tCWFstatMapVersion:
logger.debug(
f"Setting up GPU context finalizer for {self.tCWFstatMapVersion} transient maps."
)
self._finalizer = finalize(self, self._finalize_gpu_context)
def _finalize_gpu_context(self):
"""Clean up at the end of context manager style usage."""
logger.debug("Leaving the ComputeFStat context...")
if hasattr(self, "gpu_context") and self.gpu_context:
logger.debug("Detaching GPU context...")
# this is needed because we use pyCuda without autoinit
self.gpu_context.detach()
def __enter__(self):
"""Enables context manager style calling."""
logger.debug("Entering the ComputeFstat context...")
return self
def __exit__(self, *args, **kwargs):
"""Clean up at the end of context manager style usage."""
logger.debug("Leaving the ComputeFStat context...")
if "cuda" in self.tCWFstatMapVersion:
self._finalizer()
def _get_SFTCatalog(self):
"""Load the SFTCatalog
If sftfilepattern is specified, load the data. If not, attempt to
create data on the fly.
"""
if hasattr(self, "SFTCatalog"):
logger.info("Already have SFTCatalog.")
return
if self.sftfilepattern is None:
logger.info("No sftfilepattern given, making fake SFTCatalog.")
for k in ["minStartTime", "maxStartTime", "detectors"]:
if getattr(self, k) is None:
raise ValueError(
"If sftfilepattern==None, you must provide" " '{}'.".format(k)
)
C1 = getattr(self, "injectSources", None) is None
C2 = getattr(self, "injectSqrtSX", None) is None
C3 = getattr(self, "Tsft", None) is None
if (C1 and C2) or C3:
raise ValueError(
"If sftfilepattern==None, you must specify Tsft and"
" either one of injectSources or injectSqrtSX."
)
SFTCatalog = lalpulsar.SFTCatalog()
Toverlap = 0
self.detector_names = self.detectors.split(",")
self.numDetectors = len(self.detector_names)
detNames = lal.CreateStringVector(*[d for d in self.detector_names])
# MakeMultiTimestamps follows the same [minStartTime,maxStartTime)
# convention as the SFT library, so we can pass Tspan like this
Tspan = self.maxStartTime - self.minStartTime
multiTimestamps = lalpulsar.MakeMultiTimestamps(
self.minStartTime, Tspan, self.Tsft, Toverlap, detNames.length
)
SFTCatalog = lalpulsar.MultiAddToFakeSFTCatalog(
SFTCatalog, detNames, multiTimestamps
)
self.SFTCatalog = SFTCatalog
return
logger.info("Initialising SFTCatalog from sftfilepattern.")
constraints = lalpulsar.SFTConstraints()
constr_str = []
if self.detectors:
if "," in self.detectors:
logger.warning(
"Multiple-detector constraints not available,"
" using all available data."
)
else:
constraints.detector = self.detectors
constr_str.append("detector=" + constraints.detector)
if self.minStartTime:
constraints.minStartTime = lal.LIGOTimeGPS(self.minStartTime)
constr_str.append("minStartTime={}".format(self.minStartTime))
if self.maxStartTime:
constraints.maxStartTime = lal.LIGOTimeGPS(self.maxStartTime)
constr_str.append("maxStartTime={}".format(self.maxStartTime))
logger.info(
"Loading data matching SFT file name pattern '{}'"
" with constraints {}.".format(self.sftfilepattern, ", ".join(constr_str))
)
self.SFTCatalog = lalpulsar.SFTdataFind(self.sftfilepattern, constraints)
if self.SFTCatalog.length == 0:
raise IOError("No SFTs found.")
Tsft_from_catalog = int(1.0 / self.SFTCatalog.data[0].header.deltaF)
if Tsft_from_catalog != self.Tsft:
logger.info(
"Overwriting pre-set Tsft={:d} with {:d} obtained from SFTs.".format(
self.Tsft, Tsft_from_catalog
)
)
self.Tsft = Tsft_from_catalog
# NOTE: in multi-IFO case, this will be a joint list of timestamps
# over all IFOs, probably sorted and not cleaned for uniqueness.
SFT_timestamps = [d.header.epoch for d in self.SFTCatalog.data]
self.SFT_timestamps = [float(s) for s in SFT_timestamps]
if len(SFT_timestamps) == 0:
raise ValueError("Failed to load any data")
dtstr1 = utils.gps_to_datestr_utc(int(SFT_timestamps[0]))
dtstr2 = utils.gps_to_datestr_utc(int(SFT_timestamps[-1]))
logger.info(
f"Data contains SFT timestamps from {SFT_timestamps[0]} ({dtstr1})"
f" to (including) {SFT_timestamps[-1]} ({dtstr2})."
)
if self.minStartTime is None:
self.minStartTime = int(SFT_timestamps[0])
if self.maxStartTime is None:
# XLALCWGPSinRange() convention: half-open intervals,
# maxStartTime must always be > last actual SFT timestamp
self.maxStartTime = int(SFT_timestamps[-1]) + self.Tsft
self.detector_names = list(set([d.header.name for d in self.SFTCatalog.data]))
self.numDetectors = len(self.detector_names)
if self.numDetectors == 0:
raise ValueError("No data loaded.")
logger.info(
"Loaded {} SFTs from {} detectors: {}".format(
len(SFT_timestamps), self.numDetectors, self.detector_names
)
)
def init_computefstatistic(self):
"""Initialization step for the F-stastic computation internals.
This sets up the special input and output structures the lalpulsar module needs,
the ephemerides,
optional on-the-fly signal injections,
and extra options for multi-detector consistency checks and transient searches.
All inputs are taken from the pre-initialized object,
so this function does not have additional arguments of its own.
"""
self._get_SFTCatalog()
# some sanity checks on user options
if self.BSGL: # pragma: no cover
if len(self.detector_names) < 2:
raise ValueError("Can't use BSGL with single detector data")
if getattr(self, "BtSG", False):
raise ValueError("Please choose only one of [BSGL,BtSG].")
logger.info("Initialising ephems")
ephems = lalpulsar.InitBarycenter(self.earth_ephem, self.sun_ephem)
logger.info("Initialising Fstat arguments")
dFreq = 0
self.whatToCompute = lalpulsar.FSTATQ_2F
if self.transientWindowType or self.computeAtoms:
self.whatToCompute += lalpulsar.FSTATQ_ATOMS_PER_DET
FstatOAs = lalpulsar.FstatOptionalArgs()
if self.SSBprec:
logger.info("Using SSBprec={}".format(self.SSBprec))
FstatOAs.SSBprec = self.SSBprec
else:
FstatOAs.SSBprec = lalpulsar.FstatOptionalArgsDefaults.SSBprec
FstatOAs.Dterms = lalpulsar.FstatOptionalArgsDefaults.Dterms
if self.RngMedWindow:
FstatOAs.runningMedianWindow = self.RngMedWindow
else:
FstatOAs.runningMedianWindow = (
lalpulsar.FstatOptionalArgsDefaults.runningMedianWindow
)
FstatOAs.FstatMethod = lalpulsar.FstatOptionalArgsDefaults.FstatMethod
if self.assumeSqrtSX is None:
FstatOAs.assumeSqrtSX = lalpulsar.FstatOptionalArgsDefaults.assumeSqrtSX
else:
mnf = lalpulsar.MultiNoiseFloor()
assumeSqrtSX = utils.parse_list_of_numbers(self.assumeSqrtSX)
mnf.sqrtSn[: len(assumeSqrtSX)] = assumeSqrtSX
mnf.length = len(assumeSqrtSX)
FstatOAs.assumeSqrtSX = mnf
FstatOAs.prevInput = lalpulsar.FstatOptionalArgsDefaults.prevInput
FstatOAs.collectTiming = lalpulsar.FstatOptionalArgsDefaults.collectTiming
if self.allowedMismatchFromSFTLength:
FstatOAs.allowedMismatchFromSFTLength = self.allowedMismatchFromSFTLength
if hasattr(self, "injectSources") and isinstance(self.injectSources, dict):
logger.info("Injecting source with params: {}".format(self.injectSources))
PPV = lalpulsar.CreatePulsarParamsVector(1)
PP = PPV.data[0]
h0 = self.injectSources["h0"]
cosi = self.injectSources["cosi"]
use_aPlus = "aPlus" in dir(PP.Amp)
if use_aPlus: # lalsuite interface changed in aff93c45
PP.Amp.aPlus = 0.5 * h0 * (1.0 + cosi**2)
PP.Amp.aCross = h0 * cosi
else:
PP.Amp.h0 = h0
PP.Amp.cosi = cosi
PP.Amp.phi0 = self.injectSources["phi"]
PP.Amp.psi = self.injectSources["psi"]
PP.Doppler.Alpha = self.injectSources["Alpha"]
PP.Doppler.Delta = self.injectSources["Delta"]
if "fkdot" in self.injectSources:
PP.Doppler.fkdot = np.array(self.injectSources["fkdot"])
else:
PP.Doppler.fkdot = np.zeros(lalpulsar.PULSAR_MAX_SPINS)
for i, key in enumerate(["F0", "F1", "F2"]):
PP.Doppler.fkdot[i] = self.injectSources[key]
PP.Doppler.refTime = self.tref
if "t0" not in self.injectSources:
PP.Transient.type = lalpulsar.TRANSIENT_NONE
FstatOAs.injectSources = PPV
elif hasattr(self, "injectSources") and isinstance(self.injectSources, str):
logger.info(
"Injecting source from param file: {}".format(self.injectSources)
)
PPV = lalpulsar.PulsarParamsFromFile(self.injectSources, self.tref)
FstatOAs.injectSources = PPV
else:
FstatOAs.injectSources = lalpulsar.FstatOptionalArgsDefaults.injectSources
if hasattr(self, "injectSqrtSX") and self.injectSqrtSX is not None:
self.injectSqrtSX = utils.parse_list_of_numbers(self.injectSqrtSX)
if len(self.injectSqrtSX) != len(self.detector_names):
raise ValueError(
"injectSqrtSX must be of same length as detector_names ({}!={})".format(
len(self.injectSqrtSX), len(self.detector_names)
)
)
FstatOAs.injectSqrtSX = lalpulsar.MultiNoiseFloor()
FstatOAs.injectSqrtSX.length = len(self.injectSqrtSX)
FstatOAs.injectSqrtSX.sqrtSn[: FstatOAs.injectSqrtSX.length] = (
self.injectSqrtSX
)
else:
FstatOAs.injectSqrtSX = lalpulsar.FstatOptionalArgsDefaults.injectSqrtSX
# Here we are treating 0 and None as equivalent
# (use default, which is 0 and means "randomise the seed").
# See XLALAddGaussianNoise().
FstatOAs.randSeed = (
getattr(self, "randSeed", None)
or lalpulsar.FstatOptionalArgsDefaults.randSeed
)
self._set_min_max_cover_freqs()
logger.info("Initialising FstatInput")
self.FstatInput = lalpulsar.CreateFstatInput(
self.SFTCatalog,
self.minCoverFreq,
self.maxCoverFreq,
dFreq,
ephems,
FstatOAs,
)
logger.info("Initialising PulsarDoplerParams")
PulsarDopplerParams = lalpulsar.PulsarDopplerParams()
PulsarDopplerParams.refTime = self.tref
PulsarDopplerParams.fkdot = np.zeros(lalpulsar.PULSAR_MAX_SPINS)
self.PulsarDopplerParams = PulsarDopplerParams
logger.info("Initialising FstatResults")
self.FstatResults = lalpulsar.FstatResults()
# always initialise the twoFX array,
# but only actually compute it if requested
self.twoF = 0
self.twoFX = np.zeros(lalpulsar.PULSAR_MAX_DETECTORS)
self.singleFstats = self.singleFstats or self.BSGL # BSGL implies twoFX
if self.singleFstats:
self.whatToCompute += lalpulsar.FSTATQ_2F_PER_DET
if self.BSGL:
logger.info("Initialising BSGL")
self.log10BSGL = np.nan
# Tuning parameters - to be reviewed
# We use a fixed Fstar0 for coherent searches,
# and recompute it from a fixed p-value for the semicoherent case.
nsegs_eff = max([getattr(self, "nsegs", 1), getattr(self, "nglitch", 1)])
if nsegs_eff > 1:
p_val_threshold = 1e-6
Fstar0s = np.linspace(0, 1000, 10000)
p_vals = scipy.special.gammaincc(2 * nsegs_eff, Fstar0s)
self.Fstar0 = Fstar0s[np.argmin(np.abs(p_vals - p_val_threshold))]
if self.Fstar0 == Fstar0s[-1]:
raise ValueError("Max Fstar0 exceeded")
else:
self.Fstar0 = 15.0
logger.info("Using Fstar0 of {:1.2f}".format(self.Fstar0))
# assume uniform per-detector prior line-vs-Gaussian odds
self.oLGX = np.zeros(lalpulsar.PULSAR_MAX_DETECTORS)
self.oLGX[: self.numDetectors] = 1.0 / self.numDetectors
self.BSGLSetup = lalpulsar.CreateBSGLSetup(
numDetectors=self.numDetectors,
Fstar0sc=self.Fstar0,
oLGX=self.oLGX,
useLogCorrection=True,
numSegments=getattr(self, "nsegs", 1),
)
if self.transientWindowType:
logger.info(
f"Initialising transient parameters for window type '{self.transientWindowType}'"
)
self.maxTwoF = 0
if getattr(self, "BtSG", False):
self.lnBtSG = np.nan
self.windowRange = lalpulsar.transientWindowRange_t()
transientWindowTypes = {
"none": lalpulsar.TRANSIENT_NONE,
"rect": lalpulsar.TRANSIENT_RECTANGULAR,
"exp": lalpulsar.TRANSIENT_EXPONENTIAL,
}
if self.transientWindowType in transientWindowTypes:
self.windowRange.type = transientWindowTypes[self.transientWindowType]
else:
raise ValueError(
"Unknown window-type ({}) passed as input, [{}] allows.".format(
self.transientWindowType, ", ".join(transientWindowTypes)
)
)
# default spacing
self.windowRange.dt0 = self.Tsft
self.windowRange.dtau = self.Tsft
# special treatment of window_type = none
# ==> replace by rectangular window spanning all the data
if self.windowRange.type == lalpulsar.TRANSIENT_NONE:
self.windowRange.t0 = int(self.minStartTime)
self.windowRange.t0Band = 0
self.windowRange.tau = int(self.maxStartTime - self.minStartTime)
self.windowRange.tauBand = 0
else: # user-set bands and spacings
if getattr(self, "t0Band", None) is None:
self.windowRange.t0Band = 0
else:
if not isinstance(self.t0Band, int):
logger.warning(
"Casting non-integer t0Band={} to int...".format(
self.t0Band
)
)
self.t0Band = int(self.t0Band)
self.windowRange.t0Band = self.t0Band
if self.dt0:
self.windowRange.dt0 = self.dt0
if getattr(self, "tauBand", None) is None:
self.windowRange.tauBand = 0
else:
if not isinstance(self.tauBand, int):
logger.warning(
"Casting non-integer tauBand={} to int...".format(
self.tauBand
)
)
self.tauBand = int(self.tauBand)
self.windowRange.tauBand = self.tauBand
if self.dtau:
self.windowRange.dtau = self.dtau
if self.tauMin is None:
self.windowRange.tau = int(2 * self.Tsft)
else:
if not isinstance(self.tauMin, int):
logger.warning(
"Casting non-integer tauMin={} to int...".format(
self.tauMin
)
)
self.tauMin = int(self.tauMin)
self.windowRange.tau = self.tauMin
logger.info("Initialising transient FstatMap features...")
(
self.tCWFstatMapFeatures,
self.gpu_context,
) = tcw.init_transient_fstat_map_features(
self.tCWFstatMapVersion, self.cudaDeviceName
)
if self.BSGL:
self.twoFXatMaxTwoF = np.zeros(lalpulsar.PULSAR_MAX_DETECTORS)
def _set_min_max_cover_freqs(self):
# decide on which minCoverFreq and maxCoverFreq to use:
# either from direct user input, estimate_min_max_CoverFreq(), or SFTs
if self.sftfilepattern is not None:
minFreq_SFTs, maxFreq_SFTs = self._get_min_max_freq_from_SFTCatalog()
if (self.minCoverFreq is None) != (self.maxCoverFreq is None):
raise ValueError(
"Please use either both or none of [minCoverFreq,maxCoverFreq]."
)
elif (
self.minCoverFreq is None
and self.maxCoverFreq is None
and self.search_ranges is None
):
raise ValueError(
"Please use either search_ranges or both of [minCoverFreq,maxCoverFreq]."
)
elif self.minCoverFreq is None or self.maxCoverFreq is None:
logger.info(
"[minCoverFreq,maxCoverFreq] not provided, trying to estimate"
" from search ranges."
)
self.estimate_min_max_CoverFreq()
elif (self.minCoverFreq < 0.0) or (self.maxCoverFreq < 0.0):
if self.sftfilepattern is None:
raise ValueError(
"If sftfilepattern==None, cannot use negative values for"
" minCoverFreq or maxCoverFreq (interpreted as offsets from"
" min/max SFT frequency)."
" Please use actual frequency values for both,"
" or set both to None (automated estimation)."
)
if self.minCoverFreq < 0.0:
logger.info(
"minCoverFreq={:f} provided, using as offset from min(SFTs).".format(
self.minCoverFreq
)
)
# to set *above* min, since minCoverFreq is negative: subtract it
self.minCoverFreq = minFreq_SFTs - self.minCoverFreq
if self.maxCoverFreq < 0.0:
logger.info(
"maxCoverFreq={:f} provided, using as offset from max(SFTs).".format(
self.maxCoverFreq
)
)
# to set *below* max, since minCoverFreq is negative: add it
self.maxCoverFreq = maxFreq_SFTs + self.maxCoverFreq
if (self.sftfilepattern is not None) and (
(self.minCoverFreq < minFreq_SFTs) or (self.maxCoverFreq > maxFreq_SFTs)
):
raise ValueError(
"[minCoverFreq,maxCoverFreq]=[{:f},{:f}] Hz incompatible with"
" SFT files content [{:f},{:f}] Hz".format(
self.minCoverFreq, self.maxCoverFreq, minFreq_SFTs, maxFreq_SFTs
)
)
logger.info(
"Using minCoverFreq={} and maxCoverFreq={}.".format(
self.minCoverFreq, self.maxCoverFreq
)
)
def _get_min_max_freq_from_SFTCatalog(self):
fAs = [d.header.f0 for d in self.SFTCatalog.data]
minFreq_SFTs = np.min(fAs)
fBs = [
d.header.f0 + (d.numBins - 1) * d.header.deltaF
for d in self.SFTCatalog.data
]
maxFreq_SFTs = np.max(fBs)
return minFreq_SFTs, maxFreq_SFTs
def estimate_min_max_CoverFreq(self):
"""Extract spanned spin-range at reference -time from the template bank.
To use this method, self.search_ranges must be a dictionary of lists per search parameter
which can be either [single_value], [min,max] or [min,max,step].
"""
if type(self.search_ranges) is not dict:
raise ValueError("Need a dictionary for search_ranges!")
range_keys = list(self.search_ranges.keys())
required_keys = ["Alpha", "Delta", "F0"]
if len(np.setdiff1d(required_keys, range_keys)) > 0:
raise ValueError(
"Required keys not found in search_ranges: {}".format(
np.setdiff1d(required_keys, range_keys)
)
)
for key in range_keys:
if (
type(self.search_ranges[key]) is not list
or len(self.search_ranges[key]) == 0
or len(self.search_ranges[key]) > 3
):
raise ValueError(
"search_ranges entry for {:s}"
" is not a list of a known format"
" (either [single_value], [min,max]"
" or [min,max,step]): {}".format(key, self.search_ranges[key])
)
# start by constructing a DopplerRegion structure
# which will be needed to conservatively account for sky-position dependent
# Doppler shifts of the frequency range to be covered
searchRegion = lalpulsar.DopplerRegion()
# sky region
Alpha = self.search_ranges["Alpha"][0]
AlphaBand = (
self.search_ranges["Alpha"][1] - Alpha
if len(self.search_ranges["Alpha"]) >= 2
else 0.0
)
Delta = self.search_ranges["Delta"][0]
DeltaBand = (
self.search_ranges["Delta"][1] - Delta
if len(self.search_ranges["Delta"]) >= 2
else 0.0
)
searchRegion.skyRegionString = lalpulsar.SkySquare2String(
Alpha,
Delta,
AlphaBand,
DeltaBand,
)
searchRegion.refTime = self.tref
# frequency and spindowns
searchRegion.fkdot = np.zeros(lalpulsar.PULSAR_MAX_SPINS)
searchRegion.fkdotBand = np.zeros(lalpulsar.PULSAR_MAX_SPINS)
for k in range(3):
Fk = "F{:d}".format(k)
if Fk in range_keys:
searchRegion.fkdot[k] = self.search_ranges[Fk][0]
searchRegion.fkdotBand[k] = (
self.search_ranges[Fk][1] - self.search_ranges[Fk][0]
if len(self.search_ranges[Fk]) >= 2
else 0.0
)
# now construct DopplerFullScan from searchRegion
scanInit = lalpulsar.DopplerFullScanInit()
scanInit.searchRegion = searchRegion
scanInit.stepSizes = lalpulsar.PulsarDopplerParams()
scanInit.stepSizes.refTime = self.tref
scanInit.stepSizes.Alpha = (
self.search_ranges["Alpha"][-1]
if len(self.search_ranges["Alpha"]) == 3
else 0.001 # fallback, irrelevant for band estimate but must be > 0
)
scanInit.stepSizes.Delta = (
self.search_ranges["Delta"][-1]
if len(self.search_ranges["Delta"]) == 3
else 0.001 # fallback, irrelevant for band estimate but must be > 0
)
scanInit.stepSizes.fkdot = np.zeros(lalpulsar.PULSAR_MAX_SPINS)
for k in range(3):
if Fk in range_keys:
Fk = "F{:d}".format(k)
scanInit.stepSizes.fkdot[k] = (
self.search_ranges[Fk][-1]
if len(self.search_ranges[Fk]) == 3
else 0.0
)
scanInit.startTime = self.minStartTime
scanInit.Tspan = float(self.maxStartTime - self.minStartTime)
scanState = lalpulsar.InitDopplerFullScan(scanInit)
# now obtain the PulsarSpinRange extended over all relevant Doppler shifts
spinRangeRef = lalpulsar.PulsarSpinRange()
lalpulsar.GetDopplerSpinRange(spinRangeRef, scanState)
# optional: binary parameters
if "asini" in range_keys:
if len(self.search_ranges["asini"]) >= 2:
maxOrbitAsini = self.search_ranges["asini"][1]
else:
maxOrbitAsini = self.search_ranges["asini"][0]
else:
maxOrbitAsini = 0.0
if "period" in range_keys:
minOrbitPeriod = self.search_ranges["period"][0]
else:
minOrbitPeriod = 0.0
if "ecc" in range_keys:
if len(self.search_ranges["ecc"]) >= 2:
maxOrbitEcc = self.search_ranges["ecc"][1]
else:
maxOrbitEcc = self.search_ranges["ecc"][0]
else:
maxOrbitEcc = 0.0
# finally call the wrapped lalpulsar estimation function with the
# extended PulsarSpinRange and optional binary parameters
self.minCoverFreq, self.maxCoverFreq = utils.get_covering_band(
tref=self.tref,
tstart=self.minStartTime,
tend=self.maxStartTime,
F0=spinRangeRef.fkdot[0],
F1=spinRangeRef.fkdot[1],
F2=spinRangeRef.fkdot[2],
F0band=spinRangeRef.fkdotBand[0],
F1band=spinRangeRef.fkdotBand[1],
F2band=spinRangeRef.fkdotBand[2],
maxOrbitAsini=maxOrbitAsini,
minOrbitPeriod=minOrbitPeriod,
maxOrbitEcc=maxOrbitEcc,
)
def get_fullycoherent_detstat(
self,
F0=None,
F1=None,
F2=None,
Alpha=None,
Delta=None,
asini=None,
period=None,
ecc=None,
tp=None,
argp=None,
params=None,
tstart=None,
tend=None,
):
"""Computes the detection statistic(s) fully-coherently at a single point.
Currently supported statistics:
* twoF (CW)
* log10BSGL (CW or transient)
* maxTwoF (transient)
* lnBtSG (transient)
All computed statistics are stored as attributes,
but only one statistic is returned.
As the basic statistic of this class, `twoF` is always computed
and stored as `self.twoF` as well,
and it is the default return value.
If `self.singleFstats`, additionally the single-detector
2F-stat values are stored in `self.twoFX`.
If `self.BSGL`, the `log10BSGL` statistic for CWs is additionally stored,
and it is returned instead of `twoF`.
If transient parameters are enabled (`self.transientWindowType` is set),
`maxTwoF` will always be computed and stored,
and returned by default.
Depending on the `self.BSGL` and `self.BtSG` options,
either `log10BSGL` (a transient version of it, superseding the CW version)
or `lnBtSG` will also be computed, stored,
and returned instead of `maxTwoF`.
The full transient-F-stat map is also computed here,
but stored in `self.FstatMap`, not returned.
NOTE: the old way of calling this with explicit [F0,F1,F2,Alpha,Delta,...]
parameters is DEPRECATED and may be removed in future versions.
Currently, this method can be either called with
* a complete set of `(F0, F1, F2, Alpha, Delta)`
(plus optional binary parameters),
* OR a `params` dictionary;
and only the latter version will be supported going forward.
Parameters
----------
F0, F1, F2, Alpha, Delta: float
DEPRECATED: Parameters at which to compute the statistic.
asini, period, ecc, tp, argp: float, optional
DEPRECATED: Optional: Binary parameters at which to compute the statistic.
params: dict
A dictionary defining a parameter space point.
See get_fullycoherent_twoF() for more information.
tstart, tend: int or None
GPS times to restrict the range of data used.
If None: falls back to self.minStartTime and self.maxStartTime.
This is only passed on to `self.get_transient_detstats()`,
i.e. only used if `self.transientWindowType` is set.
Returns
-------
stat: float
A single value of the main detection statistic
at the input parameter values.
"""
self.get_fullycoherent_twoF(
F0=F0,
F1=F1,
F2=F2,
Alpha=Alpha,
Delta=Delta,
asini=asini,
period=period,
ecc=ecc,
tp=tp,
argp=argp,
params=params,
)
if not self.transientWindowType:
if self.singleFstats:
self.get_fullycoherent_single_IFO_twoFs()
if not self.BSGL:
return self.twoF
self.get_fullycoherent_log10BSGL()
return self.log10BSGL
return self.get_transient_detstats(
tstart=tstart,
tend=tend,
)
def _set_PulsarDopplerParams(
self,
params=None,
F0=None,
F1=None,
F2=None,
Alpha=None,
Delta=None,
asini=None,
period=None,
ecc=None,
tp=None,
argp=None,
):
"""Helper function to set a PulsarDoplerParams struct from user inputs.
No return value, struct is set as an attribute of the class instance.
FIXME: this can be simplified a lot when removing the deprecated
way of calling with individual parameters instead of a dict!
Parameters
----------
params: dict, optional
A dictionary defining a parameter space point.
See get_fullycoherent_twoF() for more information.
F0, F1, F2, Alpha, Delta: float, optional
DEPRECATED: Parameters at which to compute the statistic.
asini, period, ecc, tp, argp: float, optional
DEPRECATED: Optional: Binary parameters at which to compute the statistic
"""
base_params_oldstyle = {
"F0": F0,
"F1": F1,
"F2": F2,
"Alpha": Alpha,
"Delta": Delta,
}
if params is not None:
required_keys = ["F0", "Alpha", "Delta"]
parkeys = list(params.keys())
keysetdiff = np.setdiff1d(required_keys, parkeys)
if len(keysetdiff) > 0: # pragma: no cover
raise ValueError(
f"Required keys not found in params.keys(): {keysetdiff}"
)
# all supported parameters are either required, binary, or of "Fk" type
keysetdiff = np.setdiff1d(parkeys, required_keys + self.binary_keys)
if not np.all(
[key.startswith("F") for key in keysetdiff]
): # pragma: no cover
raise ValueError(
f"Unknown parameters in input dictionary: {[key for key in keysetdiff if not key.startswith('F')]}"
)
elif sum([val is not None for val in base_params_oldstyle.values()]) == 5:
params = {key: float(val) for key, val in base_params_oldstyle.items()}
parkeys = list(params.keys())
if self.binary:
for key in self.binary_keys:
bpar = eval(key)
if bpar is None: # pragma: no cover
raise ValueError(f"We got self.binary but {key}=None.")
params[key] = float(bpar)
parkeys += self.binary_keys
else: # pragma: no cover
raise ValueError(
"Need either a 'params' dictionary"
f" or a full set of {list(base_params_oldstyle.keys())} (DEPRECATED)"
f" plus optional binary parameters (also DEPRECATED)."
)
self.PulsarDopplerParams.fkdot = np.zeros(lalpulsar.PULSAR_MAX_SPINS)
for key in [key for key in parkeys if key.startswith("F")]:
try:
k = int(key[1:])
except ValueError: # pragma: no cover
raise ValueError(
f"Unknown parameter {key} in input dictionary, it looks like a 'Fk'-style spindown term but cannot convert the part after the 'F' to an integer."
)
if k >= lalpulsar.PULSAR_MAX_SPINS: # pragma: no cover
raise ValueError(
f"Input parameter {key} exceeds lalpulsar.PULSAR_MAX_SPINS={lalpulsar.PULSAR_MAX_SPINS}."
)
self.PulsarDopplerParams.fkdot[k] = params[key]
self.PulsarDopplerParams.Alpha = float(params["Alpha"])
self.PulsarDopplerParams.Delta = float(params["Delta"])
for key in np.intersect1d(self.binary_keys, parkeys, assume_unique=True):
setattr(self.PulsarDopplerParams, key, float(params[key]))
def get_fullycoherent_twoF(
self,
F0=None,
F1=None,
F2=None,
Alpha=None,
Delta=None,
asini=None,
period=None,
ecc=None,
tp=None,
argp=None,
params=None,
):
"""Computes the fully-coherent 2F statistic at a single point.
NOTE: This always uses the full data set as defined when initialising
the search object.
If you want to restrict the range of data used for a single 2F computation,
you need to set a `self.transientWindowType` and then call
`self.get_fullycoherent_detstat()` with `tstart` and `tend` options
instead of this function.
NOTE the old way of calling this with explicit (F0,F1,F2,Alpha,Delta,...)
parameters is DEPRECATED and may be removed in future versions.
Currently, this method can be either called with
* a complete set of `(F0, F1, F2, Alpha, Delta)`
(plus optional binary parameters),
* OR a `params` dictionary;
and only the latter version will be supported going forward.
Parameters
----------
F0, F1, F2, Alpha, Delta: float
DEPRECATED: Parameters at which to compute the statistic.
asini, period, ecc, tp, argp: float, optional
DEPRECATED: Optional: Binary parameters at which to compute the statistic.
params: dict
A dictionary defining a parameter space point,
with `["F0","Alpha","Delta"]` required as a minimum set of keys.
Also supported:
Fk with `1