"""
This module defines the `BaseSurvey` base class, providing the core data structure and shared methods for all survey types.
"""
import warnings
import numpy as np
[docs]
class BaseSurvey( object ):
""" The `BaseSurvey` class.
Parameters
----------
data: `pandas.DataFrame`
observing data.
REQUIRED_COLUMNS : list
List of column names that must be present in the input DataFrame:
* ``mjd``: Modified Julian Date of the observation.
* ``band``: Filter/bandpass used (e.g., 'g', 'r', 'i').
* ``skynoise``: Image background contribution to flux error.
* ``gain``: CCD gain (e.g., electrons/ADU).
* ``zp``: Photometric zeropoint.
"""
REQUIRED_COLUMNS = ['mjd', 'band', 'skynoise', "gain", "zp"]
# NOTE
# -----
# ``skynoise`` is the image background contribution to the flux measurement
# error (in units corresponding to the specified zeropoint and zeropoint
# system). To get the error on a given measurement, ``skynoise`` is added
# in quadrature to the photon noise from the source.
#
# It is left up to the user to calculate ``skynoise`` as they see fit as the
# details depend on how photometry is done and possibly how the PSF is
# is modeled. As a simple example, assuming a Gaussian PSF, and perfect
# PSF photometry, ``skynoise`` would be ``4 * pi * sigma_PSF * sigma_pixel``
# where ``sigma_PSF`` is the standard deviation of the PSF in pixels and
# ``sigma_pixel`` is the background noise in a single pixel in counts.
# -- note from sncosmo
def __init__(self, data):
"""
Initialize the BaseSurvey class."""
self.set_data(data)
def __array__(self):
""" Numpy array representation of the data. """
return self.data.__array__()
# ============== #
# Methods #
# ============== #
[docs]
def set_data(self, data, lower_precision=True, sort_mjd=True):
""" Set the observing data.
= It is unlikely you need to use that directly. =
Parameters
----------
data: `pandas.DataFrame`
observing data. see REQUIRED_COLUMNS for the list of
required columns.
lower_precision: bool
change the types from 64 to 32 precision when possible.
sort_mjd: bool
should this sort by mjd (if needed) as required to draw dataset
Returns
-------
None
"""
if data is None:
self._data = None
return
if not np.isin(self.REQUIRED_COLUMNS, data.columns).all():
warnings.warn(f"at least one of the following column name if missing {self.REQUIRED_COLUMNS}")
if self.fields is not None and self.fieldids.name is not None:
if not np.all([f_name in data for f_name in self.fieldids.names]):
warnings.warn(f"fieldid {self.fieldids.names} are not in the input data")
if lower_precision:
data = data.astype( {k: str(v).replace("64","32") for k, v in data.dtypes.to_dict().items()})
if sort_mjd and not (data["mjd"].is_monotonic_increasing or data["mjd"].is_monotonic_decreasing):
data = data.sort_values("mjd")
self._data = data
# ------------ #
# GETTER #
# ------------ #
[docs]
def get_timerange(self, timekey="mjd"):
""" Returns the min and max of the given timekey column.
Parameters
----------
timekey: str
column name of the time column.
Returns
-------
`numpy.array`
"""
return self.data[timekey].agg(["min", "max"]).values
[docs]
def get_fieldcoverage(self, incl_zeros=False, fillna=np.nan,
**kwargs):
""" Short cut to ``get_fieldstat('size')``.
Parameters
----------
incl_zeros: bool
fields will no entries will not be shown
except if incl_zeros is True
fillna: float, str
format of the N/A entries
**kwargs goes to ``get_fieldstat()``
Returns
-------
DataFrame or Serie
following `groupby.agg()`
See also
--------
``get_fieldstat``: get observing statistics for the fields
"""
return self.get_fieldstat(stat="size", columns=None,
incl_zeros=incl_zeros,
fillna=fillna, **kwargs)
[docs]
def get_fieldstat(self, stat, columns=None,
incl_zeros=False, fillna=np.nan,
data=None):
""" Get observing statistics for the fields.
basically a shortcut to ``data.groupby("fieldid")[`column`].`stat`()``
Parameters
----------
stat: str, list
element to be passed to `groupby.agg()`
could be e.g.: 'mean' or ['mean', 'std'] or [np.median, 'mean'] etc.
If stat = 'size', this returns data["fieldid"].value_counts()
(slightly faster than groupby("fieldid").size()).
columns: str, list, None
name of the columns to be kept.
None means no cut.
incl_zeros: bool
fields will no entries will not be shown
except if incl_zeros is True
fillna: float, str
format of the N/A entries
data: `pandas.DataFrame`, None
data you want this to be applied to.
if None, a copy of self.data is used.
= leave to None if unsure =
Returns
-------
DataFrame or Serie
following `groupby.agg()`
"""
if data is None:
data = self.data.copy()
fieldids = self.fieldids.names
fieldgrouped = self.data.groupby(fieldids)
if stat in ["size","value_counts"]:
data = fieldgrouped.size()
elif columns is None:
data = fieldgrouped.agg(stat)
else:
data = fieldgrouped[columns].agg(stat)
if not incl_zeros:
return data
return data.reindex(self.fieldids, level=0)
[docs]
def radec_to_fieldid(self, radec):
""" Get the fieldid of the given (list of) coordinates.
Parameters
----------
radec: `pandas.DataFrame` or 2d array
coordinates in degree
Returns
-------
`pandas.Series`
"""
raise NotImplementedError("you have not implemented radec_to_fieldid for your survey")
[docs]
def get_observations_from_coords(self, radec):
""" Returns the data associated to the input radec coordinates.
(calls ``radec_to_fieldid`` and select data matching the fieldid)
Parameters
----------
radec: `pandas.DataFrame` or 2d array
coordinates in degree
(see format ``radec_to_fieldid()``)
Returns
-------
`pandas.DataFrame`
copy of the data observed in the given radec coordinates
"""
fields = self.radec_to_fieldid(radec, observed_fields=True)
return self.data[ self.data[self.fieldids.name].isin(fields[self.fieldids.name]) ].copy()
# ----------- #
# PLOTTER #
# ----------- #
[docs]
def show(self):
""" Shows the sky coverage.
Raises
------
NotImplementedError
This method is not implemented for this survey.
"""
raise NotImplementedError("you have not implemented show for your survey")
[docs]
def show_nexposures(self, ax=None, exposure_key="expid",
bands=None,perband=True, band_key="band", band_colors=None,
fieldid=None,
legend=True, **kwargs):
""" Show the number of exposures per day.
Parameters
----------
ax: `matplotlib.axes`
axes to plot on.
exposure_key: str
column name of the exposure id.
bands: list
list of bands to plot.
perband: bool
if True, plot the number of exposures per band.
band_key: str
column name of the band.
band_colors: dict
dictionary of colors for each band.
fieldid: int or list
field id to plot.
legend: bool
if True, show the legend.
**kwargs goes to ax.bar
Returns
-------
`matplotlib.figure`
"""
from astropy.time import Time
day = self.data["mjd"].astype(int)
day.name = "day"
data = self.data.join(day) # this is a copy
if fieldid is not None:
data = data[data["fieldid"].isin( np.atleast_1d(fieldid) )]
if ax is None:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(9,2))
ax = fig.add_axes([0.075, 0.2, 0.85, 0.6])
else:
fig = ax.figure
if not perband:
nobs = data.groupby(exposure_key).first().groupby("day").size()
max_obs = nobs.max()
all_days = nobs.index
nbands = 1
else:
nobs = data.groupby(exposure_key).first().groupby(["day", band_key]).size()
max_obs = nobs.groupby(level=0).sum().max()
all_days = nobs.index.levels[0]
if bands is None:
bands = nobs.index.levels[1]
nbands = len(bands)
times = Time(all_days.astype( float ), format="mjd").datetime
# plotting properties
prop = {**dict(zorder=3, width=0.95), **kwargs}
if perband:
bottom = 0
if band_colors is None:
band_colors = [None for i_ in range(len(bands))]
for band_, color_ in zip( bands, band_colors ):
d_ = nobs.xs(band_, level=1).reindex(all_days).fillna(0).astype(int).values
ax.bar(times, d_, color=color_,
bottom=bottom, label=f"{band_}",
**prop)
bottom += d_
else:
ax.bar(times, nobs.values, **prop)
from matplotlib import dates as mdates
locator = mdates.AutoDateLocator()
formatter = mdates.ConciseDateFormatter(locator)
ax.xaxis.set_major_locator(locator)
ax.xaxis.set_major_formatter(formatter)
clearwhich = ["left","right","top"] # "bottom"
[ax.spines[which].set_visible(False) for which in clearwhich]
ax.tick_params(axis="y", labelsize="small",
labelcolor="0.7", color="0.7")
ax.grid(axis="y", lw=0.5, color='0.7', zorder=1, alpha=0.5)
ax.set_ylabel("exposures per day", color="0.7", fontsize="small")
ax.set_ylim(ymin=0, ymax=np.round(max_obs*1.05,decimals=-1) )
if legend:
ax.legend(loc=[0,1], ncol=nbands, frameon=False, fontsize="small")
return fig
# ============== #
# Properties #
# ============== #
@property
def data(self):
""" Dataframe containing what has been observed when.
aka. the observing data
"""
return self._data
@property
def metadata(self):
""" Metadata associated to the survey, """
meta = {"type":self.of_type}
return meta
@property
def nfields(self):
""" Number of fields """
if not hasattr(self,"_nfields") or self._nfields is None:
warnings.warn("no nfields set, so this is assuming max of data['fieldid'].")
self._nfields =self.data["fieldid"].max()
return self._nfields
@property
def fields(self):
""" Geodataframe containing the fields coordinates. """
if not hasattr(self,"_fields"):
return None
return self._fields
@property
def of_type(self):
""" Kind of survey that is. """
return str(type(self)).split("'")[-2].split(".")[-1]
@property
def date_range(self):
""" First and last date of the survey. """
return np.min(self.data["mjd"]), np.max(self.data["mjd"])