Source code for skysurvey.dataset

"""
This module concerns the data as observed.
DataSet joins information for a Transient (list of true data) and a Survey (what has been observed when).
It generates real lightcurves observations.
"""

import numpy as np
import pandas
import sncosmo
import warnings

from .target.collection import TargetCollection
from .tools import speedutils


# ================== #
#                    #
#    DataSet         #
#                    #
# ================== #


[docs] class DataSet(object): """ A class for managing and realistic transient light curves given true data and survey observing logs. This class provides methods to load, manipulate, and visualize light curve data based on target and survey information. The classmethod ``DataSet.from_targets_and_survey()`` should be favored for loading the dataset. Parameters ---------- data : `pandas.DataFrame` Multi-index dataframe corresponding to the concatenation of all targets observations. targets : ``skysurvey.Target`` or child of, optional Target data corresponding to the true target parameters (as given by nature). survey : ``skysurvey.Survey`` or child of, optional Survey that has been used to generate the dataset (if known). See Also -------- :func:`from_targets_and_survey` : Loads a dataset (observed data) given targets and survey. :func:`read_parquet` : Loads a stored dataset. """ def __init__(self, data, targets=None, survey=None): """Initialize the DataSet class.""" self.set_data(data) self.set_targets(targets) self.set_survey(survey)
[docs] @classmethod def from_targets_and_survey(cls, targets, survey, incl_error=True, # client=None, phase_range=[-50, +200], progress_bar=False, seed=None, discard_bands=True): """Loads a dataset (observed data) given targets and a survey. This first matches the targets (given ``targets.data[['ra','dec']]``) with the survey to find which target has been observed with which field. Then simulate the targets lightcurves given the observing data (``survey.data``). Parameters ---------- targets: ``skysurvey.Target``, list, ``skysurvey.TargetCollection`` Target data corresponding to the true target parameters (as given by nature). Could be a list survey: ``skysurvey.Survey`` (or child of) Sky observation (what was observed when with which situation). incl_error: bool, optional Include error in the lightcurve. If False, the flux is the true model flux. phase_range: list, None, optional Rest-frame phase range to be used for simulating the lightcurves. If None, no cut is applied on time range for the logs. progress_bar: bool, optional shall this display a progress bar associated to the generation of targets ? (uses tqdm) seed : None, int, Generator, RandomState, optional = ignored if incl_error=False = (docstring adapted from ``np.random.default_rng``) If None, a fresh seed will be pulled. If an `int`, it will be passed to `SeedSequence` to derive the initial `BitGenerator` state. Additionally, when passed a `(Bit)Generator`, it will be returned unaltered. When passed a legacy `RandomState` instance it will be coerced to a `Generator`. discard_bands : bool, optional If True, discards the bands that includes wavelength for which the (observer-frame) target SED is not defined. This prevents crashing the code due to an error from `sncosmo`. Returns ------- dataset instance of a `DataSet` loaded from the given targets. """ if progress_bar: from tqdm import tqdm # if input targets is a list, create a TemplateCollection if type(targets) in [list, tuple]: targets = TargetCollection(targets) # fields in which target fall into dfieldids_ = survey.radec_to_fieldid(targets.data[["ra", "dec"]]) # make sure index of dfieldids_ corresponds to the input one. _data_index = targets.data.index.name if _data_index is None: _data_index = "index" dfieldids_.index.name = _data_index # merge target dataframe with matching fields. # note: pandas.merge conserves dtypes of fieldids, not pandas.join targets_data = targets.data.merge(dfieldids_, left_index=True, right_index=True) target_fields = np.stack(targets_data[survey.fieldids.names].values, dtype="int") #### IS THAT NECESSARY ? #### # =========== # survey_data = survey.data[["mjd", "band", "skynoise", "gain", "zp"] + survey.fieldids.names].copy() if survey_data.index.name is None: survey_data.index.name = "index_obs" field_names = survey.fieldids.names gsurvey_indexed = survey_data.groupby(field_names, observed=True, group_keys=False) # # check which fields have been observed # to avoid looping over un-observed targets. # nobs = gsurvey_indexed.size() fields_observed = np.stack(nobs.index.values, dtype="int") # build boolean mask to see which "target" could have data # given the "field" (all field_names) that have been observed. if (nfields := len(field_names)) == 2: # speed tricks for matching pairs is_target_observed = speedutils.isin_pair_elements(target_fields, fields_observed) elif nfields == 1: is_target_observed = np.isin(target_fields, fields_observed) else: raise NotImplementedError("more than 2 entries for {field_names=}. Not implemented.") # List of observed targets targets_data_observed = targets_data[is_target_observed] # # for lop on targets: # # each lightcurve's flux and associated error are stored # inside `bandflux`. which is then converted into a unique # pandas.DataFrame, using the faster `eff_concat` trick. # # make sure phase_range is an array to multiple by (1+z) if phase_range is not None: phase_range = np.asarray(phase_range) bandflux = [] targets_observed = targets_data_observed.index.unique() for index_target in (tqdm(targets_observed) if progress_bar else targets_observed): # get the target model, that will be used to generate the flux # this model is set to the target parameters. model = targets.get_target_template(index=index_target, as_model=True, set_magabs=True) # grab the target information (could be several rows) this_target = targets_data_observed.loc[[index_target]] # logs associated to this target. this_target_logs = pandas.concat( [ gsurvey_indexed.get_group(tuple(entry_)) for entry_ in this_target[field_names].values ] ) # limit the logs to the given restframe phase range if phase_range is not None: # to limit per phase: # 1. get the model t0 and redshift to get rest-frame phase t0 = model.parameters[model.param_names.index("t0")] redshift = model.parameters[model.param_names.index("z")] # 2. create the mjd range to consider for this target. this_mjd_range = t0 + phase_range * (1 + redshift) # 3. limit the logs to mjd matching this condition. used_logs = this_target_logs[this_target_logs["mjd"].between(*this_mjd_range)].copy() else: used_logs = this_target_logs.copy() if discard_bands: bands = np.unique(used_logs['band']) for band in bands: bandpass = sncosmo.get_bandpass(band) if bandpass.minwave() < model.minwave() or bandpass.maxwave() > model.maxwave(): used_logs = used_logs[used_logs['band'] != band] used_logs = used_logs.sort_values("mjd") # realise the flux lightcurves and its error used_logs["flux"] = model.bandflux( used_logs["band"], used_logs["mjd"], zp=used_logs["zp"], zpsys="ab" ) used_logs["fluxerr"] = np.sqrt( used_logs["skynoise"] ** 2 + np.abs(used_logs["flux"]) / used_logs["gain"] ) # and store. bandflux.append(used_logs) # create a dataframe concatenating all lightcurves lcs = speedutils.eff_concat(bandflux, int(np.sqrt(len(targets_observed))), keys=targets_observed.values) lcs.index.set_names("index", level=0, inplace=True) # if incl_error, the true flux is converted into an observed flux if incl_error: rng = np.random.default_rng(seed) lcs["flux"] += rng.normal(loc=0, scale=lcs["fluxerr"]) return cls(lcs, targets=targets, survey=survey)
[docs] @classmethod def read_parquet(cls, parquetfile, survey=None, targets=None, **kwargs): """Loads a stored dataset. Only the observation data can be loaded this way, not the survey nor the targets (truth). Parameters ---------- parquetfile: str path to the parquet file containing the dataset (pandas.DataFrame) survey: ``skysurvey.Survey`` (or child of), None survey that have been used to generate the dataset (if you know it) targets: ``skysurvey.Target`` (of child of), None target data corresponding to the true target parameters (as given by nature) **kwargs goes to `pandas.read_parquet` Returns ------- class instance with a dataset loaded but maybe no survey nor targets See also -------- :func:`from_targets_and_survey`: loads a dataset (observed data) given targets and survey """ data = pandas.read_parquet(parquetfile, **kwargs) return cls(data, survey=survey, targets=targets)
[docs] @classmethod def read_from_directory(cls, dirname, **kwargs): """Loads a directory containing the dataset, the survey and the targets. = Not Implemented Yet = Parameters ---------- dirname: str path to the directory. Returns ------- class instance See also -------- :func:`from_targets_and_survey`: loads a dataset (observed data) given targets and survey :func:`read_parquet`: loads a stored dataset """ raise NotImplementedError("read_from_directory is not yet available.")
# ============== # # Method # # ============== # # -------- # # SETTER # # -------- #
[docs] def set_data(self, data): """Lightcurve data as observed by the survey. = It is unlikely you need to use that directly. = Parameters ---------- data: `pandas.DataFrame` multi-index dataframe ((id, observation index)) corresponding the concat of all targets observations Returns ------- None See also -------- :func:`read_parquet`: loads a stored dataset """ self._data = data self._obs_index = None
[docs] def set_targets(self, targets): """Set the targets. = It is unlikely you need to use that directly. = Parameters ---------- targets: ``skysurvey.Target`` (of child of), None target data corresponding to the true target parameters (as given by nature) Returns ------- None See also -------- :func:`from_targets_and_survey`: loads a dataset (observed data) given targets and survey """ self._targets = targets
[docs] def set_survey(self, survey): """set the survey = It is unlikely you need to use that directly. = Parameters ---------- survey: ``skysurvey.Survey`` (or child of), None survey that have been used to generate the dataset (if you know it) Returns ------- None See also -------- :func:`from_targets_and_survey`: loads a dataset (observed data) given targets and survey """ self._survey = survey
# -------- # # GETTER # # -------- #
[docs] def get_data(self, add_phase=False, phase_range=None, index=None, redshift_key="z", detection=None, zp=None, join_bandday=False, join_stats="first"): """ Tools to access the data with additional tools. Parameters ---------- add_phase: bool should the phase information 'phase_obs' (obs-frame), 'phase' (rest-frame) be added to the dataframe assuming the input target's t0 and redshift ? phase_range: array min and max phases to be returned. Applied on phase (rest-frame). Setting this sets add_phase to True. index: `pandas.Index`, list, None select the index (targets id) you want. redshift_key: string name of the redshift column in the dset.targets.data. = ignored if add_phase is False = detection: bool, None should this be limited to (non)detected points only ? This follow the bool/None format: - detection=None: no selection - detection=False: only non-detected points - detection=True: only detected points zp: float get the simulated data in the given zp system join_bandday: bool if there are multiple observations per band and day (int of mjd) for a given target, should these be joined ? (see join_stat). join_stats: str join_bandday is True, how multiple observation should be considered ? (e.g., first). Returns ------- `pandas.DataFrame` """ if phase_range is not None: add_phase = True if index is not None: data = self.data.loc[index].copy() else: data = self.data.copy() index = data.index.levels[0] if join_bandday: index_colnames = data.index.names data["mjd_date"] = data["mjd"].astype("int") gb_data = data.reset_index().groupby(by=["index", "band", "mjd_date"]) if join_stats == "first": data = gb_data.first().reset_index().set_index(index_colnames) else: raise NotImplementedError( f"{join_stats=} not implemented. Only first() is." ) if add_phase: target_info = self.targets.data.loc[index][["t0", redshift_key]] # target_info.index = self._data_index # for merging data["phase_obs"] = data["mjd"] - target_info["t0"] data["phase"] = data["phase_obs"] / (1 + target_info[redshift_key]) if phase_range is not None: data = data[data["phase"].between(*phase_range)] if detection is not None: flag_detection = (data["flux"] / data["fluxerr"]) >= 5 if detection: data = data[flag_detection] else: data = data[~flag_detection] if zp is not None: coef = 10 ** (-(data["zp"].values - zp) / 2.5) data["flux"] *= coef data["fluxerr"] *= coef data["zp"] = zp return data
[docs] def get_ndetection(self, phase_range=None, per_band=False, join_bandday=False): """Get the number of detection for each lightcurves. Basically computes the number of datapoints with (flux/fluxerr)>detlimit). Parameters ---------- phase_range: array rest-frame phase range to be considered. per_band: bool should be computation be made per band ? if true it will then be per target *and* per band. join_bandday: bool if there are multiple observations per band and day (int of mjd) for a given target, should these be joined ? (see join_stat). Returns ------- `pandas.Series` the number of detected point per target (and per band if per_band=True) """ data = self.get_data(phase_range=phase_range, detection=True, join_bandday=join_bandday) if per_band: groupby = [self._data_index, "band"] else: groupby = self._data_index ndetection = data.groupby(groupby).size() return ndetection
[docs] def get_target_lightcurve(self, index, detection=None, phase_range=None): """Get the observation of the given target. = short cut to self.get_data(index=index) = Parameters ---------- index : int, optional The index of the target whose light curve is to be taken. If None, a random index is chosen. detection: bool, None should this be limited to (non)detected points only ? This follow the bool/None format: - detection=None: no selection - detection=False: only non-detected points - detection=True: only detected points phase_range: array min and max phases to be returned. Applied on phase (rest-frame). Setting this sets add_phase to True. Returns ------- `pandas.DataFrame` the lightcurve """ return self.get_data(index=index, phase_range=phase_range, detection=detection)
# -------- # # PLOTTER # # -------- #
[docs] def show_target_lightcurve(self, ax=None, fig=None, index=None, zp=25, lc_prop={}, bands=None, show_truth=True, format_time=True, t0_format="mjd", phase_window=None, **kwargs): """Plot the light curve of a target. If `index` is None, a random index will be used. If `bands` is None, the target's observed band will be used. Parameters ---------- ax : `matplotlib.axes.Axes`, optional The axes on which to plot the light curve. If None, a new figure and axes will be created. fig : `matplotlib.figure.Figure`, optional The figure on which to plot the light curve. If None, a new figure will be created. index : int, optional The index of the target whose light curve is to be plotted. If None, a random index is chosen. zp : float, optional Zero point magnitude for flux conversion. Default is 25. lc_prop : dict, optional Additional properties to pass to the light curve plotting function (kwargs). bands : list of str, optional The bands to plot. If None, all observed bands for the target will be used. show_truth : bool, optional Whether to show the true light curve. Default is True. format_time : bool, optional Whether to format the time axis as dates. Default is True. t0_format : str, optional The format of the reference time. Default is "mjd". phase_window : array-like, optional The phase window to plot. If None, the entire light curve will be plotted. **kwargs : dict Additional keyword arguments to pass to the plotting functions. Returns ------- `matplotlib.figure.Figure` The figure object containing the light curve plot. """ from matplotlib.colors import to_rgba from .config import get_band_color if format_time: from astropy.time import Time if index is None: rng = np.random.default_rng() index = rng.choice(self.obs_index) # Data obs_ = self.get_target_lightcurve(index).copy() if phase_window is not None: t0 = self.targets.data["t0"].loc[index] phase_window = np.asarray(phase_window) + t0 obs_ = obs_[obs_["mjd"].astype("float").between(*phase_window)] coef = 10 ** (-(obs_["zp"] - zp) / 2.5) obs_["flux_zp"] = obs_["flux"] * coef obs_["fluxerr_zp"] = obs_["fluxerr"] * coef if len(obs_) == 0: warnings.warn(f"No detections for the SN index={index} (detections possibly outside phase_window).") return None # Model if bands is None: bands = np.unique(obs_["band"]) # = axes and figure = # if ax is None: if fig is None: import matplotlib.pyplot as plt fig = plt.figure(figsize=[7, 4]) ax = fig.add_subplot(111) else: fig = ax.figure colors = get_band_color(bands) if show_truth: fig = self.targets.show_lightcurve(bands, ax=ax, fig=fig, index=index, format_time=format_time, t0_format=t0_format, zp=zp, colors=colors, zorder=2, **lc_prop) elif format_time: from matplotlib import dates as mdates locator = mdates.AutoDateLocator() formatter = mdates.ConciseDateFormatter(locator) ax.xaxis.set_major_locator(locator) ax.xaxis.set_major_formatter(formatter) else: ax.set_xlabel("time [in day]", fontsize="large") # loop over bands for band_, color_ in zip(bands, colors): if color_ is None: ecolor = to_rgba("0.4", 0.2) else: ecolor = to_rgba(color_, 0.2) obs_band = obs_[obs_["band"] == band_] times = ( obs_band["mjd"] if not format_time else Time(obs_band["mjd"], format=t0_format).datetime ) ax.scatter(times, obs_band["flux_zp"], color=color_, zorder=4, **kwargs) ax.errorbar(times, obs_band["flux_zp"], yerr=obs_band["fluxerr_zp"], ls="None", marker="None", ecolor=ecolor, zorder=3, **kwargs) return fig
# ============== # # Properties # # ============== # @property def data(self): """Lightcurve data as observed by the survey.""" return self._data @property def _data_index(self): """Name of data index.""" if not hasattr(self, "_hdata_index"): self._hdata_index = "index" return self._hdata_index @property def targets(self): """Target data corresponding to the true target parameters.""" return self._targets @property def survey(self): """Survey that has been used to generate the dataset.""" return self._survey @property def obs_index(self): """Index of the observed target.""" if not hasattr(self, "_obs_index") or self._obs_index is None: self._obs_index = self.data.index.get_level_values(0).unique().sort_values() return self._obs_index