import argparse as ap
import datetime as dt
import logging
import os
from copy import deepcopy
from importlib import import_module
from typing import List, Optional
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import git
import h5py
import megham.transform as mt
import megham.utils as mu
import numpy as np
import yaml
from scipy.cluster import vq
from scipy.optimize import minimize
from so3g.proj import quat
from sotodlib.coords import optics as op
from sotodlib.coords.fp_containers import (
FocalPlane,
OpticsTube,
Receiver,
Template,
Transform,
plot_by_gamma,
plot_ot,
plot_receiver,
plot_ufm,
)
from sotodlib.coords.pointing_model import apply_pointing_model
from sotodlib.core import AxisManager, Context, IndexAxis, metadata
from sotodlib.io.metadata import read_dataset
from sotodlib.site_pipeline.utils.logging import init_logger
logger = init_logger(__name__, "finalize_focal_plane: ")
def _create_db(filename, per_obs, obs_ids, start_time, stop_time):
if per_obs:
if len(obs_ids) != 1:
raise ValueError(f"Running in per_obs mode but {len(obs_ids)} found!")
base = {"obs:obs_id": obs_ids[0]}
group = obs_ids[0]
else:
base = {"obs:timestamp": (start_time, stop_time)}
group = str(start_time)
if os.path.isfile(filename):
return metadata.ManifestDb(filename), base, group
if not os.path.isdir(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename), exist_ok=True)
scheme = metadata.ManifestScheme()
scheme.add_exact_match("dets:stream_id")
if per_obs:
scheme.add_exact_match("obs:obs_id")
else:
scheme.add_range_match("obs:timestamp")
scheme.add_data_field("dataset")
metadata.ManifestDb(scheme=scheme).to_file(filename)
return metadata.ManifestDb(filename), base, group
def _avg_focalplane(full_fp, tot_weight):
# Figure out how many good pointings we have for each det
msk = np.isfinite(full_fp)
n_obs = np.sum(np.any(msk, axis=1), axis=-1)
n_point, _, n_gamma = tuple(np.sum(msk, axis=-1).T)
tot_weight[tot_weight[:, 0] == 0] = np.nan
avg_fp = np.nansum(full_fp, axis=-1) / tot_weight[:, 0][..., None]
avg_weight = tot_weight / n_obs[..., None]
# nansum all all nans is 0, addressing that case here
all_nan = ~np.any(
np.isfinite(np.swapaxes(full_fp, 0, 1)).reshape((full_fp.shape[1], -1)), axis=1
)
avg_fp[:, all_nan] = np.nan
return avg_fp, avg_weight, n_point, n_gamma
def _log_vals(shift, scale, shear, rot, axis):
deg2rad = np.pi / 180.0
rad2deg = 180.0 / np.pi
for ax, s in zip(axis, shift):
logger.info("\tShift along %s axis is %f", ax, s)
for ax, s in zip(axis, scale):
logger.info("\tScale along %s axis is %f", ax, s)
if np.isclose(s, deg2rad):
logger.warning(
"\tScale factor for %s looks like a degrees to radians conversion", ax
)
elif np.isclose(s, rad2deg):
logger.warning(
"\tScale factor for %s looks like a radians to degrees conversion", ax
)
logger.info("\tShear param is %f", shear)
logger.info("\tRotation of the %s-%s plane is %f radians", axis[0], axis[1], rot)
[docs]
def gamma_fit(src, dst):
"""
Fit the transformation for gamma.
Note that the periodicity here assumes things are in radians.
Arguments:
src: Source gamma in radians
dst: Destination gamma in radians
Returns:
scale: Scale applied to src
shift: Shift applied to scale*src
"""
def _gamma_min(pars, src, dst):
scale, shift = pars
transformed = np.sin(src * scale + shift)
diff = np.sin(dst) - transformed
return np.sqrt(np.mean(diff**2))
res = minimize(_gamma_min, (1.0, 0.0), (src, dst))
return res.x
def _load_template(template_path, ufm, pointing_cfg):
template_rset = read_dataset(template_path, ufm)
det_ids = template_rset["dets:det_id"]
template = np.column_stack(
(
np.array(template_rset["xi"]),
np.array(template_rset["eta"]),
np.array(template_rset["gamma"]),
)
)
template_optical = template_rset["is_optical"]
return Template(
np.array(det_ids), template, np.array(template_optical), pointing_cfg
)
def _get_obs_ids(
ctx,
metalist,
start_time,
stop_time,
query=None,
obs_ids=[],
tags=[],
stream_ids=[],
min_dets=0,
):
all_obs = obs_ids
query_obs = []
if len(obs_ids) == 0:
query_all = query
if query is None:
query_all = (
f"type=='obs' and start_time>{start_time} and stop_time<{stop_time}"
)
if ctx.obsdb is None:
raise ValueError("No obsdb!")
all_obs = ctx.obsdb.query(query_all, tags=tags)["obs_id"]
dbs = [
metadata.ManifestDb(md["db"])
for md in ctx["metadata"]
if md.get("name", "") in metalist or md.get("label", "") in metalist
]
with_meta = np.unique(
np.hstack(
[np.array([entry["obs:obs_id"] for entry in db.inspect()]) for db in dbs]
)
)
all_obs = np.intersect1d(all_obs, with_meta)
if query is not None:
query_obs = ctx.obsdb.query(query)["obs_id"]
obs_ids += query_obs
if len(stream_ids) > 0:
all_obs = [
obs_id
for obs_id in all_obs
if len(
np.intersect1d(
ctx.obsdb.get(obs_id)["stream_ids_list"].split(","), stream_ids
)
)
]
all_obs = np.array(
[
obs_id
for obs_id in all_obs
if len(ctx.obsfiledb.get_det_table(obs_id)) >= min_dets
]
)
if len(obs_ids) == 0 and query is None:
return all_obs
return np.intersect1d(obs_ids, all_obs)
def _load_ctx(config):
ctx = Context(config["context"]["path"])
if ctx.obsdb is None:
raise ValueError("No obsdb!")
tod_pointing_name = config["context"].get("tod_pointing", "tod_pointing")
map_pointing_name = config["context"].get("map_pointing", "map_pointing")
pol_name = config["context"].get("polarization", "polarization")
dm_name = config["context"].get("detmap", "detmap")
roll_range = config.get("roll_range", [-1 * np.inf, np.inf])
obs_ids = _get_obs_ids(
ctx,
[tod_pointing_name, map_pointing_name, pol_name],
config["start_time"],
config["stop_time"],
config["context"].get("query", None),
config["context"].get("obs_ids", []),
config["context"].get("tags", ["timing_issues=0"]),
config["context"].get("stream_ids", []),
config.get("min_points", 0),
)
if len(obs_ids) == 0:
logger.warning("No observations provided in configuration")
amans = []
dets = config["context"].get("dets", {})
failed = []
for obs_id in obs_ids:
roll = ctx.obsdb.get(obs_id)["roll_center"]
if roll is None:
continue
if roll < roll_range[0] or roll > roll_range[1]:
logger.info("%s has a roll that is out of range", obs_id)
continue
try:
aman = ctx.get_meta(obs_id, dets=dets)
except: # metadata.loader.LoaderError:
logger.error("Failed to load %s, skipping", obs_id)
failed += [obs_id]
continue
if aman.obs_info.tube_slot == "stp1":
aman.obs_info.tube_slot = "st1"
if "det_info" not in aman:
raise ValueError(f"No det_info in {obs_id}")
if "wafer" not in aman.det_info and dm_name in aman:
dm_aman = aman[dm_name].copy()
aman.det_info.wrap("wafer", dm_aman)
if "det_id" not in aman.det_info:
aman.det_info.wrap(
"det_id", aman.det_info.wafer.det_id, [(0, aman.dets)]
)
if "det_id" in aman.det_info:
aman.restrict("dets", ~np.isin(aman.det_info.det_id, ["", "NO_MATCH"]))
else:
raise ValueError(f"No detmap for {obs_id}")
pol = pol_name in aman
if pol:
aman.move(pol_name, "polarization")
else:
logger.warning("No polarization data in context")
if tod_pointing_name in aman:
_aman = aman.copy()
_aman.move(tod_pointing_name, "pointing")
amans.append(_aman)
if map_pointing_name in aman:
_aman = aman.copy()
_aman.move(map_pointing_name, "pointing")
amans.append(_aman)
elif tod_pointing_name not in aman:
raise ValueError(f"No pointing found in {obs_id}")
obs_ids = [aman.obs_info.obs_id for aman in amans]
if len(failed) > 0:
logger.error("Failed to load %s", str(failed))
# Figure out stream_ids and OTs
query_all = f"type=='obs' and start_time>{config['start_time']} and stop_time<{config['stop_time']}"
ot_sid = []
for obs in ctx.obsdb.query(query_all):
if obs["wafer_slots_list"] is None or obs["stream_ids_list"] is None:
continue
ot = obs["tube_slot"]
if ot == "stp1":
ot = "st1"
ot_sid += [
(obs["telescope_flavor"], ot, ws, sid)
for ws, sid in zip(
obs["wafer_slots_list"].split(","), obs["stream_ids_list"].split(",")
)
]
ot_sid = np.unique(np.array(ot_sid), axis=0)
return amans, obs_ids, ot_sid
def _load_rset_single(config):
obs_id = config["resultsets"].get("obs_id", "")
pointing_rset = read_dataset(*config["resultsets"]["pointing"])
pointing_aman = pointing_rset.to_axismanager(axis_key="dets:readout_id")
aman = AxisManager(pointing_aman.dets)
aman = aman.wrap("pointing", pointing_aman)
if "polarization" in config["resultsets"]:
polarization_rset = read_dataset(*config["resultsets"]["polarization"])
polarization_aman = polarization_rset.to_axismanager(axis_key="dets:readout_id")
aman = aman.wrap("polarization", polarization_aman)
det_info = AxisManager(aman.dets)
dm_rset = read_dataset(*config["resultsets"]["detmap"])
dm_aman = dm_rset.to_axismanager(axis_key="readout_id")
det_info.wrap("wafer", dm_aman)
det_info.wrap("readout_id", det_info.dets.vals, [(0, det_info.dets)])
det_info.wrap("det_id", det_info.wafer.det_id, [(0, det_info.dets)])
det_info.wrap(
"stream_id",
np.array([config["stream_id"].lower()] * det_info.dets.count),
[(0, det_info.dets)],
)
det_info.wrap(
"wafer_slot",
np.array([config["wafer_slot"].lower()] * det_info.dets.count),
[(0, det_info.dets)],
)
det_info.restrict("dets", det_info.dets.vals[det_info.det_id != ""])
det_info.det_id = np.char.strip(det_info.det_id) # Needed for some old results
aman = aman.wrap("det_info", det_info)
aman.restrict("dets", aman.dets.vals[aman.det_info.det_id != "NO_MATCH"])
obs_info = AxisManager()
obs_info.wrap("telescope_flavor", config["telescope_flavor"].lower())
obs_info.wrap("tube_slot", config["tube_slot"].lower())
aman.wrap("obs_info", obs_info)
smurf = AxisManager(aman.dets)
if "band" in aman.pointing:
smurf.wrap("band", np.array(aman.pointing.band, dtype=int), [(0, smurf.dets)])
elif "wafer" in det_info and "smurf_band" in det_info.wafer:
smurf.wrap(
"band", np.array(det_info.wafer.smurf_band, dtype=int), [(0, smurf.dets)]
)
if "channel" in aman.pointing:
smurf.wrap(
"channel", np.array(aman.pointing.channel, dtype=int), [(0, smurf.dets)]
)
elif "wafer" in det_info and "smurf_channel" in det_info.wafer:
smurf.wrap(
"channel",
np.array(det_info.wafer.smurf_channel, dtype=int),
[(0, smurf.dets)],
)
aman.det_info.wrap("smurf", smurf)
return aman, obs_id
def _load_rset(config):
stream_id = config["stream_id"]
telescope_flavor = config["telescope_flavor"].lower()
ot = config["tube_slot"].lower()
ws = config["wafer_slot"].lower()
obs = config["resultsets"]
_config = config.copy()
obs_ids = np.array(list(obs.keys()))
amans: List[Optional[AxisManager]] = [None] * len(obs_ids)
obs_info = AxisManager()
obs_info.wrap("stream_id", stream_id)
for i, (obs_id, rsets) in enumerate(obs.items()):
_config["resultsets"] = rsets
_config["resultsets"]["obs_id"] = obs_id
aman, _ = _load_rset_single(_config)
if "det_info" not in aman or "det_id" not in aman.det_info:
raise ValueError(f"No detmap for {obs_id}")
amans[i] = aman
return (
amans,
obs_ids,
[(telescope_flavor, ot, ws, stream_id)],
)
def _mk_pointing_config(telescope_flavor, tube_slot, wafer_slot, config):
config_dir = config.get("pipeline_config_dir", os.environ["PIPELINE_CONFIG_DIR"])
config_path = os.path.join(config_dir, "shared/focalplane/ufm_to_fp.yaml")
ot_config_path = os.path.join(config_dir, "shared/focalplane/optics_tubes.yaml")
zemax_path = config.get("zemax_path", None)
pointing_cfg = {
"telescope_flavor": telescope_flavor,
"tube_slot": tube_slot,
"wafer_slot": wafer_slot,
"config_path": config_path,
"ot_config_path": config_path,
"zemax_path": zemax_path,
"return_fp": False,
}
return pointing_cfg
def _restrict_inliers(aman, focal_plane):
# TODO: Use gamma as well
# Map to template
fp, _, _, template_msk = focal_plane.map_by_det_id(aman)
fp = fp[:, :2]
inliers = np.ones(len(fp), dtype=bool)
rad_thresh = 1.05 * np.nanmax(
np.linalg.norm(
focal_plane.template.fp[:, :2] - focal_plane.template.center[:, :2], axis=1
)
)
# Use kmeans to kill any ghosts
fp_white = vq.whiten(fp[inliers])
codebook, _ = vq.kmeans(fp_white, 2)
codes, _ = vq.vq(fp_white, codebook)
c0 = codes == 0
c1 = codes == 1
m0 = np.median(fp[inliers][c0], axis=0)
m1 = np.median(fp[inliers][c1], axis=0)
dist = np.linalg.norm(m0 - m1)
# If centroids are too far from each other use the bigger one
if dist < rad_thresh:
cluster = c0 + c1
elif np.sum(c0) >= np.sum(c1):
cluster = c0
else:
cluster = c1
# Flag anything too far away from the center
cent = np.median(fp[inliers][cluster], axis=0)
r = np.linalg.norm(fp[inliers] - cent, axis=1)
inliers[inliers] *= cluster * (r <= rad_thresh)
# Now kill dets that seem too far from their match
fp[~inliers] = np.nan
rot, sft = mt.get_rigid(fp, focal_plane.template.fp[template_msk, :2])
fp_aligned = mt.apply_transform(fp, rot, sft)
likelihood = mu.gen_weights(fp_aligned, focal_plane.template.fp[template_msk, :2])
inliers *= likelihood > 0.61 # ~1 sigma cut
# Now restrict the AxisManager
inlier_det_ids = focal_plane.template.det_ids[template_msk][inliers]
return aman.restrict(
"dets", aman.dets.vals[np.isin(aman.det_info.det_id, inlier_det_ids)]
)
def _apply_pointing_model(config, aman):
if "pointing_model" not in config:
logger.info("\t\tNo pointing model specified!")
return aman
if not config["pointing_model"].get("apply", False):
logger.info("\t\tNot applying pointing model")
return aman
if "function" not in config["pointing_model"]:
logger.info("\t\tUsing default pointing model function")
func = apply_pointing_model
else:
func = getattr(
import_module(config["pointing_model"]["function"][0]),
config["pointing_model"]["function"][1],
)
if "az" not in aman.pointing:
logger.warning("\t\tNeed to have az in pointing fits to apply pointing model! Filling from obsdb")
aman.pointing.wrap("az", np.deg2rad(aman.obs_info.az_center)*np.ones(aman.dets.count), [(0, aman.dets)])
if "el" not in aman.pointing:
logger.warning("\t\tNeed to have el in pointing fits to apply pointing model! Filling from obsdb")
aman.pointing.wrap("el", np.deg2rad(aman.obs_info.el_center)*np.ones(aman.dets.count), [(0, aman.dets)])
if "roll" not in aman.pointing:
logger.warning("\t\tNeed to have roll in pointing fits to apply pointing model! Filling from obsdb")
aman.pointing.wrap("roll", np.deg2rad(aman.obs_info.roll_center)*np.ones(aman.dets.count), [(0, aman.dets)])
params = config["pointing_model"].get("params", {})
if "pointing_model" in aman:
for key, val in params.items():
if key in aman.pointing_model:
aman.pointing_model[key] = val
else:
aman.pointing_model.wrap(key, val)
params = aman.pointing_model
ancil = AxisManager(IndexAxis("samps", aman.dets.count))
ancil.wrap("az_enc", np.rad2deg(aman.pointing.az))
ancil.wrap("el_enc", np.rad2deg(aman.pointing.el))
ancil.wrap("roll_enc", np.rad2deg(aman.pointing.roll))
ancil.wrap("boresight_enc", -1*np.rad2deg(aman.pointing.roll)) # for SATs
bs = func(aman, params, ancil, False)
q_fp = quat.rotation_xieta(aman.pointing.xi, aman.pointing.eta)
have_gamma = False
if "gamma" in aman.pointing:
if np.any(np.isnan(aman.pointing.gamma)):
logger.warning(
"\t\tnans in gamma, not including in pointing model correction"
)
else:
q_fp = quat.rotation_xieta(
aman.pointing.xi, aman.pointing.eta, aman.pointing.gamma
)
have_gamma = True
xi, eta, gamma = quat.decompose_xieta(
~quat.euler(2, bs.roll)
* ~quat.rotation_lonlat(-bs.az, bs.el)
* quat.rotation_lonlat(-1 * aman.pointing.az, aman.pointing.el)
* quat.euler(2, (not config["pointing_model"].get("force_zero_roll", False)) * aman.pointing.roll)
* q_fp
)
aman.pointing.xi[:] = xi
aman.pointing.eta[:] = eta
if have_gamma:
aman.pointing.gamma[:] = gamma
return aman
def _reverse_roll(fp, aff, sft, aman):
if "obs_info" not in aman:
raise ValueError("Can't reverse roll without obs information")
if "roll_center" not in aman.obs_info:
raise ValueError("Can't reverse roll without roll information")
roll = -1 * np.deg2rad(aman.obs_info.roll_center)
# We want to shift so we rotating about the origin
# To get to nominal we do fp@aff + sft
# So if we just want to recenter we do fp + sft@aff^-1
inv_aff, _ = mt.invert_transform(aff, np.zeros_like(sft))
sft_adj = sft @ inv_aff
fp_sft = fp[:, :2] + sft_adj
# Now lets reverse the roll
# The transpose is the inverse
rot = np.array([[np.cos(roll), -1 * np.sin(roll)], [np.sin(roll), np.cos(roll)]])
fp_rot = fp_sft @ rot
# And undo the shift, keeping track of rotations
fp_rot -= sft_adj @ rot
# Make sure its set
fp[:, :2] = fp_rot
# For gamma lets just shift by the roll
fp[:, 2] -= roll
return fp
[docs]
def main():
# Read in input pars
parser = ap.ArgumentParser()
parser.add_argument("config_path", help="Location of the config file")
parser.add_argument(
"--per_obs", "-p", action="store_true", help="Run in per observation mode"
)
parser.add_argument(
"--include_cm",
"-i",
action="store_true",
help="Include the common mode in the final detector positions",
)
args = parser.parse_args()
# Open config file
with open(args.config_path, "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
per_obs = config.get("per_obs", args.per_obs)
include_cm = config.get("include_cm", args.include_cm)
# Build output path
append = config.get("append", "")
dbroot = f"db{bool(append)*'_'}{append}"
froot = f"focal_plane{bool(append)*'_'}{append}"
subdir = config.get("subdir", "")
subdir = subdir + (subdir == "") * (
per_obs * "per_obs" + (not per_obs) * "combined"
)
outdir = os.path.join(config["outdir"], subdir)
outpath = os.path.abspath(os.path.join(outdir, f"{froot}.h5"))
dbpath = os.path.join(outdir, f"{dbroot}.sqlite")
logpath = os.path.join(outdir, f"{froot}.log")
os.makedirs(outdir, exist_ok=True)
plot_dir_base = config.get("plot_dir", None)
if plot_dir_base is not None:
plot_dir_base = os.path.join(
plot_dir_base, subdir + bool(append) * "_" + append
)
plot_dir_base = os.path.abspath(plot_dir_base)
os.makedirs(plot_dir_base, exist_ok=True)
# Log file
logfile = logging.FileHandler(logpath)
logger.addHandler(logfile)
# Time range
config["start_time"] = config.get("start_time", 0)
config["stop_time"] = config.get("stop_time", 2**32)
logger.info(
"Running on time range %s to %s",
dt.datetime.fromtimestamp(config["start_time"]),
dt.datetime.fromtimestamp(config["stop_time"]),
)
# Load data
if "context" in config:
amans, obs_ids, ot_sids = _load_ctx(config)
elif "resultsets" in config:
amans, obs_ids, ot_sids = _load_rset(config)
else:
raise ValueError("No valid inputs provided")
if len(ot_sids) == 0:
raise ValueError("No stream_ids found!")
if np.any(ot_sids[:, 0] != ot_sids[0][0]):
raise ValueError("Not all AxisManagers agree on telescope!")
weight_factor = config.get("weight_factor", 1000)
min_points = config.get("min_points", 50)
gen_template = "template" not in config
template_path = config.get("template", "nominal.h5")
have_template = os.path.exists(template_path)
if not gen_template and not have_template:
logger.error("Provided template doesn't exist, trying to generate one")
gen_template = True
# Serialize config
repo = git.Repo(
os.path.abspath(os.path.dirname(__file__)), search_parent_directories=True
)
sha = f"{repo.head.object.hexsha}{'_dirty'*repo.is_dirty()}"
config["git_sha"] = sha
cfg_str = str(yaml.dump(config))
# Need to move installed OT and WS of array to templace for this
# if config.get("pad", False):
# logger.info("Padding missing arrays with template, getting complete list of arrays from template")
# if not have_template:
# logger.warning("\tNo template provided, arrays not found in any observations will be missing")
# with h5py.File(template_path) as f:
# stream_ids = list(f.keys())
# Split up into batches
# Right now either per_obs or all at once
# Maybe allow for batch my encoder angle later?
if per_obs:
logger.info("Running in per_obs mode")
batches = [([aman], [obs_id]) for aman, obs_id in zip(amans, obs_ids)]
else:
batches = [(amans, obs_ids)]
stream_ids = config.get("context", {}).get("stream_ids", [])
for amans, obs_ids in batches:
plot_dir = plot_dir_base
if per_obs:
plot_dir = os.path.join(plot_dir_base, obs_ids[0])
else:
plot_dir = os.path.join(plot_dir_base, str(config["start_time"]))
os.makedirs(plot_dir, exist_ok=True)
logger.info("Working on batch containing: %s", str(obs_ids))
# Setup db and Receiver
db, base, group = _create_db(
dbpath,
per_obs=per_obs,
obs_ids=obs_ids,
start_time=config["start_time"],
stop_time=config["stop_time"],
)
rx = Receiver()
if config.get("in_place", True):
with h5py.File(outpath, "a") as f:
if group in f:
rx = Receiver.load(f, group)
for tel, ot, ws, stream_id in ot_sids:
if len(stream_ids) > 0 and stream_id not in stream_ids:
continue
logger.info("Working on %s", stream_id)
# Limit ourselves to amans with this stream_id and restrict
amans_restrict = [
aman.copy().restrict(
"dets", aman.dets.vals[aman.det_info.stream_id == stream_id]
)
for aman in amans
if aman is not None and stream_id in aman.det_info.stream_id
]
obs_ids_restrict = [
obs_id
for aman, obs_id in zip(amans, obs_ids)
if aman is not None and stream_id in aman.det_info.stream_id
]
if len(amans_restrict) == 0:
message = "\tSomehow no AxisManagers with stream_id %s"
if per_obs:
logger.info(message, stream_id)
continue
else:
logger.error(message, stream_id)
# Make pointing config
logger.info("\t%s is in %s %s %s", stream_id, tel, ot, ws)
pointing_cfg = _mk_pointing_config(tel, ot, ws, config)
# Cnstructing the OT if we need to
if ot in rx.ot_dict:
rx.ot_dict[ot].delete_fp(stream_id)
else:
optics_tube = OpticsTube.from_pointing_cfg(pointing_cfg)
rx.optics_tubes = rx.optics_tubes + [optics_tube]
# If a template is provided load it, otherwise generate one
if gen_template:
logger.info(f"\tGenerating template for {stream_id}")
if "wafer_info" not in config:
raise ValueError("Need wafer_info to generate template")
template_det_ids, template, is_optical = op.gen_template(
config["wafer_info"], stream_id, **pointing_cfg
)
template = Template(
template_det_ids,
template,
is_optical,
pointing_cfg,
)
elif have_template:
logger.info("\tLoading template from %s", template_path)
template = _load_template(template_path, stream_id, pointing_cfg)
else:
raise ValueError(
"No template provided and unable to generate one for some reason"
)
focal_plane = FocalPlane.empty(
template, stream_id, ws, len(amans), config=cfg_str
)
if focal_plane.template is None:
raise ValueError("Template is somehow None")
n_obs = 0
for i, (aman, obs_id) in enumerate(zip(amans_restrict, obs_ids_restrict)):
logger.info("\tWorking on %s", obs_id)
if aman.dets.count < min_points:
logger.info("\t\tToo few dets found, skipping")
continue
if config.get("faked_gamma", False):
aman.pointing.gamma[:] = np.nan
# Restrict to optical dets
optical = np.isin(
aman.det_info.det_id, focal_plane.template.det_ids[template.optical]
)
aman.restrict("dets", aman.dets.vals[optical])
if aman.dets.count == 0:
logger.info("\t\tNo optical dets, skipping...")
continue
# Apply pointing model if we want to
aman = _apply_pointing_model(config, aman)
# Do some outlier cuts
if "hits" in aman.pointing:
aman.restrict(
"dets", aman.pointing.hits >= config.get("min_hits", 5)
)
if aman.dets.count == 0:
logger.info("\t\tNo high hits dets, skipping...")
continue
if "R2" in aman.pointing:
aman.restrict("dets", aman.pointing.R2 > config.get("min_r2", 0.7))
if aman.dets.count == 0:
logger.info("\t\tNo high R2 dets, skipping...")
continue
if aman.dets.count < min_points:
logger.info("\t\tToo few dets found, skipping")
continue
_restrict_inliers(aman, focal_plane)
# Mapping to template
fp, r2, det_boresight, template_msk = focal_plane.map_by_det_id(aman)
focal_plane.template.add_wafer_info(aman, template_msk)
# Try an initial alignment and get weights
try:
aff, sft = mt.get_rigid(
fp[:, :2], focal_plane.template.fp[template_msk, :2]
)
except ValueError as e:
logger.error("\t\t%s", e)
continue
aligned = mt.apply_transform(fp[:, :2], aff, sft)
if config.get("reverse_roll", False):
fp = _reverse_roll(fp, aff, sft, aman)
if np.any(np.isfinite(fp[:, 2])):
gscale, gsft = gamma_fit(
fp[:, 2], focal_plane.template.fp[template_msk, 2]
)
weights = mu.gen_weights(
np.column_stack((aligned, gscale * fp[:, 2] + gsft)),
focal_plane.template.fp[template_msk],
focal_plane.template.spacing.ravel() / weight_factor,
)
else:
weights = mu.gen_weights(
aligned,
focal_plane.template.fp[template_msk, :2],
focal_plane.template.spacing[:2].ravel() / weight_factor,
)
# ~1 sigma cut
weights[weights < 0.61] = np.nan
if np.sum(np.isfinite(weights)) < min_points / 2:
logger.error("\t\tToo few points! Skipping...")
# Store weighted values
weights = np.column_stack((weights, r2))
focal_plane.add_fp(i, fp, weights, det_boresight, template_msk)
n_obs += 1
# Compute the average focal plane with weights
(
focal_plane.avg_fp,
focal_plane.weights,
focal_plane.n_point,
focal_plane.n_gamma,
) = _avg_focalplane(focal_plane.full_fp, focal_plane.tot_weight)
tot_points = np.sum((focal_plane.n_point > 0).astype(int))
focal_plane.id_strs = focal_plane.template.id_strs
logger.info("\t%d points from %d obs in fit", tot_points, n_obs)
if tot_points < min_points:
logger.error("\tToo few points! Skipping...")
if config.get("pad", False):
logger.info("\tPadding output with template")
focal_plane.transformed = focal_plane.template.fp
focal_plane.tot_weight = None
rx.ot_dict[ot].focal_planes = rx.ot_dict[ot].focal_planes + [
focal_plane
]
continue
try:
affine, shift = mt.get_affine_two_stage(
focal_plane.template.fp[:, :2],
focal_plane.avg_fp[:, :2],
focal_plane.weights[:, 0],
)
except ValueError as e:
logger.error("\t%s", e)
continue
focal_plane.transformed[:, :2] = mt.apply_transform(
focal_plane.template.fp[:, :2], affine, shift
)
focal_plane.center_transformed[:, :2] = mt.apply_transform(
focal_plane.template.center[:, :2], affine, shift
)
# Compute transformation between the two nominal and measured pointing
focal_plane.have_gamma = np.sum(focal_plane.n_gamma) > 0
if focal_plane.have_gamma:
gamma_scale, gamma_shift = gamma_fit(
focal_plane.template.fp[:, 2], focal_plane.avg_fp[:, 2]
)
else:
logger.warning(
"\tNo polarization data availible, gammas will be based on the nominal values."
)
logger.warning(
"\tSetting gamma shift to the xi-eta rotation and scale to 1.0"
)
transform = Transform.from_split(np.array((*shift, 0.0)), affine, 1.0)
gamma_scale = 1.0
gamma_shift = transform.rot
focal_plane.transformed[:, 2] = (
focal_plane.template.fp[:, 2] * gamma_scale + gamma_shift
)
focal_plane.center_transformed[:, 2] = (
focal_plane.template.center[:, 2] * gamma_scale + gamma_shift
)
rms = np.sqrt(
np.nanmean(
(
focal_plane.avg_fp[:, : (2 + focal_plane.have_gamma)]
- focal_plane.transformed[:, : (2 + focal_plane.have_gamma)]
)
** 2
)
)
logger.info("\tRMS after transformation is %f", rms)
shift = np.array((*shift, gamma_shift))
focal_plane.transform = Transform.from_split(shift, affine, gamma_scale)
_log_vals(
focal_plane.transform.shift,
focal_plane.transform.scale,
focal_plane.transform.shear,
focal_plane.transform.rot,
("xi", "eta", "gamma"),
)
if config.get("plot", False):
plot_ufm(focal_plane, plot_dir)
plot_by_gamma(focal_plane, plot_dir)
# Add to the receiver
rx.ot_dict[ot].focal_planes = rx.ot_dict[ot].focal_planes + [focal_plane]
# Per OT common mode
todel = []
for name, ot in rx.ot_dict.items():
logger.info("Fitting common mode for %s", ot.name)
centers = np.atleast_2d(
np.array(
[
fp.template_center
for fp in ot.focal_planes
if fp.tot_weight is not None
]
)
)
centers_transformed = np.atleast_2d(
np.array(
[
fp.center_transformed
for fp in ot.focal_planes
if fp.tot_weight is not None
]
)
)
if ot.num_fps == 0 or centers.size == 0 or centers_transformed.size == 0:
logger.error("\tNo focal planes found! Skipping...")
if not config.get("pad", False):
todel.append(name)
continue
plot_ot(ot, plot_dir)
centers = centers.reshape((-1, 3))
centers_transformed = centers_transformed.reshape((-1, 3))
if centers.shape[0] < 3:
logger.warning(
"\tToo few wafers fit to compute common mode, transform will be approximated"
)
centers = np.vstack([ot.center, ot.center - 1, ot.center + 1])
centers_transformed = np.vstack(
[
mt.apply_transform(
centers, fp.transform.affine, fp.transform.shift
)
for fp in ot.focal_planes
],
)
centers = np.repeat(centers, len(ot.focal_planes), 0)
rot, sft = mt.get_rigid(centers[:, :2], centers_transformed[:, :2])
gamma_shift = np.mean(centers_transformed[:, 2] - centers[:, 2])
ot.transform_fullcm = Transform.from_split(
np.array((*sft.ravel(), gamma_shift)), rot, 1.0
)
ot.center_transformed = mt.apply_transform(
ot.center, ot.transform_fullcm.affine, ot.transform_fullcm.shift
)
_log_vals(
ot.transform_fullcm.shift,
ot.transform_fullcm.scale,
ot.transform_fullcm.shear,
ot.transform_fullcm.rot,
("xi", "eta", "gamma"),
)
logger.info("Deleting OTs: %s", str(todel))
for ot in todel:
rx.delete_ot(ot)
# Full receiver common mode
logger.info("Fitting receiver common mode")
centers = np.atleast_2d(
np.array([ot.center for ot in rx.optics_tubes if ot.num_fps > 0])
)
centers_transformed = np.atleast_2d(
np.array(
[ot.center_transformed for ot in rx.optics_tubes if ot.num_fps > 0]
)
)
no_ots = (
len(rx.optics_tubes) == 0
or centers.size == 0
or centers_transformed.size == 0
)
if no_ots and not config.get("pad", False):
logger.error("\tNo optics tubes found! Skipping...")
continue
elif len(rx.optics_tubes) == 1 or (no_ots and config.get("pad", False)):
logger.info(
"\tOnly one OT found, receiver common mode will be from this tube"
)
recv_transform = deepcopy(rx.optics_tubes[0].transform_fullcm)
else:
centers = centers.reshape((-1, 3))
centers_transformed = centers_transformed.reshape((-1, 3))
if len(rx.optics_tubes) < 3:
logger.info(
"\tNot enough OTs to fit receiver common mode, transform will be approximated"
)
centers = np.vstack(
[np.roll(np.arange(3, dtype=float), i) for i in range(3)]
)
centers_transformed = np.vstack(
[
mt.apply_transform(
centers, ot.transform.affine, ot.transform.shift
)
for ot in rx.optics_tubes
],
)
centers = np.repeat(centers, len(rx.optics_tubes), 0)
rot, sft = mt.get_rigid(centers[:, :2], centers_transformed[:, :2])
gamma_shift = np.mean(centers_transformed[:, 2] - centers[:, 2])
recv_transform = Transform.from_split(
np.array((*sft.ravel(), gamma_shift)), rot, 1.0
)
receiver = Receiver(
rx.optics_tubes,
transform=recv_transform,
include_cm=include_cm,
valid_ids=np.unique(rx.valid_ids + rx.fp_valid_ids).tolist(),
)
if not (no_ots and config.get("pad", False)):
_log_vals(
recv_transform.shift,
recv_transform.scale,
recv_transform.shear,
recv_transform.rot,
("xi", "eta", "gamma"),
)
plot_receiver(receiver, plot_dir)
# Now compute correction only transform for each ufm
# Transforms are composed as ufm(ot(rx(focal_plane)))
for ot in receiver.optics_tubes:
# Now remove the receiver CM from the OT
ot.transform.affine, ot.transform.shift = mt.decompose_transform(
ot.transform_fullcm.affine,
ot.transform_fullcm.shift,
recv_transform.affine,
recv_transform.shift,
)
ot.transform.decompose()
# Now for each fp remove the CM
for fp in ot.focal_planes:
(
fp.transform_nocm.affine,
fp.transform_nocm.shift,
) = mt.decompose_transform(
fp.transform.affine,
fp.transform.shift,
ot.transform_fullcm.affine,
ot.transform_fullcm.shift,
)
fp.transform_nocm.decompose()
# Remove the common mode if desired
if not include_cm and fp.template is not None:
fp.with_cm = False
fp.transformed = mt.apply_transform(
fp.template_fp,
fp.transform_nocm.affine,
fp.transform_nocm.shift,
)
# Make final outputs and save
logger.info("Saving data to %s", outpath)
logger.info("Writing to database at %s", dbpath)
if config.get("pad", False):
logger.info("Padding missing arrays with values from template")
with h5py.File(outpath, "a") as f:
if group in f:
del f[group]
f.create_group(group)
receiver.save(f, (db, base), group)
if __name__ == "__main__":
main()