Source code for sotodlib.site_pipeline.multilayer_preprocess_tod

import os
import yaml
import time
import logging
from typing import Optional, Union, Callable, List
import numpy as np
import argparse
import traceback
from typing import Optional
from sotodlib.utils.procs_pool import get_exec_env
import h5py
import copy
from tqdm import tqdm
from sotodlib.coords import demod as demod_mm
from sotodlib.hwp import hwp_angle_model
from sotodlib import core
from sotodlib.core.metadata.manifest import MultiDbBatchManager
from sotodlib.site_pipeline.jobdb import JobManager, JState
from sotodlib.preprocess import _Preprocess, Pipeline, processes
import sotodlib.preprocess.preprocess_util as pp_util
from sotodlib.preprocess.preprocess_util import PreprocessErrors
from sotodlib.site_pipeline.utils.pipeline import main_launcher
from sotodlib.site_pipeline.utils.obsdb import get_obslist


logger = pp_util.init_logger("preprocess")


[docs] def multilayer_preprocess_tod(obs_id: str, configs_init: Union[str, dict], configs_proc: Union[str, dict], group: list, verbosity: int = 0, compress: bool = False, overwrite: bool = False): """Meant to be run as part of a batched script, this function calls the preprocessing pipeline a specific Observation ID and group combination and saves the results in the ManifestDb specified in the configs. Arguments ---------- obs_id : str or ResultSet entry obs_id or obs entry that is passed to context.get_obs configs_init : str or dict Config file or loaded config dictionary for first layer database. configs_proc : str or dict Config file or loaded config dictionary for second layer database. group : list The group to be run. For example, this might be ['ws0', 'f090'] if ``group_by`` (specified by the subobs->use key in the preprocess config) is ['wafer_slot', 'wafer.bandpass']. verbosity : str The log level to use. 0 = error, 1 = warn, 2 = info, 3 = debug. compress : bool Whether or not to compress the preprocessing h5 files. overwrite : bool If True, overwrite contents of temporary h5 files. Returns ------- out_dict_init : dict or None Dictionary output for init config from get_preproc_group_out_dict if preprocessing ran successfully for init layer or ``None`` if preprocessing was loaded or ``preproc_or_load_group`` failed. out_dict_proc : dict or None Dictionary output for proc config from get_preproc_group_out_dict if preprocessing ran successfully for proc layer or ``None`` if preprocessing was loaded, that layer was not run or loaded, or ``preproc_or_load_group`` failed. errors : tuple A tuple containing the error from PreprocessError, an error message, and the traceback. Each will be None if preproc_or_load_group finished successfully. """ logger = pp_util.init_logger("preprocess", verbosity=verbosity) group_by = np.atleast_1d(configs_proc['subobs'].get('use', 'detset')) dets = {gb:gg for gb, gg in zip(group_by, group)} aman, out_dict_init, out_dict_proc, errors = pp_util.preproc_or_load_group( obs_id=obs_id, configs_init=configs_init, dets=dets, configs_proc=configs_proc, logger=logger, overwrite=overwrite, save_archive=False, compress=compress, ) return out_dict_init, out_dict_proc, errors
def _check_init_jobdb( jdb, init_db, init_jobs_map, obs_id, group, group_by, overwrite=False, ): """ Check if a job exists in the init jobdb. Create it if it doesn't or open it if overwritting or it is not in the init db. Arguments ---------- jdb : JobManager A preprocesing jobdb. init_db : ManifestDb or None Init preproc database. init_jobs_map : dict Mapping from jobdb tags to existing init jobs. obs_id : str The obs_id for jobdb entry. group : list The group for the jobdb entry. group_by : list Keys defining detector grouping (e.g. ["wafer", "band"]). overwrite : bool, optional Whether to reopen jobs even if they already exist in init_db. Returns ------- failed_job : bool True if a failed job should be recorded as failed. job : jobdb.Job or None New job to add to init db. None if none are to be added. """ # build tags tags = {'obs:obs_id': obs_id} for gb, g in zip(group_by, group): tags[f'dets:{gb}'] = g init_job = init_jobs_map.get(frozenset(tags.items())) # make job if it doesn't exist if init_job is None: tags["error"] = None init_job = jdb.create_job( jclass="init", tags=tags, commit=False, check_existing=False, ) return False, init_job # if job is open, don't need to do anything if init_job.jstate not in (JState.done, JState.failed): return False, None # check whether this job exists in init_db found = True if init_job.jstate == JState.done and init_db is not None: x = init_db.inspect({'obs:obs_id': obs_id}) found = any( [a[f'dets:{gb}'] for gb in group_by] == group for a in x ) # reopen job if overwriting or missing from init_db if overwrite or not found: with jdb.locked(init_job) as j: j.jstate = "open" for _t in j._tags: if _t.key == "error": _t.value = None return False, None # otherwise mark job as failed so we don't rerun it if init_job.jstate == JState.failed: return True, None return False, None def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"], as_completed_callable: Callable, configs_init: str, configs_proc: str, query: str = '', obs_id: Optional[str] = None, overwrite: bool = False, min_ctime: Optional[int] = None, max_ctime: Optional[int] = None, update_delay: Optional[int] = None, tags: Optional[str] = None, planet_obs: bool = False, verbosity: Optional[int] = None, nproc: int = 4, compress: bool = False, run_from_jobdb: bool = False, raise_error: bool = False, pb_path: Optional[str] = None): init_temp_subdir = "temp" proc_temp_subdir = "temp_proc" tqdm.monitor_interval = 0 configs_init, context_init = pp_util.get_preprocess_context(configs_init) configs_proc, context_proc = pp_util.get_preprocess_context(configs_proc) logger = pp_util.init_logger("preprocess", verbosity=verbosity) for fname in (configs_init['archive']['policy']['filename'], configs_proc['archive']['policy']['filename'] ): os.makedirs(os.path.dirname(fname), exist_ok=True) # jobdb jobdb_path = configs_proc.get("jobdb", None) if jobdb_path is not None: jdb = JobManager(sqlite_file=jobdb_path) # get init jobs new_init_jobs = [] init_jobs = jdb.get_jobs(jclass="init") init_jobs_map = { frozenset({k: v for k, v in j.tags.items() if k != 'error'}.items()): j for j in init_jobs } elif run_from_jobdb: raise ValueError(f"Need a jobdb if using it to make a run_list.") errlog = os.path.join(os.path.dirname(configs_proc['archive']['index']), 'errlog_proc.txt') if jobdb_path is not None and run_from_jobdb: if not overwrite: job_list = jdb.get_jobs("proc", jstate=JState.open) else: job_list = jdb.get_jobs("proc") obs_list = list(set([j.tags['obs:obs_id'] for j in job_list])) if len(obs_list) == 0: logger.warning(f"No open jobs in jobdb.") return else: obs_list = get_obslist(context_proc, query=query, obs_id=obs_id, min_ctime=min_ctime, max_ctime=max_ctime, update_delay=update_delay, tags=tags, planet_obs=planet_obs) obs_list = [obs['obs_id'] for obs in obs_list] if len(obs_list) == 0: logger.warning(f"No observations returned from query: {query}") return # clean up lingering files from previous incomplete runs init_policy_dir = os.path.join(os.path.dirname( configs_init['archive']['policy']['filename']), init_temp_subdir ) proc_policy_dir = os.path.join(os.path.dirname( configs_proc['archive']['policy']['filename']), proc_temp_subdir ) for obs_id in obs_list: pp_util.cleanup_obs(obs_id, init_policy_dir, errlog, configs_init, context_init, subdir=init_temp_subdir, remove=overwrite) pp_util.cleanup_obs(obs_id, proc_policy_dir, errlog, configs_proc, context_proc, subdir=proc_temp_subdir, remove=overwrite) # remove datasets from final archive file not found in init db for config in (configs_init, configs_proc): pp_util.cleanup_archive(config, logger) run_list = [] group_by = np.atleast_1d(configs_proc['subobs'].get('use', 'detset')) if overwrite or not os.path.exists(configs_init['archive']['index']): init_db = None else: init_db = core.metadata.ManifestDb(configs_init['archive']['index']) if overwrite or not os.path.exists(configs_proc['archive']['index']): proc_db = None else: proc_db = core.metadata.ManifestDb(configs_proc['archive']['index']) if not run_from_jobdb: futures = [] futures_dict = {} for obs_id in obs_list: futures.append(executor.submit(pp_util.get_groups, obs_id, configs_proc)) futures_dict[futures[-1]] = obs_id for future in tqdm(as_completed_callable(futures), total=len(futures), desc="building run list from obs list"): obs_id = futures_dict[future] _, groups, _ = future.result() if proc_db is not None and not overwrite: x = proc_db.inspect({'obs:obs_id': obs_id}) if x is not None and len(x) != 0 and len(x) != len(groups): [groups.remove([a[f'dets:{gb}'] for gb in group_by]) for a in x] for group in groups: if 'NC' not in group: failed_job = False if jobdb_path is not None: failed_job, new_init_job = _check_init_jobdb( jdb, init_db, init_jobs_map, obs_id, group, group_by, overwrite=False, ) if new_init_job: new_init_jobs.append(new_init_job) if not failed_job or overwrite: run_list.append((obs_id, group)) # filter by jobdb status if jobdb_path is not None: jdb.commit_jobs(new_init_jobs) run_list = pp_util.filter_preproc_runlist_by_jobdb( jdb=jdb, jclass="proc", db=proc_db, run_list=run_list, group_by=group_by, overwrite=overwrite, logger=logger ) else: for job in job_list: obs_id = job.tags["obs:obs_id"] groups = [[job.tags['dets:'+g] for g in group_by]] if overwrite: with jdb.locked(job) as j: j.jstate = JState.open for _t in j._tags: if _t.key == "error": _t.value = None elif proc_db is not None: x = proc_db.inspect({'obs:obs_id': obs_id}) if x is not None and len(x) != 0 and len(x) != len(groups): for a in x: entry = [a[f'dets:{gb}'] for gb in group_by] if entry in groups: groups.remove(entry) # if it was found in the db but is still in open # jobs, then it was added by cleanup_obs and # should be set to done. with jdb.locked(job) as j: j.jstate = JState.done for group in groups: failed_job = False if jobdb_path is not None: failed_job, new_init_job = _check_init_jobdb( jdb, init_db, init_jobs_map, obs_id, group, group_by, overwrite=False, ) if new_init_job: new_init_jobs.append(new_init_job) if not failed_job or overwrite: run_list.append((obs_id, group)) jdb.commit_jobs(new_init_jobs) if len(run_list) == 0: logger.info("Nothing to run") return logger.info(f'Run list created with {len(run_list)} obsid groups') # ensure dbs exist up front to prevent race conditions db_init = pp_util.get_preprocess_db(configs_init, group_by, logger) db_proc = pp_util.get_preprocess_db(configs_proc, group_by, logger) futures = [] futures_dict = {} obs_errors = {} for r in run_list: futures.append( executor.submit(multilayer_preprocess_tod, obs_id=r[0], configs_init=configs_init, configs_proc=configs_proc, group=r[1], verbosity=verbosity, compress=compress, overwrite=overwrite, ) ) futures_dict[futures[-1]] = (r[0], r[1]) if r[0] not in obs_errors: obs_errors[r[0]] = [] total = len(futures) batch_size_init = configs_init['archive'].get('batch_size', 1) batch_size_proc = configs_proc['archive'].get('batch_size', 1) pb_name = os.path.join(pb_path or '', f"pb_{str(int(time.time()))}.txt") with open(pb_name, 'w') as f: with MultiDbBatchManager( [db_init, db_proc], batch_size=[batch_size_init, batch_size_proc], logger=logger ) as (db_mgr_init, db_mgr_proc): for future in tqdm(as_completed_callable(futures), total=total, desc="multilayer_preprocess_tod", file=f, miniters=max(1, total // 100)): obs_id, group = futures_dict[future] out_meta = (obs_id, group) try: out_dict_init, out_dict_proc, errors = future.result() obs_errors[obs_id].append({'group': group, 'error': errors[0]}) logger.info(f"{obs_id}: {group} extracted successfully") except Exception as e: errmsg, tb = PreprocessErrors.get_errors(e) logger.error(f"Executor Future Result Error for {obs_id}: {group}:\n{errmsg}\n{tb}") obs_errors[obs_id].append({'group': group, 'error': PreprocessErrors.ExecutorFutureError}) out_dict_init = None out_dict_proc = None errors = (PreprocessErrors.ExecutorFutureError, errmsg, tb) futures.remove(future) # only run if first layer was run if out_dict_init is not None: logger.info(f"Adding future result to init db for {obs_id}: {group}") pp_util.cleanup_mandb(out_dict_init, out_meta, errors, configs_init, logger, overwrite, db_manager=db_mgr_init) logger.info(f"Adding future result to proc db for {obs_id}: {group}") pp_util.cleanup_mandb(out_dict_proc, out_meta, errors, configs_proc, logger, overwrite, db_manager=db_mgr_proc) # update jobdb if jobdb_path is not None: tags = {} tags["obs:obs_id"] = obs_id for gb, g in zip(group_by, group): tags['dets:' + gb] = g jobs = jdb.get_jobs(jstate=JState.open, tags=tags) for job in jobs: # init layer state will be JState.done if already run if job.jstate == JState.open: with jdb.locked(job) as j: j.mark_visited() if errors[0] is not None: j.jstate = JState.failed for _t in j._tags: if _t.key == "error": _t.value = errors[0] else: j.jstate = JState.done if raise_error: n_obs_fail = 0 n_groups_fail = 0 for obs_id, out_meta in obs_errors.items(): if all(entry['error'] is not None for entry in out_meta): n_obs_fail += 1 n_groups_fail += len(out_meta) else: for entry in out_meta: if entry['error'] is not None: n_groups_fail += 1 if n_groups_fail > 0: raise RuntimeError(f"multilayer_preprocess_tod ended with {n_obs_fail}/{len(obs_errors)} " f"failed obsids and {n_groups_fail}/{len(run_list)} failed groups") logger.info("multilayer_preprocess_tod is done") def get_parser(parser=None): if parser is None: parser = argparse.ArgumentParser() parser.add_argument('configs_init', help="Preprocessing Configuration File for first layer database") parser.add_argument('configs_proc', help="Preprocessing Configuration File for second layer database") parser.add_argument( '--query', help="Query to pass to the observation list. Use \\'string\\' to " "pass in strings within the query.", type=str ) parser.add_argument( '--obs-id', help="obs-id of particular observation if we want to run on just one" ) parser.add_argument( '--overwrite', help="If true, overwrites existing entries in the database", action='store_true', ) parser.add_argument( '--min-ctime', help="Minimum timestamp for the beginning of an observation list", ) parser.add_argument( '--max-ctime', help="Maximum timestamp for the beginning of an observation list", ) parser.add_argument( '--update-delay', help="Number of days (unit is days) in the past to start observation list.", type=int ) parser.add_argument( '--tags', help="Observation tags. Ex: --tags 'jupiter' 'setting'", nargs='*', type=str ) parser.add_argument( '--planet-obs', help="If true, takes all planet tags as logical OR and adjusts related configs", action='store_true', ) parser.add_argument( '--verbosity', help="increase output verbosity. 0:Error, 1:Warning, 2:Info(default), 3:Debug", default=2, type=int ) parser.add_argument( '--nproc', help="Number of parallel processes to run on.", type=int, default=4 ) parser.add_argument( '--compress', help="Compress preprocessing database.", action='store_true', default=False ) parser.add_argument( '--run-from-jobdb', help="If True, use open jobs in jobdb as the run_list.", default=False, action='store_true', ) parser.add_argument( '--raise-error', help="Raise an error upon completion if any obsids or groups fail.", type=bool, default=False ) parser.add_argument( '--pb-path', help="Path to where to save progress bar.", type=str, default=None ) return parser def main(configs_init: str, configs_proc: str, query: str = '', obs_id: Optional[str] = None, overwrite: bool = False, min_ctime: Optional[int] = None, max_ctime: Optional[int] = None, update_delay: Optional[int] = None, tags: Optional[List[str]] = None, planet_obs: bool = False, verbosity: Optional[int] = None, compress: bool = False, nproc: int = 4, run_from_jobdb: bool = False, raise_error: bool = False, pb_path: Optional[str] = None): rank, executor, as_completed_callable = get_exec_env(nproc) if rank == 0: _main(executor=executor, as_completed_callable=as_completed_callable, configs_init=configs_init, configs_proc=configs_proc, query=query, obs_id=obs_id, overwrite=overwrite, min_ctime=min_ctime, max_ctime=max_ctime, update_delay=update_delay, tags=tags, planet_obs=planet_obs, verbosity=verbosity, nproc=nproc, compress=compress, run_from_jobdb=run_from_jobdb, raise_error=raise_error, pb_path=pb_path) if __name__ == '__main__': main_launcher(main, get_parser)