Source code for sotodlib.core.metadata.obsdb

import sqlite3
import os
import numpy as np
import warnings

from .resultset import ResultSet
from . import common
from .. import util


DBROW_ALL = '_all'

[docs] class ObsDb(object): """Observation database. The ObsDb helps to associate observations, indexed by an obs_id, (or obs_id plus some wafer_info such as wafer_slot or bandpass) with properties of the observation that might be useful for selecting data or for identifying metadata. The main ObsDb table is called 'obs', and contains the columns obs_id (string), plus any others deemed important for this context (you will probably find timestamp (float representing a unix timestamp)). Additional columns may be added to this table as needed. The second ObsDb table is called 'tags', and facilitates grouping observations together using string labels. """
[docs] def __init__(self, map_file=None, init_db=True, wafer_info=None): """Instantiate an ObsDb. Args: map_file (str or sqlite3.Connection): If this is a string, it will be treated as the filename for the sqlite3 database, and opened as an sqlite3.Connection. If this is an sqlite3.Connection, it is cached and used. If this argument is None (the default), then the sqlite3.Connection is opened on ':memory:'. init_db (bool): If True, then any ObsDb tables that do not already exist in the database will be created. wafer_info (list of str): The additional primary keys for the obs table. The default is None, which defaults to ['obs_id'] only. An example of an alternative is ['wafer_slot', 'bandpass'] in which case the ObsDb will be indexed by obs_id, wafer_slot, and bandpass. This is only required when first initialializing a database; otherwise the primary fields are determined from the loaded database. Notes: If map_file is provided, the database will be connected to the indicated sqlite file on disk, and any changes made to this object be written back to the file. """ if isinstance(map_file, sqlite3.Connection): self.conn = map_file else: if map_file is None: map_file = ':memory:' self.conn = sqlite3.connect(map_file) self.conn.row_factory = sqlite3.Row # access columns by name if init_db: pkeys = ["`obs_id`"] if wafer_info: pkeys.extend([f"`{k}`" for k in wafer_info]) self._table_defs = {'obs': [ "`timestamp` float", *(f"{k} varchar(256)" for k in pkeys), f"PRIMARY KEY ({', '.join(pkeys)})" ], 'tags': [ "`tag` varchar(256)", *(f"{k} varchar(256)" for k in pkeys), f"PRIMARY KEY ({', '.join(pkeys)}, `tag`)" ]} # Define indices dynamically based on primary keys pkeys_str = ', '.join([k.strip('`') for k in pkeys]) self._indices = { 'idx_obs': f'obs({pkeys_str})', 'idx_tags': f'tags({pkeys_str})', } c = self.conn.cursor() c.execute("SELECT type, name FROM sqlite_master " "WHERE type in ('table', 'index') and name not like 'sqlite_%';") tables = [r[1] for r in c] changes = False for k, v in self._table_defs.items(): if k not in tables: q = ('create table if not exists `%s` (' % k + ','.join(v) + ')') c.execute(q) changes = True for index, cols in self._indices.items(): if index not in tables: c.execute(f'CREATE INDEX IF NOT EXISTS {index} on {cols}') changes = True if changes: self.conn.commit() self.primary_keys = self._get_primary_fields(wafer_info)
def _get_primary_fields(self, wafer_info=None): """Retrieve the primary keys of the specified table. This is used whether to index by obs_id or obs_id plus additional fields defined by wafer_info.""" query = "PRAGMA table_info('obs')" c = self.conn.execute(query) primary_keys = [row['name'] for row in c.fetchall() if row['pk'] > 0] if wafer_info: pkeys = ["obs_id"] pkeys.extend([f"{k}" for k in wafer_info]) if sorted(pkeys) != sorted(primary_keys): # sorted allows for different order raise ValueError(f"Primary keys do not match: {primary_keys} != {pkeys}"+ f" must use `wafer_info`=={primary_keys} or create a new dB with {pkeys}") return primary_keys def _convert_wafer_info(self, obs_id, wafer_info): """Helper function to allow flexibility in way obs_id and wafer_info are passed in.""" if isinstance(wafer_info, dict): wafer_info = tuple([wafer_info[k] for k in self.primary_keys[1:]]) if isinstance(obs_id, tuple): if len(obs_id) == len(self.primary_keys): wafer_info = tuple([wi for wi in obs_id[1:]]) obs_id = obs_id[0] else: raise ValueError(f"obs_id tuple must be of length {len(self.primary_keys)}") if isinstance(obs_id, dict): if len(obs_id) == len(self.primary_keys): wafer_info = tuple([obs_id[k] for k in self.primary_keys[1:]]) obs_id = obs_id['obs_id'] else: raise ValueError(f"obs_id dict must be of length {len(self.primary_keys)}") return obs_id, wafer_info def _warn_primary_keys(self, wafer_info): """Warn the user if the primary keys are not specified and we're defaulting to using _all.""" if len(self.primary_keys) == 1: return [] if len(wafer_info) != len(self.primary_keys) - 1: raise ValueError(f"Wafer info must be of length {len(self.primary_keys) - 1}") if wafer_info is None: wafer_info = [None] * (len(self.primary_keys) - 1) wafer_info = list(wafer_info) if (None in wafer_info): warn_str = 'WARNING: Primary key(s)' for i, wb in enumerate(wafer_info): if wb is None: wafer_info[i] = DBROW_ALL warn_str += f' wafer_info[{i}],' warn_str += f""" are not specified and ObsDb is indexed by {self.primary_keys}. These keys will be set to _all. """ warnings.warn(warn_str, UserWarning) return wafer_info def __len__(self): return self.conn.execute(f'SELECT COUNT({self.primary_keys[0]}) FROM obs').fetchone()[0]
[docs] def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True): """Add columns to the obs table. Args: column_defs (list of pairs of str): Column descriptions, see notes. ignore_duplicates (bool): If true, requests for new columns will be ignored if the column name is already present in the table. Returns: self. Notes: The input format for column_defs is somewhat flexible. First of all, if a string is passed in, it will converted to a list by splitting on ",". Second, if the items in the list are strings (rather than tuples), the string will be broken into 2 components by splitting on whitespace. Finally, each pair of items is interpreted as a (name, data type) pair. The name can be a simple string, or a string inside backticks; so 'timestamp' and '`timestamp`' are equivalent. The data type can be any valid sqlite type expression (e.g. 'float', 'varchar(256)', etc) or it can be one of the three basic python type objects: str, float, int. Here are some examples of valid column_defs arguments:: [('timestamp', float), ('drift', str)] ['`timestamp` float', '`drift` varchar(32)'] 'timestamp float, drift str' """ current_cols = self.conn.execute('pragma table_info(obs)').fetchall() current_cols = [r[1] for r in current_cols] if isinstance(column_defs, str): column_defs = column_defs.split(',') for column_def in column_defs: if isinstance(column_def, str): column_def = column_def.split() name, typestr = column_def if typestr is float: typestr = 'float' elif typestr is int: typestr = 'int' elif typestr is str: typestr = 'text' check_name = name if name.startswith('`'): check_name = name[1:-1] else: name = '`' + name + '`' if check_name in current_cols: if ignore_duplicates: continue raise ValueError("Column %s already exists in table obs" % check_name) self.conn.execute('ALTER TABLE obs ADD COLUMN %s %s' % (name, typestr)) current_cols.append(check_name) if commit: self.conn.commit() return self
[docs] def update_obs(self, obs_id, data={}, tags=[], wafer_info=None, commit=True): """Update an entry in the obs table. Arguments: obs_id (str): The id of the obs to update. wafer_info (tuple of str): The wafer_info used as primary keys in addition to obs_id. The default will be replaced with '_all' all primary keys other than obs_id data (dict): map from column_name to value. tags (list of str): tags to apply to this observation (if a tag name is prefxed with '!', then the tag name will be un-applied, i.e. cleared from this observation. Example of ways to pass updates to obsdb when there are multiple primary keys. 1) obs_id as str and wafer_info as tuple:: obsdb.update_obs('obs_2345_xyz_110', wafer_info=('ws0', 'f090'), ...) 2) obs_id as str and wafer_info as dict:: obsdb.update_obs('obs_2345_xyz_110', wafer_info={'wafer_slot': 'ws0', 'bandpass': 'f090'}, ...) 3) obs_id as dict and wafer_info is None:: obsdb.update_obs({'obs_id': 'obs_2345_xyz_110', 'wafer_slot': 'ws0', 'bandpass': 'f090'}, ...) 4) obs_id as tuple and wafer_info is None:: obsdb.update_obs(('obs_2345_xyz_110', 'ws0', 'f090'), ...) """ obs_id, wafer_info = self._convert_wafer_info(obs_id, wafer_info) obs_key = {'obs_id': obs_id} if (len(self.primary_keys) > 1): wafer_info = self._warn_primary_keys(wafer_info) for i, k in enumerate(self.primary_keys[1:]): obs_key[k] = wafer_info[i] c = self.conn.cursor() columns = ', '.join(obs_key.keys()) placeholders = ', '.join(['?'] * len(obs_key)) c.execute(f'INSERT OR IGNORE INTO obs ({columns}) VALUES ({placeholders})', tuple(obs_key.values())) if len(data.keys()): settors = [f'{k}=?' for k in data.keys()] where_str = ' AND '.join([f'{k}=?' for k in obs_key.keys()]) c.execute(f'UPDATE obs SET {", ".join(settors)} WHERE {where_str}', tuple(data.values()) + tuple(obs_key.values())) for t in tags: if t[0] == '!': # Kill this tag where_str = ' AND '.join([f'{k}=?' for k in obs_key.keys()]) c.execute(f'DELETE FROM tags WHERE {where_str} AND tag=?', tuple(obs_key.values()) + (t[1:],)) else: # Add the tag for the specific primary key combination. columns = ', '.join(list(obs_key.keys()) + ['tag']) placeholders = ', '.join(['?'] * (len(obs_key) + 1)) c.execute(f'INSERT OR REPLACE INTO tags ({columns}) VALUES ({placeholders})', tuple(obs_key.values()) + (t,)) if commit: self.conn.commit() return self
[docs] def copy(self, map_file=None, overwrite=False): """ Duplicate the current database into a new database object, and return it. If map_file is specified, the new database will be connected to that sqlite file on disk. Note that a quick way of writing a Db to disk to call copy(map_file=...) and then simply discard the returned object. """ if map_file is not None and os.path.exists(map_file): if overwrite: os.remove(map_file) else: raise RuntimeError("Output file %s exists (overwrite=True " "to overwrite)." % map_file) new_db = ObsDb(map_file=map_file, init_db=False) script = ' '.join(self.conn.iterdump()) new_db.conn.executescript(script) return new_db
[docs] def to_file(self, filename, overwrite=True, fmt=None): """Write the present database to the indicated filename. Args: filename (str): the path to the output file. overwrite (bool): whether an existing file should be overwritten. fmt (str): 'sqlite', 'dump', or 'gz'. Defaults to 'sqlite' unless the filename ends with '.gz', in which it is 'gz'. """ return common.sqlite_to_file(self.conn, filename, overwrite=overwrite, fmt=fmt)
[docs] @classmethod def from_file(cls, filename, fmt=None, force_new_db=True): """This method calls :func:`sotodlib.core.metadata.common.sqlite_from_file` """ conn = common.sqlite_from_file(filename, fmt=fmt, force_new_db=force_new_db) return cls(conn, init_db=False)
[docs] def get(self, obs_id=None, wafer_info=None, tags=None, add_prefix=''): """Returns the entry for obs_id, as an ordered dict. If obs_id is None, returns all entries, as a ResultSet. However, this usage is deprecated in favor of self.query(). Args: obs_id (str): The observation id to get info for. wafer_info (tuple of str): The wafer_info used as primary keys in addition to obs_id. The default will be replaced with '_all' all primary keys other than obs_id tags (bool): If True, include the tags associated with this observation in the output. The tags will be stored in a field called 'tags', which will be a list of strings. If False or None, the tags will not be included in the output. add_prefix (str): A string that will be prepended to each field name. This is for the lazy metadata system, because obsdb selectors are prefixed with 'obs:'. Returns: An ordered dict with the obs table entries for this obs_id, or None if the obs_id is not found. If tags have been requested, they will be stored in 'tags' as a list of strings. """ obs_id, wafer_info = self._convert_wafer_info(obs_id, wafer_info) if obs_id is None: return self.query('1', add_prefix=add_prefix) wafer_info = self._warn_primary_keys(wafer_info) query_text = " AND ".join([f"{key} == '{val}'" for key, val in zip(self.primary_keys, [obs_id] + wafer_info)]) results = self.query(query_text, add_prefix=add_prefix) if len(results) == 0: return None if len(results) > 1: raise ValueError('Too many rows...') # or integrity error... output = results[0] if tags: # "distinct" should not be needed given uniqueness constraint. where_str = ' AND '.join([f"{k}='{v}'" for k, v in zip(self.primary_keys, [obs_id] + list(wafer_info))]) c = self.conn.execute(f'SELECT DISTINCT tag FROM tags WHERE {where_str}') output['tags'] = [r[0] for r in c] return output
[docs] def query(self, query_text='1', tags=None, sort=['obs_id'], add_prefix=''): """Queries the ObsDb using user-provided text. Returns a ResultSet. Args: query_text (str): The sqlite query string. All fields should refer to the obs table, or to tags explicitly listed in the tags argument. tags (list of str): Tags to include in the output; if they are listed here then they can also be used in the query string. Filtering on tag value can be done here by appending '=0' or '=1' to a tag name. Returns: A ResultSet with one row for each Observation matching the criteria. Notes: Tags are added to the output on request. For example, passing tags=['planet','stare'] will cause the output to include columns 'planet' and 'stare' in addition to all the columns defined in the obs table. The value of 'planet' and 'stare' in each row will be 0 or 1 depending on whether that tag is set for that observation. We can include expressions involving planet and stare in the query, for example:: obsdb.query('planet=1 or stare=1', tags=['planet', 'stare']) For simple filtering on tags, pass '=1' or '=0', like this:: obsdb.query(tags=['planet=1','hwp=1']) When filtering is activated in this way, the returned results must satisfy all the criteria (i.e. the individual constraints are AND-ed). If your tag name contains special characters (e.g. '-'), you will need to enclose it in backticks when using it in a query string, e.g.: obsdb.query('`bad-tag`=1', tags=['bad-tag']) For this reason, we generally advise against the use of non-alphanumeric characters in tags. """ sort_text = '' if sort is not None and len(sort): sort_text = ' ORDER BY ' + ','.join(sort) if '"' in query_text: warnings.warn('obsdb.query text contains double quotes (") -- ' 'replacing with single quotes (\').') query_text = query_text.replace('"', "'") joins = '' extra_fields = [] if tags is not None and len(tags): for tagi, t in enumerate(tags): if '=' in t: t, val = t.split('=') else: val = None if val is None: join_type = 'left join' extra_fields.append(f"ifnull(tt{tagi}.obs_id,'') != '' as '{t}'") elif val == '0': join_type = 'left join' extra_fields.append(f"ifnull(tt{tagi}.obs_id,'') != '' as '{t}'") query_text += f' and "{t}"=0' else: join_type = 'join' extra_fields.append(f'1 as "{t}"') joins += (f" {join_type} (select distinct obs_id from tags where tag='{t}') as tt{tagi} on " f"obs.obs_id = tt{tagi}.obs_id") extra_fields = ''.join([','+f for f in extra_fields]) q = 'select obs.* %s from obs %s where %s %s' % (extra_fields, joins, query_text, sort_text) c = self.conn.execute(q) results = ResultSet.from_cursor(c) if add_prefix is not None: results.keys = [add_prefix + k for k in results.keys] return results
[docs] def query_linked_dbs(self, secondary_dbs, query_text, add_prefix='', wafer_info=None): """ Query two ObsDb objects and link their results based on obs_id. Primary ObsDb can be either keyed by obs_id or obs_id and wafer_info (such as wafer_slot and bandpass). For every row returned from the primary database, the linked secondary databases are queried for rows with the same obs_id (and a specific wafer_info subset if wafer_info is passed). The results are returned as a list of tuples the first element of the tuple is the primary database result and the rest are the linked secondary database results. Args: secondary_dbs (list of ObsDb): A list of secondary database to query for linked rows. If a single ObsDb is passed, it will be converted to a list of length 1. query_text (str): The query text for the primary database. add_prefix (str): A string to prepend to field names in the result. wafer_info (tuple of str): The wafer_info to restrict what's returned from the secondary database. The default value is None, which means all wafer_info will be returned. Returns: results (list of ResultSet): A list containing tuples of resultsets from the primary and secondary databases. """ # Ensure secondary_dbs is a list if not isinstance(secondary_dbs, list): secondary_dbs = [secondary_dbs] # Query the primary database primary_results = self.query(query_text, add_prefix=add_prefix) if len(primary_results) == 0: return None results = [] for pr in primary_results: _res = (pr, ) for secondary_db in secondary_dbs: if wafer_info: _wafer_info = secondary_db._warn_primary_keys(wafer_info) query_str = ' and '.join([f"{k}=='{v}'" for k, v in zip(secondary_db.primary_keys, [pr['obs_id']] + list(_wafer_info))]) else: query_str = f"obs_id=='{pr['obs_id']}'" secondary_result = secondary_db.query(query_str, add_prefix=add_prefix) _res += ([sr for sr in secondary_result],) results.append(_res) return results
[docs] def info(self): """Return a dict summarizing the structure and contents of the obsdb; this is used by the CLI. """ def _short_list(items, max_len=40): i, acc, keepers = 0, 0, [] while (len(keepers) < 1 or acc < max_len) and i < len(items): keepers.append(str(items[i])) i += 1 acc += len(keepers[-1]) + 2 return ('[' + ', '.join(map(str, keepers)) + (' ...' * (i < len(items))) + ']') # Summarize the fields ... rs = self.query() fields = {} for k in rs.keys: items = np.unique(rs[k]) fields[k] = (len(items), _short_list(items)) # Count occurances of each tag ... c = self.conn.execute('select tag, count(obs_id) from tags group by tag order by tag') tags = {r[0]: r[1] for r in c} return { 'count': len(rs), 'fields': fields, 'tags': tags, }
def diff_obsdbs(obsdb_left, obsdb_right, return_detail=False): """Examine all records in two obsdbs and construct a list of changes that could made to obsdb_left in order to make it match obsdb_right. Returns a dict with following entries: - ``different`` (bool): whether the two databases carry different information. - ``patchable`` (bool): whether the function was able to construct patching instructions. - ``unpatchable_reason`` (str): if not patchable, a string explaining why. - ``detail`` (various): if not patchable, and return_detail, then this will contain detail about the offending data (e.g. obs rows in the two dbs that contain discrepant data). - ``patch_data`` (dict): if patchable, the data needed to patch obsdb_left. The fields are: - ``remove_obs`` (list of obs_id): entries to remove from obs table. - ``remove_tags`` (list of tuple): entries to remove from tags table. - ``new_obs`` (list of dict): rows of new data for obs table -- each dict can be passed directly to obsdb.update_obs. - ``new_tags`` (list of tuple): rows of data for tags table (obs_id, tag). Notes: In the present implementation, only changes involving adding rows to obsdb_left (either whole obs rows or tag rows) will yield a patchable result. Cases where some data has changed, or obs or tags have been deleted, will simply return as unpatchable. This is probably pretty easy to extend, should the need arise. """ if isinstance(obsdb_left, str): obsdb_left = ObsDb.from_file(obsdb_left, force_new_db=False) if isinstance(obsdb_right, str): obsdb_right = ObsDb.from_file(obsdb_right, force_new_db=False) def failure_declaration(reason, detail=None): if not return_detail: detail = None return {'different': True, 'patchable': False, 'unpatchable_reason': reason, 'detail': detail} full = [db.query() for db in [obsdb_left, obsdb_right]] common_cols = full[1].keys if not (set(full[0].keys) >= set(common_cols)): return failure_declaration( 'obsdb_left is missing some columns found in obsdb_right.', detail=set(common_cols).difference(full[0].keys)) # Convert to arrays. obs_ids = [set(f['obs_id']) for f in full] # Insist right is superset of left. left_not_right = sorted(list(obs_ids[0].difference(obs_ids[1]))) if len(left_not_right): return failure_declaration( f'obsdb_left contains {len(left_not_right)} ' 'obs not found in obsdb_right.', detail=left_not_right) # Any obs in common? unmatched_right = np.ones(len(full[1]), bool) common = sorted(list(obs_ids[0].intersection(obs_ids[1]))) if len(common): common, i0, i1 = util.get_coindices(*(f['obs_id'] for f in full)) diffs = [] Li, Ri = ([_f.keys.index(k) for k in common_cols] for _f in full) for i, (_i0, _i1) in enumerate(zip(i0, i1)): Lrow, Rrow = full[0].rows[_i0], full[1].rows[_i1] L = tuple(Lrow[_i] for _i in Li) R = tuple(Rrow[_i] for _i in Ri) if L != R: diffs.append((L, R)) if len(diffs): return failure_declaration( f'obsdb_left and obsdb_right have {len(diffs)} obs ' 'in common, with different data.', detail=diffs) unmatched_right[i1] = False # Ok finally pd = { 'remove_obs': [], 'remove_tags': [], 'new_obs': [], 'new_tags': [], } for idx in unmatched_right.nonzero()[0]: pd['new_obs'].append(full[1][idx]) # Tag check. tags_tuples = [ list(map(tuple, db.conn.execute( 'select distinct obs_id, tag from tags ' 'order by obs_id, tag').fetchall())) for db in [obsdb_left, obsdb_right]] # Collapse tags to single strings and eliminate duplicates. DELIM = ':::/:::' common, i0, i1 = util.get_coindices(*[[t[0] + DELIM + t[1] for t in tt] for tt in tags_tuples]) if len(i0) != len(tags_tuples[0]): return failure_declaration( f'obsdb_left contains {len(tags_tuples[0]) - len(i0)} ' 'tags not found in obsdb_right', detail=list(set(tags_tuples[0]).difference(tags_tuples[1]))) unmatched_right = np.ones(len(tags_tuples[1]), bool) unmatched_right[i1] = False for idx in unmatched_right.nonzero()[0]: pd['new_tags'].append(tags_tuples[1][idx]) different = any([(len(v) != 0) for v in pd.values()]) return { 'different': different, 'patchable': True, 'patch_data': pd, } def patch_obsdb(patch_data, target_db): """Update an ObsDb with a batch of changes. Args: target_db (ObsDb): the database where changes should be made. patch_data (dict): patch information, as returned by diff_obsdbs. """ assert len(patch_data['remove_obs']) == 0 assert len(patch_data['remove_tags']) == 0 for obs_entry in patch_data['new_obs']: target_db.update_obs(obs_entry['obs_id'], obs_entry, commit=False) # Group new tags by obs. tags_obsed = {} for k, v in patch_data['new_tags']: if k not in tags_obsed: tags_obsed[k] = [v] else: tags_obsed[k].append(v) for obs, tags in tags_obsed.items(): target_db.update_obs(obs, {}, tags=tags, commit=False) target_db.conn.commit()