import functools
import logging
import glob
import os
import string
import time
import h5py
import numpy as np
from tqdm import tqdm
from sotodlib.core.metadata import ResultSet, ObsDb
from sotodlib.io import metadata, hkdb
logger = logging.getLogger(__name__)
[docs]
def denumpy(output):
"""Traverse a container, recursively, and convert any numpy
scalars to generic python scalars. Note ndarrays are not converted
-- the denumpification is only applied to scalars nested in
standard python containers (dict/list/tuple).
"""
if isinstance(output, dict):
return {k: denumpy(v)
for k, v in output.items()}
if isinstance(output, list):
return [denumpy(x) for x in output]
if isinstance(output, tuple):
return tuple((denumpy(x) for x in output))
if isinstance(output, np.generic):
return output.item()
return output
[docs]
def intersect_ranges(r1, r2, empty_as_none=False):
"""Return the intersection of semi-open intervals r1 = [a, b) and r2 =
[c, d). If empty_as_none, then empty intervals are returned as
None instead of [e, e).
"""
if r2[0] < r1[0]:
r1, r2 = r2, r1
if r1[1] < r2[0]:
out = (r1[0], r1[0])
else:
out = (max(r1[0], r2[0]), min(r1[1], r2[1]))
if out[0] == out[1] and empty_as_none:
return None
return out
[docs]
def merge_rs(rs0, rs1):
"""Merge rows from rs1 into rs0 (both ResultSets), sorting rows by
field 'timestamp'. Because the primary application is to update
some dataset based on expanding source data, rows that repeat some
'timestamp' value are dropped, with precedence given to data from
rs0.
"""
i0, i1 = 0, 0
G0 = [(t, -1, i) for i, t in enumerate(rs0['timestamp'])]
G1 = [(t, i, -1) for i, t in enumerate(rs1['timestamp'])]
G = sorted(G0 + G1)
# De-duplicate
rows_out = []
last_t = None,
for t, i1, i0 in G:
if t == last_t:
continue
if i0 >= 0:
rows_out.append(rs0.rows[i0])
else:
rows_out.append(rs1.rows[i1])
last_t = rows_out[-1][0]
rs0.rows = rows_out
def get_engine(key, cfg):
from . import ANCIL_ENGINES
_cfg = {k: v for k, v in cfg.items()
if k not in ['class']}
if 'class' in cfg:
class_name = cfg['class']
else:
class_name = key
cls = ANCIL_ENGINES[class_name]
return cls(_cfg)
def _get_time_range(*time_ranges, now=None):
if now is None:
now = time.time()
for time_range in time_ranges:
if time_range is None:
continue
return tuple(now if t is None else t for t in time_range)
return (now, now)
class AncilEngine:
#: The class used to provide configuration for this Engine
#: type. (This is handled in subclassing.)
config_class = None
#: The configuration object (matching config_class) used to set up
#: the instance.
cfg = None
#: A dict mapping name-of-friend to the Engine instance for each
#: friend. On construction, the values will be None -- they must be
#: associated using register_friends.
friends = None
_fields = []
def __init__(self, cfg):
if isinstance(cfg, dict):
cfg = self.config_class(**cfg)
self.cfg = cfg
if self.cfg.friends:
self.friends = {k: None for k in self.cfg.friends}
else:
self.friends = {}
def register_friends(self, friend_configs):
for k in self.friends.keys():
self.friends[k] = friend_configs[k]
def _get_friend(self, friend):
if friend not in self.friends:
raise RuntimeError(f'Friend not configured: "{friend}"')
if isinstance(self.friends[friend], dict):
return get_engine(friend, self.friends[friend])
return self.friends[friend]
@property
def obsdb_fields(self):
if self.cfg.obsdb_format:
# transform (bare_field, ...) into (obsdb_field, ...)
return [
(self.cfg.obsdb_format.format(
dataset=self.cfg.dataset_name, field=item[0]),) + item[1:]
for item in self._fields]
return self._fields
def _obsdb_map(self, target=None):
if target:
return {v[0]: target[k[0]]
for k, v in zip(self._fields, self.obsdb_fields)}
return {k[0]: v[0] for k, v in zip(self._fields, self.obsdb_fields)}
def obsdb_query(self, time_range=None, redo=False):
"""Get obsdb query string to identify records that need recomputation.
If time_range is not None, then screening on timestamp will be
included. If redo is True, then the engine-specific value
testing will be skipped and all records in the time_range will
be queried.
"""
if redo:
vquery = '1'
else:
vquery = self.cfg.obsdb_query
assert vquery is not None
if time_range is not None:
t0, t1 = time_range
if t0 is None:
tquery = f'timestamp < {t1}'
elif t1 is None:
tquery = f'timestamp >= {t0}'
else:
tquery = f'(timestamp >= {t0}) and (timestamp < {t1})'
vquery = f'{tquery} and {vquery}'
return vquery.format(**self._obsdb_map())
def obsdb_check(self, obsdb, create_cols=False):
"""Check whether the obsdb contains the columns required by
this engine's "update_obsdb" operation. If create_cols is
True, the columns will be created if they are missing and the
function returns True. Otherwise the function will simply
return True if cols are there, and False if not.
"""
ok = False
try:
obsdb.conn.execute('select %s from obs limit 1' % (
','.join(['`%s`' % f[0] for f in self.obsdb_fields])))
ok = True
except:
pass
if not ok and create_cols:
for field_row in self.obsdb_fields:
k, t = field_row[:2]
obsdb.add_obs_columns([f'{k} {t}'])
return True
return ok
@property
def _base_dir(self):
"""Return the directory where base data should be stored."""
dpre, ddir = self.cfg.data_prefix, self.cfg.data_dir
if ddir is None:
ddir = self.cfg.dataset_name
if dpre is not None:
ddir = os.path.join(dpre, ddir)
return ddir
def check_base(self):
return {}
def update_base(self, time_range=None, reset=False):
"""Update the base dataset, for the indicated time_range.
If reset is True, then new download / computation replaces all
data in the time_range.
"""
pass
def getter(self, targets=None, results=None):
"""Generator that yields results, one by one, for entry in
target. If results is provided, it must be a list of the same
length as targets, containing dicts which will be updated in
place and yielded.
"""
raise NotImplementedError()
def collect(self, targets=None, results=None, show_pbar=False, for_obsdb=False):
output = list(self.getter(targets=tqdm(targets, disable=not show_pbar),
results=results))
remap = self._obsdb_map()
if for_obsdb and self.cfg.obsdb_format:
for i in range(len(output)):
output[i] = {remap[k]: v for k, v in output[i].items()}
return output
def _target_obs_ids(self, targets):
for t in targets:
yield t['obs_id']
def _target_time_ranges(self, targets):
for t in targets:
yield (t['start_time'], t['stop_time'])
[docs]
class LowResTable(AncilEngine):
"""Helper class for archiving fairly low resolution data into a set of
HDF5 files, as the "base data".
"""
def _get_raw(self, time_range):
raise NotImplementedError()
[docs]
def update_base(self, time_range=None, reset=False):
"""Attempt to patch any gaps in the dataset, for time_range, by
grabbing data from the source.
If reset is True, then the grabbed data replaces any archived
data in that time_range.
See _get_raw for the specific activity of the subclass.
"""
time_range = _get_time_range(time_range, self.cfg.dataset_time_range)
dataset = self.cfg.dataset_name
for t0, t1, filename in self._get_filenames(time_range):
if not os.path.exists(filename) or reset:
rs0 = ResultSet(keys=['timestamp', 'pwv'])
os.makedirs(os.path.dirname(filename), exist_ok=True)
else:
rs0 = metadata.read_dataset(filename, dataset)
if reset:
gap_ranges = [time_range]
else:
# Are there gaps that cross the target range?
super_t = np.hstack((t0, rs0['timestamp'], t1))
gap_t = np.diff(super_t)
gaps = (gap_t >= self.cfg.gap_size).nonzero()[0]
gap_ranges = [(super_t[i], super_t[i+1]) for i in gaps]
gap_ranges = list(filter(lambda x: x[1] > x[0],
[intersect_ranges(time_range, g)
for g in gap_ranges]))
if len(gap_ranges) == 0:
continue
# One query for this whole thing ...
query_range = (gap_ranges[0][0], gap_ranges[-1][1])
logger.info(f'Pulling {dataset} data for %s (%.1f, %1.f) ...' % (
filename, query_range[0], query_range[1]))
rs = self._get_raw((query_range[0], query_range[1]))
# Filter down to only data that's inside a gap range.
t = rs['timestamp']
mask = np.zeros(len(t), bool)
for t0, t1 in gap_ranges:
mask += (t0 <= t) * (t < t1)
rs = rs.subset(rows=mask)
# Merge into rs0.
before = len(rs0)
merge_rs(rs0, rs)
after = len(rs0)
logger.info(' ... row count %i -> %i' % (before, after))
if after > before:
logger.info(' ... writing new %s' % filename)
rs0_a = rs0.asarray(dtypes=self.cfg.dtypes)
metadata.write_dataset(rs0_a, filename, dataset, overwrite=True)
[docs]
def check_base(self):
"""Check the base data -- determine output directory, count files
therein, take note of extra files, etc.
"""
info = {
'output_dir': self._base_dir,
'files_found': 0,
}
time_range = _get_time_range(self.cfg.dataset_time_range)
for t0, t1, filename in self._get_filenames(time_range):
if os.path.exists(filename):
info['files_found'] += 1
return info
def _get_filenames(self, time_range):
ns = int(self.cfg.archive_block_seconds)
t0 = int((time_range[0] // ns) * ns)
t1 = time_range[1]
rows = []
while t1 > t0:
fn = self.cfg.filename_pattern.format(
timestamp=t0,
dataset_name=self.cfg.dataset_name)
fn = os.path.join(self._base_dir, fn)
rows.append((max(t0, time_range[0]), min(t1, t0 + ns), fn))
t0 += ns
return rows
def _load(self, time_range):
# If you don't cache, the dataset reads can be very
# inefficient, especially when looping over a bunch of obs
# from the same time period.
@functools.lru_cache(maxsize=4)
def _get_dataset(filename, dataset):
if not os.path.exists(filename):
return None
return metadata.read_dataset(filename, dataset)
rs = None
for t0, t1, filename in self._get_filenames(time_range):
_rs = _get_dataset(filename, self.cfg.dataset_name)
if _rs is not None:
if rs is None:
rs = _rs
else:
rs.rows.extend(_rs.rows)
if rs is None:
rs = ResultSet(keys=['timestamp', 'pwv'])
t = rs['timestamp']
s = (time_range[0] <= t) * (t < time_range[1])
if not np.all(s):
rs = rs.subset(rows=s)
return rs
[docs]
def getter(self, targets=None, **kwargs):
raise NotImplementedError()
def get_example_obsdb(start, end=None, step=3600, filename=None):
obsdb = ObsDb()
obsdb.add_obs_columns(['timestamp float', 'start_time float', 'stop_time float'])
if end is None:
end = start + step * 10
for t in np.arange(start, end, step):
obs_id = f'obs_{t:.0f}'
obsdb.update_obs(obs_id, denumpy({'timestamp': t,
'start_time': t,
'stop_time': t + step,
}))
if filename is not None:
obsdb.to_file(filename)
return obsdb