Source code for sotodlib.site_pipeline.update_det_cal

"""
This module is used to compute detector calibration parameters from sodetlib
data products.

The naive computation is described in the `sodetlib documentation. <https://sodetlib.readthedocs.io/en/latest/operations/bias_steps.html#in-transition>`_

Details about the RP and loopgain correction `can be found on our confluence. <https://simonsobs.atlassian.net/wiki/spaces/~5570586d07625a6be74c8780e4b96f6156f5e6/blog/2024/02/02/286228683/Nonlinear+TES+model+using+RP+curve>`_
"""

import traceback
import os
import yaml
from dataclasses import dataclass, astuple, fields
import numpy as np
from tqdm.auto import tqdm
import logging
from typing import Optional, Union, Dict, List, Any, Tuple, Literal, Callable
from queue import Queue
import argparse

from so3g.proj import RangesMatrix
from sotodlib import core
from sotodlib.io.metadata import write_dataset, ResultSet
from sotodlib.io.load_book import get_cal_obsids
from sotodlib.utils.procs_pool import get_exec_env
from sotodlib.hwp import get_hwpss, subtract_hwpss
from sotodlib.site_pipeline.utils.pipeline import main_launcher
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed, Future
import sodetlib.tes_param_correction as tpc
from sodetlib.operations.iv import IVAnalysis
from sodetlib.operations.bias_steps import BiasStepAnalysis


# stolen  from pysmurf, max bias volt / num_bits
DEFAULT_RTM_BIT_TO_VOLT = 10 / 2**19
DEFAULT_pA_per_phi0 = 9e6
TES_BIAS_COUNT = 12  # per detset / primary file group

# For converting bias group to bandpass.
BGS = {'lb': [0, 1, 4, 5, 8, 9], 'hb': [2, 3, 6, 7, 10, 11]}
BAND_STR = {'mf': {'lb': 'f090', 'hb': 'f150'},
            'uhf': {'lb': 'f220', 'hb': 'f280'},
            'lf': {'lb': 'f030', 'hb': 'f040'}}

from sotodlib.site_pipeline.utils.logging import init_logger as sp_init_logger

logger = logging.getLogger("det_cal")
if not logger.hasHandlers():
    sp_init_logger("det_cal")


def get_data_root(ctx: core.Context, obs_id:str) -> str:
    "Get root data directory based on context file"
    c = ctx.obsfiledb.conn.execute("select name from files limit 1")
    res = [r[0] for r in c][0]
    # split out <data_root>/obs/<timecode>/<obsid>/fname
    for _ in range(5):
        res = os.path.dirname(res)
    # add telescope str to path
    tel_str = obs_id.split(('_'))[2]
    res = os.path.join(res, tel_str)
    return res


