Source code for sotodlib.tod_ops.gapfill

import numpy as np
import so3g
from . import pca
import logging

logger = logging.getLogger(__name__)

[docs] class Extract: """Container for storage of sparse sub-segments of a vector. This ties together a Ranges object and a data array with the same length as the total length of the Ranges segments. This is useful in cases where a relatively small number of samples need to be excised from a signal vector, but stored and perhaps restored at a later time. In docstrings below a "full data vector" is a 1-d array with self.ranges.count elements; the "tracked samples" are a subset consisting of self.ranges.mask.sum() elements. Attributes: ranges: The so3g.proj.Ranges object mapping tracked samples into full data vector. n_ex: The number of tracked samples. data: The data vector containing tracked samples. """
[docs] def __init__(self, ranges, init_data=True): """Constructor. Arguments: ranges: a Ranges object; all positive ranges will be tracked. init_data: array or boolean. If False, self.data is initialized to None. If True, self.data is initialized to zeros. Otherwise, init_data is stored in self.data. Note this is a reference, not a copy. """ self.ranges = ranges.copy() rr = self.ranges.ranges() self.n_ex = rr[:, 1].sum() - rr[:, 0].sum() if init_data is False or init_data is None: self.data = None elif init_data is True: self.data = np.zeros(self.n_ex) else: self.data = init_data
@property def dtype(self): return self.data.dtype @property def shape(self): return (self.ranges.count,)
[docs] def offset_iter(self): """Returns an iterator for use in expanding / collapsing extracts. Each iteration returns a tuple (ex_lo, ex_hi, full_lo, full_hi) where ex_lo:ex_hi is a range in extracted data vector, and full_lo:full_hi is a range in the full data vector. """ offset = 0 for lo, hi in self.ranges.ranges(): yield offset, offset + hi - lo, lo, hi offset += hi - lo
[docs] def expand(self, fill_value=0): """Expands the extract into a full-length data array, filling missing bits with fill_value.""" output = np.zeros(self.ranges.count, self.data.dtype) + fill_value for elo, ehi, lo, hi in self.offset_iter(): output[lo:hi] = self.data[elo:ehi] return output
[docs] def swap(self, signal): """Swaps the current extract with the tracked samples of full data vector signal.""" to_save = np.empty(self.n_ex, self.data.dtype) for elo, ehi, lo, hi in self.offset_iter(): to_save[elo:ehi] = signal[lo:hi] signal[lo:hi] = self.data[elo:ehi] self.data[:] = to_save
[docs] def patch(self, signal): """Copies the extract into the signal vector. Untracked samples are not modified. """ for elo, ehi, lo, hi in self.offset_iter(): signal[lo:hi] = self.data[elo:ehi] return signal
[docs] def accumulate(self, signal, scale): """Adds the extract, scaled by factor scale, into the full length signal vector. Untracked samples are not modified. """ for elo, ehi, lo, hi in self.offset_iter(): signal[lo:hi] += self.data[elo:ehi] * scale return signal
[docs] class ExtractMatrix: """This class is a simple container for a list (of length n_dets) of Extract objects of equal length (n_samps), to provide abstract access to data with shape (n_dets, n_samps). Each child Extract can be accessed with [item] indexing. Most methods have (tod, signal=None) signature, and are simple loops that call the method of the same name on each child Extract object. """
[docs] def __init__(self, items=None): """The complete list of Extract objects should be passed in.""" if items is None: items = [] self.items = items
def __repr__(self): """This repr shows the shape of the full array, along with the fraction (as a percentage) of the full space that is actually tracked. """ frac = (sum([e.n_ex for e in self.items]) / (self.shape[0] * self.shape[1])) return 'ExtractMatrix(' + ','.join(map(str, self.shape)) + \ '@%.1f%%)' % (frac*100) def __len__(self): return self.shape[0] def copy(self): return ExtractMatrix([x.copy() for x in self.items]) @property def shape(self): if len(self.items) == 0: return (0,) return (len(self.items),) + self.items[0].shape def __getitem__(self, index): return self.items[index]
[docs] def expand(self, fill_value=0): """Expands the extract into a full-size data array, filling missing bits with fill_value.""" dtype = self.items[0].dtype signal = np.zeros(self.shape, dtype) + fill_value for o, i in zip(signal, self.items): i.patch(o) return signal
[docs] def swap(self, tod, signal=None): """Swaps the current extract with the tracked samples of full data vector signal.""" if signal is None: signal = tod.signal for o, i in zip(signal, self.items): i.swap(o) return signal
[docs] def patch(self, tod, signal=None): """Copies the extract into the signal vector. Untracked samples are not modified. """ if signal is None: signal = tod.signal for o, i in zip(signal, self.items): i.patch(o) return signal
[docs] def accumulate(self, tod, signal, scale): """Adds the extract, scaled by factor scale, into the full length signal vector. Untracked samples are not modified. """ if signal is None: signal = tod.signal for o, i in zip(signal, self.items): i.accumulate(o, scale) return signal
[docs] def get_gap_fill_single(data, flags, nbuf=10, order=1, swap=False): """Computes samples to fill the gaps in data identified by flags. Each flagged segment is modeled with a polynomial of the order specified, based on up to nbuf points on each side of the segment. Arguments: data: 1d vector of samples. flags: Ranges object consistent with data size. nbuf: Maximum number of samples on each side of a flagged span to use for model fit. If the unflagged span between two flagged spans is shorter than nbuf, then it is still used to anchor one end of the fit. order: Maximum order of polynomial model used in interpolation. The actual order may be smaller if insufficient data are available to constrain the coefficients. Care should be taken when going higher than linear (order=1); this will tend to be unstable unless the gaps are much smaller than the anchors used to constrain the poly on each end. swap: If False, do not modify the input data vector. If True, patch the data with the model. Returns: An Extract object containing the modeled data. If swap is True, then the input data vector is patched with the model and the returned object, instead, contains the samples from data that were changed. """ rsegs = (flags.copy().buffer(nbuf) * ~flags) rseg_ranges = rsegs.ranges() A = np.zeros((order+1, order+1)) b = np.zeros(order+1) t0, y0 = 0, 0 model = None model_i = -1 # Set to trigger update. sig_ex = Extract(flags) for elo, ehi, lo, hi in sig_ex.offset_iter(): while (model_i + 1 < len(rseg_ranges) and lo > rseg_ranges[model_i+1][0]): model = None model_i += 1 if model is None: t0, y0 = lo, data[lo-1] b[:] = 0 A[:] = 0 contrib_count = 0 for f in [model_i, model_i+1]: if f < 0 or f >= len(rseg_ranges): continue _lo, _hi = rseg_ranges[f] _t = np.arange(_lo, _hi) - t0 for _j in range(order+1): b[_j] += np.dot(_t**_j, data[_lo:_hi] - y0) for _k in range(_j, order+1): A[_j, _k] += (_t**(_j+_k)).sum() A[_k, _j] = A[_j, _k] contrib_count += _hi - _lo if contrib_count == 0: y0 = 0. model = [0.] else: # Only fit as many terms as you plausibly constrain -- # 10 data points per term. n_keep = 1 + max(0, min(order, contrib_count // 10)) model = np.linalg.solve(A[:n_keep,:n_keep], b[:n_keep])[::-1] t = np.arange(lo, hi) - t0 sig_ex.data[elo:ehi] = np.polyval(model, t) + y0 if swap: sig_ex.swap(data) return sig_ex
[docs] def get_gap_fill(tod, nbuf=10, order=1, swap=False, signal=None, flags=None, _method='fast'): """See get_gap_fill_single for meaning of arguments not described here. Arguments: tod: AxisManager with (dets, samps) axes. signal: signal to pass to get_gap_fill_single as data argument; defaults to tod.signal flags: flags to pass to get_gap_fill_single; defaults to tod.flags. Returns: The ExtractMatrix object with per-detector Extracts from get_gap_fill_single. """ if signal is None: signal = tod.signal if flags is None: flags = tod.flags if _method is None: _method = 'fast' if hasattr(so3g, 'get_gap_fill_poly') else 'slow' if _method == 'fast': sample_counts = [np.dot(r.ranges(), [-1, 1]).sum() for r in flags] dest = np.empty(sum(sample_counts), dtype='float32') so3g.get_gap_fill_poly(flags, signal, nbuf, order, swap, dest) sample_idx = np.cumsum([0] + sample_counts) return ExtractMatrix([Extract(r, dest[i:j]) for r, i, j in zip(flags, sample_idx[:-1], sample_idx[1:])]) else: return ExtractMatrix([get_gap_fill_single(d, f, order=order, nbuf=nbuf, swap=swap) for d, f in zip(signal, flags)])
[docs] def get_gap_model_single(weights, modes, flags): """Computes samples to fill gaps in data identified by flags, based on weights and modes (such as would be contained in a PCA model). Arguments: weights: 1-d array of weights (n_mode) to apply to each mode. modes: 2-d array of modes (n_mode, n_samps) flags: so3g.proj.Ranges object with count == n_samps. Returns: An Extract object containing the modeled data. """ sig_ex = Extract(flags) for w, m in zip(weights, modes): for elo, ehi, lo, hi in sig_ex.offset_iter(): sig_ex.data[elo:ehi] += w * m[lo:hi] return sig_ex
[docs] def get_gap_model(tod, model, flags=None, weights=None, modes=None): """Calls get_gap_model_single on each detector and its corresponding model weights. Arguments: tod: AxisManager with (dets, samps) axes. model: AxisManager with (dets, eigen, samps) axes. flags: flags to pass to get_gap_model_single; defaults to tod.flags. weights: array of mode couplings (dets, eigen); defaults to model.weights. modes: array with time-dependent modes (eigen, samps); defaults to model.modes. Returns: The ExtractMatrix object with per-detector models from get_gap_model_single. """ if flags is None: flags = tod.flags if weights is None: weights = model.weights if modes is None: modes = model.modes return ExtractMatrix([get_gap_model_single(w, modes, f) for w, f in zip(weights, flags)])
[docs] def get_contaminated_ranges(good_flags, bad_flags): """Determine what intervals in good_flags are contaminated (overlap with) intervals in bad_flags. Note this isn't as simple as good_flags * bad_flags, because any contiguous region in good_flags is considered contaminated if even one sample of it is touched by bad_flags. Args: good_flags (RangesMatrix): The flags to check for contamination. Must have shape (dets, samps). bad_flags (RangesMatrix): The flags marking bad data, which should be considered as a contaminant to good_flags. Same shape as good_flags. Returns: RangesMatrix with same shape as inputs, indicating the intervals from good_flags that overlap at all with some interval of bad_flags. Example: Move contaminated intervals into bad_flags:: contam = get_contaminated_ranges(source_flags, glitch_flags) source_flags *= ~contam glitch_flags += contam """ contam = good_flags.zeros_like() for r0, r1, rs in zip(good_flags.ranges, bad_flags.ranges, contam.ranges): overlap = (r0 * r1).ranges() if len(overlap) == 0: continue # Any interval in r0 that overlaps must be moved for i0, i1 in r0.ranges(): if np.any((i0 <= overlap[:,0]) * (overlap[:,0] < i1)): rs.add_interval(int(i0), int(i1)) return contam
def fill_glitches(aman, nbuf=10, use_pca=False, modes=3, signal=None, glitch_flags=None, wrap=True): """ This function fills pre-computed glitches provided by the caller in time-ordered data using either a polynomial (default) or PCA-based approach. Wraps the other functions in the ``tod_ops.gapfill`` module. Args ----- aman : AxisManager AxisManager to fill glitches in nbuf : int Number of buffer samples to use in polynomial gap filling. use_pca : bool Whether or not to fill glitches using pca model. Default is False modes : int Number of modes in the pca to use if pca=True. Default is 3. signal : ndarray or None Array of data to fill glitches in. If None then uses ``aman.signal``. Default is None. glitch_flags : RangesMatrix or None RangesMatrix containing flags to use for gap filling. If None then uses ``aman.flags.glitches``. wrap : bool or str If True wraps new field called ``gap_filled``, if False returns the gap filled array, if a string wraps new field with provided name. Returns ------- signal : ndarray Returns ndarray with gaps filled from input signal. """ # Process Args if signal is None: sig = np.copy(aman.signal) else: sig = np.copy(signal) if glitch_flags is None: glitch_flags = aman.flags.glitches # Polyfill gaps = get_gap_fill(aman, nbuf=nbuf, flags=glitch_flags, signal=np.float32(sig)) sig = gaps.swap(aman, signal=sig) #PCA Fill if use_pca: if modes > aman.dets.count: logger.warning(f'modes = {modes} > number of detectors = ' + f'{aman.dets.count}, setting modes = number of ' + 'detectors') modes = aman.dets.count # fill with poly fill before PCA gaps = get_gap_fill(aman, nbuf=nbuf, flags=glitch_flags, signal=np.float32(sig)) sig = gaps.swap(aman, signal=sig) # PCA fill mod = pca.get_pca_model(tod=aman, n_modes=modes, signal=sig) gfill = get_gap_model(tod=aman, model=mod, flags=glitch_flags) sig = gfill.swap(aman, signal=sig) # Wrap and Return if isinstance(wrap, str): if wrap in aman._assignments: aman.move(wrap, None) aman.wrap(wrap, sig, [(0, 'dets'), (1, 'samps')]) return sig elif wrap: if 'gap_filled' in aman._assignments: aman.move('gap_filled', None) aman.wrap('gap_filled', sig, [(0, 'dets'), (1, 'samps')]) return sig else: return sig