Source code for sotodlib.preprocess.pcore

"""Base Class and PIPELINE register for the preprocessing pipeline scripts."""
import os
import copy
import logging
import numpy as np
from .. import core
from so3g.proj import Ranges, RangesMatrix
from scipy.sparse import csr_array
from matplotlib import pyplot as plt


[docs] class _Preprocess(object): """The base class for Preprocessing modules which defines the required functions and keys required in the configurations. Each preprocess module has four overwritable functions that are called by the processing scripts in site_pipeline. These four functions are each controlled by a specific key in a configuration dictionary passed to the module on creation. The configuration dictionary has 6 special keys: ``name``, ``process``, ``calc``, ``save``, ``select``, and ``plot``. ``name`` is the name used to register the module with the PIPELINE registry. The other four keys are matched to functions in the module, if the key is not present then that function will be skipped when the preprocessing pipeline is run. There are two special AxisManagers expected to be part of the preprocessing pipeline. ``aman`` is the "standard" time ordered data AxisManager that is loaded via our default styles. ``proc_aman`` is the preprocess AxisManager, this is carry the data products that will be saved to whatever Metadata Archive is connected to the preprocessing pipeline. """ def __init__(self, step_cfgs): self.process_cfgs = step_cfgs.get("process") self.calc_cfgs = step_cfgs.get("calc") self.save_cfgs = step_cfgs.get("save") self.select_cfgs = step_cfgs.get("select") self.plot_cfgs = step_cfgs.get("plot") self.skip_on_sim = step_cfgs.get("skip_on_sim") self.use_data_aman = step_cfgs.get("use_data_aman", False)
[docs] def process(self, aman, proc_aman, sim=False, data_aman=None): """ This function makes changes to the time ordered data AxisManager. Ex: calibrating or detrending the timestreams. This function will use any configuration information under the ``process`` key of the configuration dictionary and is not expected to change or alter proc_aman. Arguments --------- aman : AxisManager The time ordered data proc_aman : AxisManager Any information generated by previous elements in the preprocessing pipeline. sim: Bool False by default when analyzing data. Should be True when doing Transfer Function simulations and determining which steps should be run. data_aman: AxisManager (Optional) An AxisManager containing the preprocessed data to be used by this process. """ if self.process_cfgs is None: return aman, proc_aman raise NotImplementedError
[docs] def calc_and_save(self, aman, proc_aman): """ This function calculates data products of some sort off of the time ordered data AxisManager. Ex: Calcuating the white noise of the timestream. This function will use any configuration information under the ``calc`` key of the configuration dictionary and can call the save function to make changes to proc_aman. Arguments --------- aman : AxisManager The time ordered data proc_aman : AxisManager Any information generated by previous elements in the preprocessing pipeline. """ if self.calc_cfgs is None: return aman, proc_aman raise NotImplementedError
[docs] def save(self, proc_aman, *args): """ This function wraps new information into the proc_aman and will use any configuration information under the ``save`` key of the configuration dictionary. Arguments --------- proc_aman : AxisManager Any information generated by previous elements in the preprocessing pipeline. args : any Any additional information ``calc_and_save`` needs to send to the save function. """ if self.save_cfgs is None: return raise NotImplementedError
[docs] def select(self, meta, proc_aman=None, in_place=True): """ This function runs any desired data selection of the preprocessing pipeline results. Assumes the pipeline has already been run and that the resulting proc_aman is now saved under the ``preprocess`` key in the ``meta`` AxisManager loaded via context. Ex: removing detectors with white noise above some limit. This function will use any configuration information under the ``select`` key. Arguments --------- meta : AxisManager Metadata related to the specific observation proc_aman : AxisManager Optional. Any information generated by previous elements in the preprocessing pipeline. in_place : bool Optional. Apply selection and return restricted axis manager if True, else return the flag array. Returns ------- meta : AxisManager Metadata where non-selected detectors have been removed """ if self.select_cfgs is None: return meta raise NotImplementedError
[docs] def plot(self, aman, proc_aman, filename): """ This function creates plots using results from ``calc_and_save``. Ex: Plotting det bias flags. This function will use any configuration information under the ``plot`` key of the configuration dictionary. Arguments --------- aman : AxisManager The time ordered data proc_aman : AxisManager Any information generated by previous elements in the preprocessing pipeline. filename : str Filename should be a concatenation of the global ``plot_dir`` config with a name with process step number and placeholder {name} as shown in Pipeline.run(). """ if self.plot_cfgs is None: return raise NotImplementedError
[docs] @classmethod def gen_metric(cls, meta, proc_aman): """ Generate a QA metric from the output of this process. Arguments --------- meta : AxisManager Metadata related to the specific observation proc_aman : AxisManager The output of the preprocessing pipeline. Returns ------- line : dict InfluxDB line entry elements to be fed to `site_pipeline.monitor.Monitor.record` """ raise NotImplementedError
[docs] @staticmethod def register(process_class): """Registers a new modules with the PIPELINE""" name = process_class.name if Pipeline.PIPELINE.get(name) is None: Pipeline.PIPELINE[name] = process_class else: raise ValueError( f"Preprocess Module of name {name} is already Registered" )
def _zeros_cls( item ): """return a callable zeros class that exactly matches type item for use with wrap_new""" if isinstance( item, np.ndarray): return lambda shape: np.zeros( shape, dtype = item.dtype) elif isinstance( item, RangesMatrix): return RangesMatrix.zeros elif isinstance( item, Ranges ): def temp(shape): assert len(shape) == 1 return Ranges( shape[0] ) return temp elif isinstance( item, csr_array): return lambda shape: csr_array(tuple(shape), dtype=item.dtype) else: raise ValueError(f"Cannot find zero type for {type(item)}") def _reform_csr_array(arr, oidx, nidx, shape): # Support function for _expand, to efficiently embed a csr_array # in an expanded frame. Assumptions are that arr is 2d, with # shape (dets, samps). So oidx and nidx each contain (array of # idx, slice). row_map = {n: o for n, o in zip(nidx[0], oidx[0])} col_shift = oidx[1].start - nidx[1].start # For coo array you have: # r, c, v = arr.row, arr.col, arr.data # and the equivalent for csr is: # r = np.repeat(np.arange(arr.shape[0]), np.diff(arr.indptr)) # c = arr.indices # v = arr.data # So the expression for r1, here, is equivalent to # r = np.repeat(np.arange(arr.shape[0]), np.diff(arr.indptr)) # r1 = [row_map[_r] for _r in r] r1 = np.repeat([row_map[_r] for _r in range(arr.shape[0])], np.diff(arr.indptr)) c1 = arr.indices + col_shift v = arr.data return csr_array((v, (r1, c1)), shape=shape, dtype=arr.dtype) def _ranges_matrix_match( o, n, oidx, nidx): """align RangesMatrix n entries to RangesMatrix o""" assert len(oidx)==len(nidx) if len(oidx) > 2: raise NotImplemented for i, x in zip( oidx[0], nidx[0]): o.ranges[i] = _ranges_match( o.ranges[i], n.ranges[x], [oidx[1]], [nidx[1]] ) return o.copy() def _ranges_match( o, n, oidx, nidx): """align Ranges n to Ranges o""" assert len(oidx)==len(nidx) assert len(oidx)==1 omsk = o.mask() nmsk = n.mask() omsk[oidx[0]] = nmsk[nidx[0]] return Ranges.from_mask(omsk) def _intersect(new, out): '''Get detector and samples intersection between ``new`` and ``out``.''' if 'dets' in new._axes: _, fs_dets, ns_dets = out.dets.intersection( new.dets, return_slices=True ) else: fs_dets = range(out.dets.count) ns_dets = None if 'samps' in new._axes: _, fs_samps, ns_samps = out.samps.intersection( new.samps, return_slices=True ) else: fs_samps = slice(None) ns_samps = None return fs_dets, ns_dets, fs_samps, ns_samps def _wrap_valid_ranges(new, out, valid_name="valid", wrap_name=None): """Wraps in a new Ranges field into ``out`` that tracks the current number of detectors and samples that intersect with ``new``. """ fs_dets, _, fs_samps, _ = _intersect(new, out) x = Ranges( out.samps.count ) m = x.mask() m[fs_samps] = True v = Ranges.from_mask(m) valid = RangesMatrix( [v if i in fs_dets else x for i in range(out.dets.count)] ) if wrap_name: if wrap_name in out: out.move(wrap_name, None) valid_aman = core.AxisManager(out.dets, out.samps) valid_aman.wrap(valid_name, valid, [(0,'dets'),(1,'samps')]) out.wrap(wrap_name, valid_aman) else: if valid_name in out: out.move(valid_name, None) out.wrap(valid_name, valid, [(0,'dets'),(1,'samps')]) def _expand(new, full, wrap_valid=True): """new will become a top level axismanager in full once it is matched to size""" fs_dets, ns_dets, fs_samps, ns_samps = _intersect(new, full) out = core.AxisManager() for k, v in full._axes.items(): if k in list(new._axes.keys())+['dets','samps']: out._axes[k] = v for a in new._axes: if a not in out: out.add_axis( new[a] ) for k, v in new._fields.items(): if isinstance(v, core.AxisManager): out.wrap( k, _expand( v, full) ) else: if np.isscalar(v): # Skip expansion for wrapped scalars. out.wrap(k, v) continue out.wrap_new( k, new._assignments[k], cls=_zeros_cls(v)) oidx=[]; nidx=[] for ii, a in enumerate(new._assignments[k]): if a == 'dets': oidx.append(fs_dets) nidx.append(ns_dets) elif a == 'samps': oidx.append(fs_samps) nidx.append(ns_samps) else: if (ii == 0) and isinstance(out[k], RangesMatrix): # Treat like dets # _ranges_matrix_match expects oidx[0] and nidx[0] to be list(inds), not slice. # Unknown axes treated as dets if first entry, else like samps. Added to support (subscans, samps) RangesMatrix. if a in full._axes: _, fs, ns = full[a].intersection(new[a], return_slices=True) else: fs = range(new[a].count) ns = range(new[a].count) oidx.append(fs) nidx.append(ns) else: # Treat like samps oidx.append(slice(None)) nidx.append(slice(None)) oidx = tuple(oidx) nidx = tuple(nidx) if isinstance(out[k], RangesMatrix): assert new._assignments[k][-1] == 'samps' out[k] = _ranges_matrix_match( out[k], v, oidx, nidx) elif isinstance(out[k], Ranges): assert new._assignments[k][0] == 'samps' out[k] = _ranges_match( out[k], v, oidx, nidx) elif isinstance(out[k], csr_array): assert tuple(new._assignments[k]) == ('dets', 'samps') out[k] = _reform_csr_array(v, oidx, nidx, out[k].shape) else: try: out[k][oidx] = v[nidx] except TypeError: # Skip expansion for scalar array with no axes. out[k] = v if wrap_valid: _wrap_valid_ranges(new, out) return out def update_full_aman(proc_aman, full, wrap_valid): """Copy new fields from proc_aman[dets,samps] over to full[full-dets,full-samps] after correct re-sizing and indexing. Arguments ---------- proc_aman: AxisManager A preprocess AxisManager from a pipeline run. The dets,samps axes in proc_aman is assumed to be a subset of the dets,samps axes in full full: AxisManager A full shape AxisManager that begins the pipeline as the original shape of the TOD AxisManager """ for fld in proc_aman._fields: if fld not in full._fields: assert isinstance(proc_aman[fld], core.AxisManager) full.wrap( fld, _expand( proc_aman[fld], full, wrap_valid=wrap_valid) )
[docs] class Pipeline(list): """This class is designed to create and run pipelines out of a series of different preprocessing modules (classes that inherent from _Preprocess). It inherits list object. It also contains the registration of all possible preprocess modules in Pipeline.PIPELINE """ PIPELINE = {} def __init__(self, modules, plot_dir='./', logger=None, wrap_valid=True): """ Arguments --------- modules: iterable A list or other iterable that contains either instantiated _Preprocess instances or the configuration dictionary used to instantiate a module plot_dir: str Directory prefix for preprocess plots logger: optional logging.logger instance used by the pipeline to send updates """ if logger is None: logger = logging.getLogger("pipeline") self.logger = logger self.plot_dir = plot_dir self.wrap_valid = wrap_valid super().__init__( [self._check_item(item) for item in copy.deepcopy(modules)]) def _check_item(self, item): if isinstance(item, _Preprocess): return item elif isinstance(item, dict): name = item.get("name") if name is None: raise ValueError(f"Processes made from dictionary must have a 'name' key") cls = self.PIPELINE.get(name) if cls is None: raise ValueError(f"'{name}' not registered as a pipeline element") return cls(item) else: raise ValueError(f"Unknown type created a pipeline element") # make pipeline have all the list pieces
[docs] def append(self, item): super().append( self._check_item(item) )
[docs] def insert(self, index, item): super().insert(index, self._check_item(item))
[docs] def extend(self, index, other): if isinstance(other, type(self)): super().extend(other) else: super().extend( [self._check_item(item) for item in other])
def __setitem__(self, index, item): super().__setitem__(index, self._check_item(item)) def __getitem__(self, index): result = super().__getitem__(index) if isinstance(index, slice): return Pipeline(result) else: return result
[docs] def run(self, aman, proc_aman=None, full_aman=None, select=True, sim=False, update_plot=False, data_amans=None): """ The main workhorse function for the pipeline class. This function takes an AxisManager TOD and successively runs the pipeline of preprocessing modules on the AxisManager. The order of operations called by run are:: for process in pipeline: process.process() process.calc_and_save() process.save() ## called by process.calc_and_save() process.select() Arguments --------- aman: AxisManager A TOD object. Generally expected to be raw, unprocessed data. This axismanager will be edited in place by the process and select functions of each preprocess module proc_aman: AxisManager (Optional) A preprocess axismanager. If this is provided it is assumed that the pipeline has previously been run on this specific TOD and has returned this preprocess axismanager. In this case, calls to ``process.calc_and_save()`` are skipped as the information is expected to be present in this AxisManager. full_aman: AxisManager (Optional) A preprocess axismanager. This axis manager stores the outputs of preprocessing functions (proc_aman) but without any of the detector or samps restrictions applied, thus maintaining its original shape. This is returned at the end of the pipeline. If not passed it is instantiated with the same number of dets and samps as aman. select: boolean (Optional) if True, the aman detector axis is restricted as described in each preprocess module. Most pipelines are developed with select=True. Running select=False may produce unstable behavior sim: boolean (Optional) if running on sim (``sim=True``), proccesses with the flag ``skip_on_sim`` will be skipped. update_plot: boolean (Optional) if True, re-runs plotting (along with processes and selects) given ``proc_aman`` is ``aman.preprocess``. This assumes ``process.calc_and_save()`` has been run on this aman before and has injested flags and other information into ``proc_aman``. data_amans: dict (Optional) A dictionary of AxisManagers with keys (step, process.name) filled with AxisManager processed up to step-1. This is used to pre-load all data AxisManager which could be required when processing simulations (e.g. to provide a T2P template) Returns ------- full_aman: AxisManager A preprocess axismanager that contains all data products calculated throughout the running of the pipeline. success: str A string that stores the name of the last process step that the pipeline completed. If the pipeline successfully finishes all steps, success = 'end'. """ if proc_aman is None: if 'preprocess' in aman: proc_aman = aman.preprocess.copy() if full_aman is None: full_aman = aman.preprocess.copy() else: proc_aman = core.AxisManager(aman.dets, aman.samps) if full_aman is None: full_aman = core.AxisManager( aman.dets, aman.samps) run_calc = True update_plot = False else: if aman.dets.count != proc_aman.dets.count or not np.all(aman.dets.vals == proc_aman.dets.vals): self.logger.warning("proc_aman has different detectors than aman. Cutting aman to match") det_list = [det for det in proc_aman.dets.vals if det in aman.dets.vals] aman.restrict('dets', det_list) proc_aman.restrict('dets', det_list) if full_aman is None: full_aman = proc_aman.copy() run_calc = False if 'frequency_cutoffs' not in proc_aman: freq_cutoff = np.nan proc_aman.wrap('frequency_cutoffs', core.AxisManager()) for _field, _sub_iir_params in aman.iir_params._fields.items(): if isinstance(_sub_iir_params, core.AxisManager): if 'a' in _sub_iir_params._fields: if _sub_iir_params['a'] is not None: from ..tod_ops import filters n = len(aman.timestamps) delta_t = (aman.timestamps[-1] - aman.timestamps[0])/n freqs = np.fft.rfftfreq(n, delta_t) iir = filters.iir_filter()(freqs, aman) mag = np.abs(iir) / np.max(np.abs(iir)) # 3dB scale scale = 10 ** (-3. / 20) freq_cutoff = freqs[np.min(np.where(np.array(mag < scale * np.max(mag)))[0])] proc_aman['frequency_cutoffs'].wrap('signal', freq_cutoff) success = 'end' for step, process in enumerate(self): if sim and (process.skip_on_sim is None): raise ValueError(f"Process {process.name} missing required field `skip_on_sim`") if sim and process.skip_on_sim: continue self.logger.debug(f"Running {process.name}") if (data_amans is not None) and process.use_data_aman: try: data_aman = data_amans[step, process.name] except KeyError: raise KeyError(f"Requested to use data AxisManager for process {process.name} but not found in data_amans") else: if process.use_data_aman and sim: raise ValueError(f"Process {process.name} requested to use data_aman but none was provided to Pipeline.run()") data_aman = None process.process(aman, proc_aman, sim, data_aman) if run_calc: aman, proc_aman = process.calc_and_save(aman, proc_aman) process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png')) update_full_aman( proc_aman, full_aman, self.wrap_valid) if update_plot: process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png')) plt.close() if select: process.select(aman, proc_aman) proc_aman.restrict('dets', aman.dets.vals) self.logger.debug(f"{proc_aman.dets.count} detectors remaining") if aman.dets.count == 0: success = process.name break if run_calc: _wrap_valid_ranges(proc_aman, full_aman, valid_name='valid_data', wrap_name='valid_data') # copy updated frequency cutoffs to full_aman if "frequency_cutoffs" in full_aman: full_aman.move("frequency_cutoffs", None) full_aman.wrap("frequency_cutoffs", proc_aman["frequency_cutoffs"]) return full_aman, success
class _FracFlaggedMixIn(object): @classmethod def gen_metric( cls, meta, proc_aman, flags_key=None, percentiles=[0, 50, 75, 90, 95, 100], tags={}, ): """ Generate a QA metric from the output of this process. Arguments --------- meta : AxisManager The full metadata container. proc_aman : AxisManager The metadata containing just the output of this process. flags_key : str or tuple The keys to access the flag dataset. e.g. 'glitches.glitch_flags'. If a single term is passed, the first item will be inferred to be the process name. percentiles : list Percentiles to compute across detectors tags : dict The values are keys into `metadata.det_info` to record as tags with the Influx line. The keys are addded to the default list ["wafer_slot", "tel_tube", "bandpass"] with "bandpass" being taken from "wafer.bandpass" or "det_cal.bandpass" if the former isn't found. Returns ------- line : dict InfluxDB line entry elements to be fed to `site_pipeline.monitor.Monitor.record` """ # parse key for flags dataset if flags_key is None: raise ValueError("The flags_key parameter must be specified in the config.") else: keys = flags_key.split(".") if len(keys) == 1: key1, key2 = cls.name, keys[0] elif len(keys) == 2: key1, key2 = keys else: raise ValueError(f"Could not parse flags_key {flags_key}") # add specified tags from ..qa.metrics import _get_tag, _has_tag tag_keys = { "wafer_slot": "wafer_slot", "tel_tube": "tel_tube", } if _has_tag(meta.det_info, 'wafer.bandpass'): bandpasses = meta.det_info.wafer.bandpass tag_keys["bandpass"] = "wafer.bandpass" else: bandpasses = meta.det_info.det_cal.bandpass tag_keys["bandpass"] = "det_cal.bandpass" tag_keys.update(tags) tags = [] vals = [] # record one metric per wafer slot, per bandpass for bp in np.unique(bandpasses): for ws in np.unique(meta.det_info.wafer_slot): subset = np.where( (meta.det_info.wafer_slot == ws) & (bandpasses == bp) )[0] # Compute the number of samples that were flagged frac_flagged = np.array([ np.dot(r.ranges(), [-1, 1]).sum() for r in proc_aman[key1][key2][subset] ], dtype=float) if len(frac_flagged) > 0: num_valid = np.array([ np.dot(r.ranges(), [-1, 1]).sum() for r in proc_aman[key1].valid[subset] ]) with np.errstate(divide="ignore"): frac_flagged *= np.where(num_valid > 0, 1 / num_valid, 0) # record percentiles over detectors and fraction of samples flagged perc = np.percentile(frac_flagged, percentiles) mean = frac_flagged.mean() # get the tags for this wafer (all detectors in this subset share these) tags_base = { k: _get_tag(meta.det_info, i, subset[0]) for k, i in tag_keys.items() if _has_tag(meta.det_info, i) } tags_base["telescope"] = meta.obs_info.telescope # add tags and values to respective lists in order tags_perc = [tags_base.copy() for i in range(perc.size)] for i, t in enumerate(tags_perc): t["det_stat"] = f"percentile_{percentiles[i]}" vals += list(perc) tags += tags_perc tags_mean = tags_base.copy() tags_mean["det_stat"] = "mean" vals.append(mean) tags.append(tags_mean) obs_time = [meta.obs_info.timestamp] * len(vals) return { "field": cls._influx_field, "values": vals, "timestamps": obs_time, "tags": tags, }