[docs] @dataclass class DetCalCfg: """ Class for configuring the behavior of the det-cal update script. Args ------------- root_dir: str Path to the root of the results directory. context_path: str Path to the context file to use. raise_exceptions: bool If Exceptions should be raised in the get_cal_resset function. Defaults to False. fit_tau: bool If True, re-fit the biasstep tau. Defaults to False. apply_cal_correction: bool If True, apply the RP calibration correction, and use corrected results for Rtes, Si, Pj, and loopgain when successful. Defaults to True. hwpss_subtraction: bool If True, reanalyze biasstep with hwpss subtraction. Defaults to False. metadata_list: str or List of str List of metadata labels to load. Defaults to 'all'. index_path: str Path to the index file to use for the det_cal database. Defaults to "det_cal.sqlite". h5_path: str Path to the HDF5 file to use for the det_cal database. Default to "det_cal.h5". h5_unix_digits: int Number of digits of unixtime to be added to h5_path. For example if h5_unix_digits = 4, h5_path will be modified to "det_cal_1700.h5". Defaults to 0. cache_failed_obsids: bool If True, will cache failed obs-ids to avoid re-running them. Defaults to True. failed_file_cache: str Path to the yaml file that will store failed obsids. Defaults to "failed_obsids.yaml". show_pb: bool If True, show progress bar in the run_update function. Defaults to True. param_correction_config: dict Configuration for the TES param correction. If None, default values are used. run_method: str Must be "site" or "nersc". If "site", this function will not parallelize SQLite access, and will only parallelize the TES parameter correction. If "nersc", this will parallelize both SQLite access and the TES param correction, using ``nprocs_obs_info`` and ``nprocs_result_set`` processes respectively. nprocs_obs_info: int Number of processes to use to acquire observation info from the file system. Defaults to 1. nprocs_result_set: int Number of parallel processes that should to compute the TES parameters, and to run the TES parameter correction. num_obs: Optional[int] Max number of observations to process per run_update call. If not set, will run on all available observations. log_level: str Logging level for the logger. multiprocess_start_method: str Method to use to start child processes. Can be "spawn" or "fork". """ def __init__( self, root_dir: str, context_path: str, *, raise_exceptions: bool = False, fit_tau: bool = False, apply_cal_correction: bool = True, hwpss_subtraction: bool = False, metadata_list: Union[str, List[str]] = 'all', index_path: str = "det_cal.sqlite", h5_path: str = "det_cal.h5", h5_unix_digits: int = 0, cache_failed_obsids: bool = True, failed_cache_file: str = "failed_obsids.yaml", show_pb: bool = True, param_correction_config: Union[Dict[str, Any], None, tpc.AnalysisCfg] = None, run_method: str = "site", nprocs_obs_info: int = 1, nprocs_result_set: int = 10, num_obs: Optional[int] = None, log_level: str = "DEBUG", multiprocess_start_method: Literal["spawn", "fork"] = "spawn" ) -> None: self.root_dir = root_dir self.context_path = os.path.expandvars(context_path) self.metadata_list = metadata_list self.raise_exceptions = raise_exceptions self.fit_tau = fit_tau self.apply_cal_correction = apply_cal_correction self.hwpss_subtraction = hwpss_subtraction self.cache_failed_obsids = cache_failed_obsids self.show_pb = show_pb self.run_method = run_method if self.run_method not in ["site", "nersc"]: raise ValueError("run_method must be in: ['site', 'nersc']") self.nprocs_obs_info = nprocs_obs_info self.nprocs_result_set = nprocs_result_set self.num_obs = num_obs self.log_level = log_level self.multiprocess_start_method = multiprocess_start_method self.root_dir = os.path.expandvars(self.root_dir) if not os.path.exists(self.root_dir): raise ValueError(f"Root dir does not exist: {self.root_dir}") def parse_path(path: str) -> str: "Expand vars and make path absolute" p = os.path.expandvars(path) if not os.path.isabs(p): p = os.path.join(self.root_dir, p) return p self.index_path = parse_path(index_path) self.h5_path = parse_path(h5_path) self.h5_unix_digits = h5_unix_digits self.failed_cache_file = parse_path(failed_cache_file) kw = {"show_pb": False, "default_nprocs": self.nprocs_result_set} if param_correction_config is None: self.param_correction_config = tpc.AnalysisCfg(**kw) # type: ignore elif isinstance(param_correction_config, dict): kw.update(param_correction_config) self.param_correction_config = tpc.AnalysisCfg(**kw) # type: ignore else: self.param_correction_config = param_correction_config self.setup_files() @classmethod def from_yaml(cls, path) -> "DetCalCfg": with open(path, "r") as f: d = yaml.safe_load(f) return cls(**d)
[docs] def setup_files(self) -> None: """Create directories and databases if they don't exist""" if not os.path.exists(self.failed_cache_file): # If file doesn't exist yet, just create an empty one with open(self.failed_cache_file, "w") as f: yaml.dump({}, f) if not os.path.exists(self.index_path): scheme = core.metadata.ManifestScheme() scheme.add_exact_match("obs:obs_id") scheme.add_data_field("dataset") db = core.metadata.ManifestDb(scheme=scheme) db.to_file(self.index_path)
[docs] @dataclass class CalInfo: """ Class that contains detector calibration information that will go into the caldb. Attributes ---------- readout_id: str Readout ID of the detector. r_tes: float Detector resistance [ohms], determined through bias steps while the detector is biased. r_frac: float Fractional resistance of TES, given by r_tes / r_n. p_bias: float Bias power on the TES [W] computed using bias steps at the bias point. s_i: float Current responsivity of the TES [1/V] computed using bias steps at the bias point. phase_to_pW: float Phase to power conversion factor [pW/rad] computed using s_i, pA_per_phi0, and detector polarity. v_bias: float Commanded bias voltage [V] on the bias line of the detector for the observation. tau_eff: float Effective thermal time constant [sec] of the detector, measured from bias steps. loopgain: float Loopgain of the detector. tes_param_correction_success: bool True if TES parameter corrections were successfully applied. bg: int Bias group of the detector. Taken from IV curve data, which contains bgmap data taken immediately prior to IV. This will be -1 if the detector is unassigned. polarity: int Polarity of the detector response for a positive change in bias current while the detector is superconducting. This is needed to correct for detectors that have reversed response. r_n: float Normal resistance of the TES [Ohms] calculated from IV curve data. p_sat: float "saturation power" of the TES [W] calculated from IV curve data. This is defined as the electrical bias power at which the TES resistance is 90% of the normal resistance. naive_r_tes: float Detector resistance [ohms]. This is based on the naive bias step estimation without any additional corrections. naive_r_frac: float Fractional resistance of TES, given by r_tes / r_n. This is based on the naive bias step estimation without any additional corrections. naive_p_bias: float Bias power on the TES [W] computed using bias steps at the bias point. This is based on the naive bias step estimation without any additional corrections. naive_s_i: float Current responsivity of the TES [1/V] computed using bias steps at the bias point. This is based on the naive bias step estimation without using any additional corrections. bandpass: str Detector bandpass, computed from bias group information. """ readout_id: str = "" r_tes: float = np.nan r_frac: float = np.nan p_bias: float = np.nan s_i: float = np.nan phase_to_pW: float = np.nan v_bias: float = np.nan tau_eff: float = np.nan loopgain: float = np.nan tes_param_correction_success: bool = False bg: int = -1 polarity: int = 1 r_n: float = np.nan p_sat: float = np.nan naive_r_tes: float = np.nan naive_r_frac: float = np.nan naive_p_bias: float = np.nan naive_s_i: float = np.nan bandpass: str = "NC" @classmethod def dtype(cls) -> List[Tuple[str, Any]]: """Returns ResultSet dtype for an item based on this class""" dtype = [] for field in fields(cls): if field.name == "readout_id": dt: Tuple[str, Any] = ("dets:readout_id", "<U40") elif field.name == 'bandpass': # Our bandpass str is max 4 characters dt: Tuple[str, Any] = ("bandpass", "<U4") else: dt = (field.name, field.type) dtype.append(dt) return dtype
@dataclass class ObsInfo: """ Class containing observation gathered from obsdbs and the file system required to compute calibration results. Attributes ------------ obs_id: str Obs id. iv_obsids: dict Dict mapping detset to iv obs-id. bs_obsids: dict Dict mapping detset to bias step obs-id. iva_files: dict Dict mapping detset to IV analysis file path. bsa_files: dict Dict mapping detset to bias step analysis file path. det_info: dict Dict which contains det readout id, band, channel, stream id, detset. bias_lines: list List which stores bias lines biases: ndarray Array of detector biases for each bias line tube_flavor: str Type of optics tube (MF, LF, or UHF) """ obs_id: str iv_obsids: Dict[str, str] bs_obsids: Dict[str, str] iva_files: Dict[str, str] bsa_files: Dict[str, str] det_info: Dict[str, List[str]] bias_lines: List[str] biases: np.ndarray tube_flavor: str @dataclass class ObsInfoResult: obs_id: str success: bool = False traceback: str = "" obs_info: Optional[ObsInfo] = None def get_obs_info(cfg: DetCalCfg, obs_id: str) -> ObsInfoResult: res = ObsInfoResult(obs_id) try: ctx = core.Context(cfg.context_path, metadata_list=cfg.metadata_list) am = ctx.get_obs( obs_id, samples=(0, 1), ignore_missing=True, no_signal=True, no_headers=False, on_missing={"det_cal": "skip"}, ) if "smurf" not in am.det_info: raise ValueError(f"Missing smurf info for {obs_id}") logger.debug(f"Getting cal obsids ({obs_id})") iv_obsids = get_cal_obsids(ctx, obs_id, "iv") # Load in IVs logger.debug(f"Loading Bias step and IV data ({obs_id})") rtm_bit_to_volt = None pA_per_phi0 = None # Automatically determine paths based on data root instead of obsfiledb # because obsfiledb queries are slow on nersc. iva_files = {} bsa_files = {} for dset, oid in iv_obsids.items(): if oid is not None: timecode = oid.split("_")[1][:5] zsmurf_dir = os.path.join( get_data_root(ctx, obs_id), "oper", timecode, oid, f"Z_smurf" ) for f in os.listdir(zsmurf_dir): if "iv" in f: iva_files[dset] = os.path.join(zsmurf_dir, f) break else: raise ValueError(f"IV data not found for in cal obs {oid}") else: logger.debug("missing IV data for %s", dset) if len(iva_files) == 0: raise ValueError(f"No IV data found for {obs_id}") # Load in bias steps bias_step_obsids = get_cal_obsids(ctx, obs_id, "bias_steps") for dset, oid in bias_step_obsids.items(): if oid is not None: timecode = oid.split("_")[1][:5] zsmurf_dir = os.path.join( get_data_root(ctx, obs_id), "oper", timecode, oid, f"Z_smurf" ) for f in os.listdir(zsmurf_dir): if "bias_step" in f: bs_file = os.path.join(zsmurf_dir, f) bsa_files[dset] = bs_file break else: raise ValueError(f"Bias step data not found for in cal obs {oid}") else: logger.debug("missing bias step data for %s", dset) if rtm_bit_to_volt is None: rtm_bit_to_volt = DEFAULT_RTM_BIT_TO_VOLT if pA_per_phi0 is None: pA_per_phi0 = DEFAULT_pA_per_phi0 # Pass through specific fields from our axismanager so that the result # is Picklable. res.obs_info = ObsInfo( obs_id=obs_id, det_info ={"stream_id": am.det_info.stream_id, "detset": am.det_info.detset, "readout_id": am.det_info.readout_id, "band": am.det_info.smurf.band, "channel": am.det_info.smurf.channel,}, bias_lines=am.bias_lines.vals, biases=am.biases, tube_flavor=am.obs_info.tube_flavor, iv_obsids=iv_obsids, bs_obsids=bias_step_obsids, iva_files=iva_files, bsa_files=bsa_files, ) res.success = True except: res.traceback = traceback.format_exc() if cfg.raise_exceptions: raise return res @dataclass class CalRessetResult: """ Results object for the get_cal_resset function. """ obs_info: ObsInfo success: bool = False traceback: Optional[str] = None fail_msg: Optional[str] = None correction_results: Optional[Dict[str, List[tpc.CorrectionResults]]] = None result_set: Optional[np.ndarray] = None def biases_flags(bsa, buffer=200): """ Make flags that mask bias steps Args bsa: sodetlib BiasStepAnalysis object buffer: Number of samples to buffer flags Returns RangesMatrix """ mask = np.zeros((bsa.am.dets.count, bsa.am.samps.count), dtype=bool) for i, bg in enumerate(bsa.bgmap): if bg == -1: continue if len(bsa.edge_idxs[bg]) == 0: continue mask[i][bsa.edge_idxs[bg][0]:bsa.edge_idxs[bg][-1]] = 1 flags = RangesMatrix.from_mask(mask).buffer(buffer) return flags def fill_zeros_biases(am): # fill the zeros in biases by non-zero values before # zeros in the beginning will be filled by non-zero values after for bias in am.biases: last = None for i in range(len(bias)): if bias[i] != 0: last = bias[i] elif last is not None: bias[i] = last last = None for i in range(len(bias)-1, -1, -1): if bias[i] != 0: last = bias[i] elif last is not None: bias[i] = last def load_and_reanalyze_bs(bsa, ctx, obs_id): """ Load raw data of biassteps and reanalyze it with hwpss subtraction Args bsa: sodetlib BiasStepAnalysis object ctx: Context object obs_id: observation id of bias steps """ am = ctx.get_obs(obs_id, special_channels=True, reindex_dets=True) am.wrap('hwp_angle', am.hwp_solution.hwp_angle, [(0, 'samps')]) if np.all(am.hwp_angle == 0): return bsa.am = am zero_bias_count = sum([sum(bias == 0) for bias in am.biases]) if zero_bias_count > 0: logger.warn(f'Patching {zero_bias_count} zero bias values in {obs_id}') fill_zeros_biases(am) bsa._find_bias_edges() flags = biases_flags(bsa) get_hwpss(am, flags=flags, merge_stats=True) subtract_hwpss(am, subtract_name='signal') bsa._get_step_response() bsa._compute_dc_params() bsa._fit_tau_effs() del bsa.am def get_cal_resset(cfg: DetCalCfg, obs_info: ObsInfo, executor=None, as_completed_callable=None) -> CalRessetResult: """ Returns calibration ResultSet for a given ObsId. This pulls IV and bias step data for each detset in the observation, and uses that to compute CalInfo for each detector in the observation. Args ------ cfg: DetCalCfg DetCal configuration object. obs_info: ObsInfo ObsInfo object. pool: Optional[multiprocessing.Pool] If specified, will run TES param correction in parallel using processes from this pool. """ obs_id = obs_info.obs_id res = CalRessetResult(obs_info) logger.debug("Computing Result set for %s", obs_info.obs_id) # Need to reset logger here because this may be created new for spawned process logger.setLevel(getattr(logging, cfg.log_level.upper())) for ch in logger.handlers: ch.setLevel(getattr(logging, cfg.log_level.upper())) try: ivas = { dset: IVAnalysis.load(iva_file) for dset, iva_file in obs_info.iva_files.items() } bsas = { dset: BiasStepAnalysis.load(bsa_file) for dset, bsa_file in obs_info.bsa_files.items() } if cfg.apply_cal_correction: for iva in ivas.values(): # Run R_L correction if analysis version is old... if getattr(iva, "analysis_version", 0) == 0: # This will edit IVA dicts in place logger.debug("Recomputing IV analysis for %s", obs_id) tpc.recompute_ivpars(iva, cfg.param_correction_config) if cfg.fit_tau: for dset, bsa in bsas.items(): bsa._fit_tau_effs() if cfg.hwpss_subtraction: # Reanalyze biasstep with hwpss subtraction ctx = core.Context(cfg.context_path, metadata_list=cfg.metadata_list) bias_step_obsids = get_cal_obsids(ctx, obs_id, "bias_steps") for dset, bsa in bsas.items(): oid = bias_step_obsids[dset] load_and_reanalyze_bs(bsa, ctx, oid) iva = list(ivas.values())[0] rtm_bit_to_volt = iva.meta["rtm_bit_to_volt"] pA_per_phi0 = iva.meta["pA_per_phi0"] cals = [CalInfo(rid) for rid in obs_info.det_info["readout_id"]] if len(cals) == 0: raise ValueError(f"No detectors found for {obs_id}") # Add IV info for i, cal in enumerate(cals): band = obs_info.det_info["band"][i] chan = obs_info.det_info["channel"][i] detset = obs_info.det_info["detset"][i] iva = ivas[detset] if iva is None: # No IV analysis for this detset continue ridx = np.where((iva.bands == band) & (iva.channels == chan))[0] if not ridx: # Channel doesn't exist in IV analysis continue ridx = ridx[0] cal.bg = iva.bgmap[ridx] cal.polarity = iva.polarity[ridx] cal.r_n = iva.R_n[ridx] # type: ignore cal.p_sat = iva.p_sat[ridx] # type: ignore obs_biases = dict( zip(obs_info.bias_lines, obs_info.biases[:, 0] * 2 * rtm_bit_to_volt) ) bias_line_is_valid = {k: True for k in obs_biases.keys()} # check to see if biases have changed between bias steps and obs for bsa in bsas.values(): if bsa is None: continue for bg, vb_bsa in enumerate(bsa.Vbias): bl_label = f"{bsa.meta['stream_id']}_b{bg:0>2}" # Usually we can count on bias voltages of bias lines >= 12 to be # Nan, however we have seen cases where they're not, so we also # restrict by count. if np.isnan(vb_bsa) or bg >= TES_BIAS_COUNT: bias_line_is_valid[bl_label] = False continue if np.abs(vb_bsa - obs_biases[bl_label]) > 0.1: logger.debug( "bias step and obs biases don't match for %s", bl_label ) bias_line_is_valid[bl_label] = False # Add TES corrected params correction_results: Dict[str, List[tpc.CorrectionResults]] = {} if cfg.apply_cal_correction: logger.debug("Applying TES param corrections (%s)", obs_id) for dset in bsas: # logger.debug(f"Applying correction for {dset}") rs = [] if executor is None: for b, c in zip(ivas[dset].bands, ivas[dset].channels): chdata = tpc.RpFitChanData.from_data( ivas[dset], bsas[dset], b, c ) rs.append( tpc.run_correction(chdata, cfg.param_correction_config) ) else: rs = tpc.run_corrections_parallel( ivas[dset], bsas[dset], cfg.param_correction_config, executor=executor, as_completed_callable=as_completed_callable) correction_results[dset] = rs res.correction_results = correction_results def find_correction_results(band, chan, dset): for r in correction_results[dset]: if r.chdata.band == band and r.chdata.channel == chan: return r return None for i, cal in enumerate(cals): band = obs_info.det_info["band"][i] chan = obs_info.det_info["channel"][i] detset = obs_info.det_info["detset"][i] stream_id = obs_info.det_info["stream_id"][i] bg = cal.bg bsa = bsas[detset] if bsa is None or bg == -1: continue bl_label = f"{stream_id}_b{bg:0>2}" if not bias_line_is_valid[bl_label]: continue ridx = np.where((bsa.bands == band) & (bsa.channels == chan))[0] if not ridx: # Channel doesn't exist in bias step analysis continue if cfg.apply_cal_correction: correction = find_correction_results(band, chan, detset) if correction is None: logger.warn( "Unable to find correction result for %s %s %s (%s)", band, chan, detset, obs_id, ) use_correction = False cal.tes_param_correction_success = False else: use_correction = correction.success cal.tes_param_correction_success = correction.success else: use_correction = False ridx = ridx[0] cal.tau_eff = bsa.tau_eff[ridx] if bg != -1: cal.v_bias = bsa.Vbias[bg] if use_correction and correction.corrected_params is not None: cpars = correction.corrected_params cal.r_tes = cpars.corrected_R0 cal.r_frac = cpars.corrected_R0 / cal.r_n cal.s_i = cpars.corrected_Si * 1e6 cal.p_bias = cpars.corrected_Pj * 1e-12 cal.loopgain = cpars.loopgain else: cal.r_tes = bsa.R0[ridx] cal.r_frac = bsa.Rfrac[ridx] cal.p_bias = bsa.Pj[ridx] cal.s_i = bsa.Si[ridx] # Save naive parameters even if we're using corrected version cal.naive_r_tes = bsa.R0[ridx] cal.naive_r_frac = bsa.Rfrac[ridx] cal.naive_s_i = bsa.Si[ridx] cal.naive_p_bias = bsa.Pj[ridx] if cal.s_i == 0: cal.phase_to_pW = np.nan else: cal.phase_to_pW = pA_per_phi0 / (2 * np.pi) / cal.s_i * cal.polarity # Add bandpass informaton from bias group tube_flavor = obs_info.tube_flavor if cal.bg in BGS['lb']: cal.bandpass = BAND_STR[tube_flavor]['lb'] elif cal.bg in BGS['hb']: cal.bandpass = BAND_STR[tube_flavor]['hb'] res.result_set = np.array([astuple(c) for c in cals], dtype=CalInfo.dtype()) res.success = True except Exception as e: res.traceback = traceback.format_exc() res.fail_msg = res.traceback if cfg.raise_exceptions: raise e return res def get_obsids_to_run(cfg: DetCalCfg) -> List[str]: """ Returns list of obs-ids to process, based on the configuration object. This will included non-processed obs-ids that are not found in the fail cache, and will be limitted to cfg.num_obs. """ ctx = core.Context(cfg.context_path, metadata_list=cfg.metadata_list) # Find all obs_ids that have not been processed with open(cfg.failed_cache_file, "r") as f: failed_cache = yaml.safe_load(f) if failed_cache is not None: failed_obsids = set(failed_cache.keys()) else: failed_obsids = set() db = core.metadata.ManifestDb(cfg.index_path) obs_ids_all = set(ctx.obsdb.query('type=="obs"')["obs_id"]) processed_obsids = set(db.get_entries(["dataset"])["dataset"]) obs_ids = sorted(list(obs_ids_all - processed_obsids - failed_obsids), reverse=True) if cfg.num_obs is not None: obs_ids = obs_ids[: cfg.num_obs] return obs_ids def add_to_failed_cache(cfg: DetCalCfg, obs_id: str, msg: str) -> None: if "KeyboardInterrupt" in msg: # Don't cache keyboard interrupts return # Transient errors of metadata loading. # These can happen when hwpss_subtraction is True, but we can retry. transient_errors = [ 'sotodlib.core.metadata.loader.LoaderError', 'BlockingIOError', ] for err in transient_errors: if err in msg: logger.error(f"obs_id {obs_id} failed to load metadata {err}." " Try again later") return if cfg.cache_failed_obsids: logger.info(f"Adding {obs_id} to failed_file_cache") with open(cfg.failed_cache_file, "r") as f: d = yaml.safe_load(f) if d is None: d = {} d[str(obs_id)] = msg with open(cfg.failed_cache_file, "w") as f: yaml.dump(d, f) return def handle_result(result: CalRessetResult, cfg: DetCalCfg) -> None: """ Handles a CalRessetResult. If successful, this will add to the manifestdb, if not this will add to the failed cache if cfg.cache_failed_obsids is True. """ obs_id = str(result.obs_info.obs_id) if not result.success: logger.error(f"Failed on obs_id: {obs_id}") logger.error(result.traceback) msg = result.fail_msg if msg is None: msg = "unknown error" add_to_failed_cache(cfg, obs_id, msg) return logger.info(f"Adding obs_id {obs_id} to dataset") rset = ResultSet.from_friend(result.result_set) h5_path = cfg.h5_path if cfg.h5_unix_digits: name, ext = os.path.splitext(cfg.h5_path) unixtime = obs_id.split('_')[1][:cfg.h5_unix_digits] h5_path = f"{name}_{unixtime}{ext}" write_dataset(rset, h5_path, obs_id, overwrite=True) db = core.metadata.ManifestDb(cfg.index_path) relpath = os.path.relpath(h5_path, start=os.path.dirname(cfg.index_path)) db.add_entry( {"obs:obs_id": obs_id, "dataset": obs_id}, filename=relpath, replace=True ) def run_update_site( cfg: DetCalCfg, executor: Union["MPICommExecutor", "ProcessPoolExecutor"], as_completed_callable: Callable) -> None: """ Main run script for computing det-cal results at the site. This will loop over obs-ids and serially gather the ObsInfo from the filesystem and sqlite dbs, and then compute the calibration results. A processing pool of cfg.nprocs_result_set processes will be used to parallelize the TES correction computation. If you have lots of compute power, and are limitted by filesystem or sqlite access, consider using the 'nersc' update function. Args: ------ cfg: DetCalCfg or str DetCalCfg object or path to config yaml file """ logger.setLevel(getattr(logging, cfg.log_level.upper())) for ch in logger.handlers: ch.setLevel(getattr(logging, cfg.log_level.upper())) obs_ids = get_obsids_to_run(cfg) logger.info(f"Processing {len(obs_ids)} obsids...") for oid in tqdm(obs_ids, disable=(not cfg.show_pb)): res = get_obs_info(cfg, oid) if not res.success: logger.info(f"Could not get obs info for obs id: {oid}") logger.error(res.traceback) if res.obs_info is None: continue result_set = get_cal_resset(cfg, res.obs_info, executor=executor, as_completed_callable=as_completed_callable) handle_result(result_set, cfg) def run_update_nersc(cfg: DetCalCfg) -> None: """ Main run script for computing det-cal results. This does the same thing as ``run_update_site`` however instantiates two separate pools for gathering ObsInfo and computing the ResultSets. This is useful in situations where sqlite/filesystem access are bottlenecks (such as nersc) so that ObsInfo can be gathered in parallel, and this can be done while ResultSet computation is ongoing. Because concurrent sqlite access can be limitted, it is recommended to keep cfg.nprocs_obs_info low (<10), while ``cfg.nprocs_result_set`` can be set arbitrarily large as to use remaining available resources. Args: ------ cfg: DetCalCfg or str DetCalCfg object or path to config yaml file """ logger.setLevel(getattr(logging, cfg.log_level.upper())) for ch in logger.handlers: ch.setLevel(getattr(logging, cfg.log_level.upper())) obs_ids = get_obsids_to_run(cfg) # obs_ids = ['obs_1713962395_satp1_0000100'] # obs_ids = ['obs_1713758716_satp1_1000000'] # obs_ids = ['obs_1701383445_satp3_1000000'] logger.info(f"Processing {len(obs_ids)} obsids...") pb = tqdm(total=len(obs_ids), disable=(not cfg.show_pb)) def callback(fut: Future): result = fut.result() pb.update() handle_result(result, cfg) def errback(e): logger.info(e) raise e # We split into multiple pools because: # - we don't want to overload sqlite files with too much concurrent access # - we want to be able to continue getting the next obs_info data while # ressets are being computed executor1 = ProcessPoolExecutor(max_workers=cfg.nprocs_obs_info) executor2 = ProcessPoolExecutor(max_workers=cfg.nprocs_result_set) resset_futures: list[Future] = [] obsinfo_futures: list[Future] = [] def get_obs_info_callback(fut: Future): try: result = fut.result() if result.success: future = executor2.submit(get_cal_resset, cfg, result.obs_info) future.add_done_callback(callback) resset_futures.append(future) else: pb.update() add_to_failed_cache(cfg, result.obs_id, result.traceback) logger.error(f"Failed to get obs_info for {result.obs_id}:\n{result.traceback}") except Exception as e: errback(e) try: for obs_id in obs_ids: future = executor1.submit(get_obs_info, cfg, obs_id) future.add_done_callback(get_obs_info_callback) obsinfo_futures.append(future) # Wait for all obsinfo tasks to complete for fut in as_completed(obsinfo_futures): pass # results handled in callback # Wait for all resset tasks to complete for fut in as_completed(resset_futures): try: fut.result() # Force exceptions to be raised here if any except Exception as e: errback(e) finally: executor1.shutdown(wait=True, cancel_futures=True) executor2.shutdown(wait=True, cancel_futures=True) pb.close() logger.info("Finished updates") def get_parser( parser: Optional[argparse.ArgumentParser] = None,) -> argparse.ArgumentParser: if parser is None: p = argparse.ArgumentParser() else: p = parser p.add_argument( "config_file", type=str, help="yaml file with configuration for update script." ) return p def _main( cfg: DetCalCfg, executor: Union["MPICommExecutor", "ProcessPoolExecutor"], as_completed_callable: Callable) -> None: """ Run update function. This will chose the correct method to run based on ``cfg.run_method``. """ if cfg.run_method == "site": run_update_site(cfg, executor, as_completed_callable) elif cfg.run_method == "nersc": # Instantiate our own separate pools when running on NERSC. executor.shutdown(wait=True) run_update_nersc(cfg) else: raise ValueError(f"Unknown run_method: {cfg.run_method}") def main(config_file: str): cfg = DetCalCfg.from_yaml(config_file) rank, executor, as_completed_callable = get_exec_env(cfg.nprocs_result_set) if rank == 0: _main(cfg, executor, as_completed_callable) if __name__ == "__main__": main_launcher(main, get_parser)