import numpy as np
import pyfftw
import inspect
import scipy.signal as signal
import logging
from . import detrend_tod
from . import fft_ops
from sotodlib import core
logger = logging.getLogger(__name__)
[docs]
def fourier_filter(tod, filt_function,
detrend=None, resize='zero_pad',
axis_name='samps', signal_name='signal',
time_name='timestamps',
**kwargs):
"""Return a filtered tod.signal_name along the axis axis_name.
Does not change the data in the axis manager.
Arguments:
tod: axis manager
filt_function: function( freqs, tod ) function that takes a set of
frequencies and the axis manager and returns the filter in
fourier space
detrend: Method of detrending to be done before ffting. Can
be 'linear', 'mean', or None. Note that detrending here can be slow
for large arrays
resize: How to resize the axis to increase fft
speed. 'zero_pad' will increase to the next nice number (a
product of small primes compatible with the FFT
implementation). 'trim' will eliminate samples from the
end so that axis has a nice length for FFTs. None will not
change the axis length and might be quite slow. Trim will
be kinda weird here, because signal will not be returned
as the same size as it is input
axis_name: name of axis you would like to fft along
signal_name: name of the variable in tod to fft
time_name: name for getting time of data (in seconds) from tod
Returns:
signal: filtered tod.signal_name
"""
if len(tod._assignments[signal_name]) >2:
raise ValueError('fourier_filter only works for 1D or 2D data streams')
axis = getattr(tod, axis_name)
times = getattr(tod, time_name)
delta_t = (times[-1]-times[0])/axis.count
if len(tod._assignments[signal_name])==1:
n_det = 1
## signal will be at least 2D
main_idx = 1
other_idx = None
elif len(tod._assignments[signal_name])==2:
checks = np.array([x==axis_name for x in tod._assignments[signal_name]],dtype='bool')
main_idx = np.where(checks)[0][0]
other_idx = np.where(~checks)[0][0]
other_axis = getattr(tod, tod._assignments[signal_name][other_idx])
n_det = other_axis.count
if resize == 'zero_pad':
n = fft_ops.find_superior_integer(axis.count)
logger.info('fourier_filter: padding %i -> %i' % (axis.count, n))
elif resize == 'trim':
n = fft_ops.find_inferior_integer(axis.count)
logger.info('fourier_filter: trimming %i -> %i' % (axis.count, n))
elif resize is None:
n = axis.count
else:
raise ValueError('resize must be "zero_pad", "trim", or None')
if detrend is not None:
logger.info('fourier_filter: detrending.')
signal = detrend_tod(tod, detrend, axis_name=axis_name,
signal_name=signal_name, in_place=False)
else:
signal = tod[signal_name]
signal = np.atleast_2d(signal)
if isinstance(filt_function, identity_filter):
logger.info('fourier_filter: filt_function is identity; skipping FFT.')
signal = signal.copy()
else:
logger.info('fourier_filter: initializing rfft object.')
a, b, t_1, t_2 = fft_ops.build_rfft_object(n_det, n, 'BOTH')
if other_idx is not None and other_idx != 0:
## so that code can be written always along axis 1
signal = signal.transpose()
# This copy is valid for all modes of "resize"
logger.info('fourier_filter: copying in data.')
a[:,:min(n, axis.count)] = signal[:,:min(n, axis.count)]
a[:,min(n, axis.count):] = 0
## FFT Signal
logger.info('fourier_filter: FFT.')
t_1()
## Get Filter
logger.info('fourier_filter: applying filter.')
freqs = np.fft.rfftfreq(n, delta_t)
filt_function.apply(freqs, tod, b, **kwargs)
## FFT Back
logger.info('fourier_filter: IFFT.')
t_2()
# Un-pad?
signal = a[:,:min(n, axis.count)]
if other_idx is not None and other_idx != 0:
return signal.transpose()
if other_idx is None:
return signal[0]
return signal
def fft_trim(tod, axis='samps', prefer='right'):
"""Restrict AxisManager sample range so that FFTs are efficient. This
uses the find_inferior_integer function.
Args:
tod (AxisManager): Target, which is modified in place.
axis (str): Axis to target.
prefer (str): One of ['left', 'right', 'center'], indicating
whether to trim away samples from the end, the beginning, or
!equally at the beginning and end (respectively).
Returns:
The (start, stop) indices to use to slice an array and get these
samples.
"""
axis_obj = tod[axis]
old_size = axis_obj.count
new_size = fft_ops.find_inferior_integer(old_size)
offset = old_size - new_size
if prefer == 'left':
offset = 0
elif prefer == 'center':
offset //= 2
elif prefer == 'right':
pass
else:
raise ValueError(f'Invalid choice prefer="{prefer}"')
start_stop = (offset, offset+new_size)
if isinstance(axis_obj, core.OffsetAxis):
# Account for special indexing of OffsetAxis.
offset += axis_obj.offset
tod.restrict(axis, (offset, offset+new_size))
return start_stop
################################################################
# Base class... provides that a * b always returns a FilterChain.
class _chainable:
@staticmethod
def _preference(f):
return getattr(f, 'preference', 'compose')
def __mul__(self, other):
if isinstance(other, identity_filter):
return self
if isinstance(self, identity_filter):
return other
return FilterChain([self, other])
class FilterFunc(_chainable):
"""Class to support chaining of Fourier filters.
FilterFunc.deco may be used to decorate functions with signatures
like::
function_name(freqs, tod, *args, **kwargs)
"""
# Note self._fun is set in the subclass, as a class variable,
# e.g.: _fun = staticmethod(gaussian_filter)
_fun_nargs = 2
preference = 'compose'
def __init__(self, *args, **kwargs):
super().__init__()
self.args = args
self.kwargs = kwargs
def __call__(self, freqs, tod):
return self._fun(freqs, tod, *self.args, **self.kwargs)
def apply(self, freqs, tod, target):
target *= self(freqs, tod)
@classmethod
def deco(cls, fun):
class filter_func(cls):
_fun = staticmethod(fun)
# get help from someone
filter_func.__doc__ = fun.__doc__
# get arguments from someone after removing the partial args
args = list(inspect.signature(fun).parameters.values())[cls._fun_nargs:]
filter_func.__signature__ = inspect.Signature(parameters=args)
return filter_func
class FilterApplyFunc(FilterFunc):
"""Class to support chaining of Fourier filters.
@FilterApplyFunc.deco may be used to decorate functions with signatures
like:
function_name(target, freqs, tod, *, **)
Such filter functions must return a transfer function array, if
target=None. But if target is a fourier transform array, the
function must apply the transfer function to that array and return
None.
"""
_fun_nargs = 3
preference = 'apply'
def __call__(self, freqs, tod):
return self._fun(None, freqs, tod, *self.args, **self.kwargs)
def apply(self, freqs, tod, target):
return self._fun(target, freqs, tod, *self.args, **self.kwargs)
class FilterChain(_chainable):
"""A chain of Fourier filters."""
def __init__(self, items):
super().__init__()
self.links = []
for a in items:
if isinstance(a, FilterChain):
self.links.extend(a.links)
else:
self.links.append(a)
def __call__(self, freqs, tod):
# We could do better by handling self.links out of order.
filt = self.links[0](freqs, tod)
for f in self.links[1:]:
if self._preference(f) == 'apply' and filt.ndim == 2:
f.apply(freqs, tod, filt)
else:
_filt = f(freqs, tod)
# Swap those to help broadcasting work...
if _filt.ndim > filt.ndim:
filt, _filt = _filt, filt
filt *= _filt
del _filt
return filt
def apply(self, freqs, tod, target):
filt = None
for f in self.links:
if self._preference(f) == 'apply':
f.apply(freqs, tod, target)
elif filt is None:
filt = f(freqs, tod)
else:
_filt = f(freqs, tod)
# Swap those to help broadcasting work...
if _filt.ndim > filt.ndim:
filt, _filt = _filt, filt
filt *= _filt
del _filt
if filt is not None:
target *= filt
# Alias the decorators...
fft_filter = FilterFunc.deco
fft_apply_filter = FilterApplyFunc.deco
# Filtering Functions
#################
@fft_filter
def counter_1_over_f(freqs, tod, fk, n):
"""
Counter 1/f filter for noise w/ PSD that follows:
w*(1 + (fk/f)**n)
where w is the white noise level, fk is the knee frequency, and
n is the 1/f index.
"""
return 1/(1+(fk/freqs)**n)
@fft_filter
def identity_filter(freqs, tod, invert=False):
"""Identity filter (gain=1 at all frequencies).
This filter has some special handling -- the FFT will not take
place if your only intention is to apply an identify filter; also
identity_filter() * other_filter() will simply return the
other_filter().
"""
return np.ones(len(freqs))
[docs]
@fft_filter
def low_pass_butter4(freqs, tod, fc):
"""4th-order low-pass filter with f3db at fc (Hz).
"""
b, a = signal.butter(4, 2*np.pi*fc, 'lowpass', analog=True)
return np.abs(signal.freqs(b, a, 2*np.pi*freqs)[1])
[docs]
@fft_filter
def high_pass_butter4(freqs, tod, fc):
"""4th-order high-pass filter with f3db at fc (Hz).
"""
b, a = signal.butter(4, 2*np.pi*fc, 'highpass', analog=True)
return np.abs(signal.freqs(b, a, 2*np.pi*freqs)[1])
[docs]
@fft_filter
def tau_filter(freqs, tod, tau_name='timeconst', do_inverse=True):
"""tau_filter is deprecated; use timeconst_filter."""
logging.warning('tau_filter is deprecated; use timeconst_filter.')
taus = getattr(tod, tau_name)
filt = 1 + 2.0j*np.pi*taus[:,None]*freqs[None,:]
if not do_inverse:
return 1.0/filt
return filt
[docs]
@fft_apply_filter
def timeconst_filter(target, freqs, tod, timeconst=None, invert=False):
"""One-pole time constant filter for fourier_filter.
Builds filter for applying or removing time constants from signal
data.
Args:
timeconst: Array of time constant values (one per detector).
Alternately, a string indicating what member of tod to use for
the time constants array. Defaults to 'timeconst'.
invert (bool): If true, returns the inverse transfer function,
to deconvolve the time constants.
Example::
# Deconvolve time constants.
fourier_filter(tod, timeconst_filter(invert=True),
detrend='linear', resize='zero_pad')
"""
if timeconst is None:
timeconst = 'timeconst'
if isinstance(timeconst, str):
timeconst = tod[timeconst]
if target is None:
filt = 1 + 2.0j*np.pi*timeconst[:,None]*freqs[None,:]
if invert:
return filt
return 1.0/filt
# Apply filter directly to FFT in target.
assert(len(timeconst) == len(target)) # safe zip.
if invert:
for tau, dest in zip(timeconst, target):
dest *= 1.+ 2.j*np.pi*tau*freqs
else:
for tau, dest in zip(timeconst, target):
dest /= 1.+ 2.j*np.pi*tau*freqs
[docs]
@fft_filter
def timeconst_filter_single(freqs, tod, timeconst, invert=False):
"""One-pole time constant filter for fourier_filter.
This version accepts a single time constant value, in seconds. To
use different time constants for each detector, see
timeconst_filter.
Example::
# Apply a 1ms time constant.
fourier_filter(tod, timeconst_filter_single(timeconst=0.001),
detrend=None, resize='zero_pad')
"""
if invert:
return 1. + 2.j * np.pi * timeconst * freqs
return 1. / (1. + 2.j * np.pi * timeconst * freqs)
[docs]
@fft_filter
def gaussian_filter(freqs, tod, fc=0., f_sigma=None, gain=1.0, t_sigma=None):
"""Gaussian bandpass filter
Parameters:
fc (float0): Central frequency of the filter (peak of
passband), in Hz.
f_sigma (float): Standard deviation of the filter kernel, in
Hz.
gain (float): Gain of the filter.
t_sigma (float): Instead of f_sigma, set t_sigma and f_sigma =
1/(2 pi t_sigma) will be used.
The filter kernel has the shape of a normal distribution, centered
on fc with standard deviation f_sigma, and peak height gain.
"""
if t_sigma is not None and f_sigma is not None:
raise ValueError("User must not specify both f_sigma and t_sigma.")
if t_sigma is not None:
f_sigma = 1.0 / (2*np.pi*t_sigma)
if f_sigma is None:
raise ValueError('User must specify either f_sigma or t_sigma.')
return gain * np.exp(-0.5*(np.abs(freqs)-fc)**2/f_sigma**2)
[docs]
@fft_filter
def low_pass_sine2(freqs, tod, cutoff, width=None):
"""Low-pass filter. Response falls from 1 to 0 between frequencies
(cutoff - width/2, cutoff + width/2), with a sine-squared shape.
"""
if width is None:
width = cutoff * 2
phase = np.pi * np.clip((abs(freqs) - cutoff) / width, -0.5, 0.5)
return 0.5 - 0.5 * np.sin(phase)
[docs]
@fft_filter
def high_pass_sine2(freqs, tod, cutoff, width=None):
"""High-pass filter. Response rises from 0 to 1 between frequencies
(cutoff - width/2, cutoff + width/2), with a sine-squared shape.
"""
if width is None:
width = cutoff * 2
phase = np.pi * np.clip((abs(freqs) - cutoff) / width, -0.5, 0.5)
return 0.5 + 0.5 * np.sin(phase)
[docs]
@fft_filter
def iir_filter(freqs, tod, b=None, a=None, fscale=1., iir_params=None,
invert=False):
"""Infinite impulse response (IIR) filter. This sort of filter is
used in digital applications as a low-pass filter prior to
decimation. The Smurf and MCE readout filters can both be
expressed in this form.
Args:
b: numerator polynomial filter coefficients (z^0,z^1, ...)
a: denominator coefficients
fscale: scalar used to compute z = exp(-2j*pi*freqs*fscale).
In general this should be equal to 1/f_orig, where f_orig
is the original sampling frequency (before downsampling).
iir_params: Alternative way to specify b, a, and fscale (see
notes).
invert: If true, returns denom/num instead of num/denom.
Notes:
The `b` and `a` coefficients are as implemented in
scipy.signal.freqs, scipy.signal.butter, etc. The "angular
frequencies", `w`, are computed as 2*pi*freqs*fscale.
If the filter parameters (b, a, fscale) are not passed in
explicitly, they will be extracted from an AxisManager based on
the argument iir_params, which must be a dict or AxisManager
with keys "b", "a", and "fscale", or an AxisManager including
the sub-iir_params of each stream_id. In the later case, if
the filter parameters of each stream_id is different,
raises an error.
But note that:
- If iir_params is a string, tod[iir_params] is used (and must
be an AxisManager or dict).
- If iir_params is None, that's the same as passing
iir_params='iir_params'.
"""
if a is None:
# Get params from TOD?
if iir_params is None:
iir_params = 'iir_params'
if isinstance(iir_params, str):
iir_params = tod[iir_params]
if 'a' not in list(iir_params._fields.keys()):
# Check iir_param's uniformity
i = 0
for _field, _sub_iir_params in iir_params._fields.items():
if isinstance(_sub_iir_params, core.AxisManager) and 'a' in list(_sub_iir_params._fields.keys()):
sub_iir_params = _sub_iir_params
if i == 0:
_a, _b, _fscale = sub_iir_params['a'], sub_iir_params['b'], sub_iir_params['fscale']
else:
if np.any(np.hstack([sub_iir_params['a'] != _a,
sub_iir_params['b'] != _b,
sub_iir_params['fscale'] != _fscale,])):
raise ValueError('iir parameters are not uniform.')
i += 1
iir_params = sub_iir_params
try:
a = iir_params['a']
b = iir_params['b']
fscale = iir_params['fscale']
except Exception as e:
raise ValueError("Failed to extract filter parameters from "
f"iir_params={iir_params}.")
z = np.exp(-2j*np.pi*freqs * fscale)
B, A = np.polyval(b[::-1], z), np.polyval(a[::-1], z)
if invert:
return A / B
return B / A
# Functions to derive low/high/band pass filter from configuration
##################################################################
[docs]
def get_lpf(cfg):
"""
Returns a low-pass filter based on the configuration.
Args:
cfg (dict): A dictionary containing the low-pass filter configuration.
It must have the following keys:
- "type": A string specifying the type of low-pass filter. Supported values are "identity", "butter4" and "sine2".
- "cutoff": A float specifying the cutoff frequency of the low-pass filter.
- "trans_width": A float specifying the transition width of the low-pass filter (only for "sine2" type).
Returns:
filters.fourier_filter: the low-pass filter.
"""
if cfg['type'] == 'identity':
return identity_filter()
elif cfg['type'] == 'butter4':
cutoff = cfg['cutoff']
return low_pass_butter4(fc=cutoff)
elif cfg['type'] == 'sine2':
cutoff = cfg['cutoff']
trans_width = cfg['trans_width']
return low_pass_sine2(cutoff=cutoff, width=trans_width)
else:
raise ValueError('Unsupported filter type. Supported filters are `identity`, `butter4` and `sine2`')
[docs]
def get_hpf(cfg):
"""
Returns a high-pass filter based on the configuration.
Args:
cfg (dict): A dictionary containing the high-pass filter configuration.
It must have the following keys:
- "type": A string specifying the type of high-pass filter. Supported values are "identity", "butter4" and "sine2".
- "cutoff": A float specifying the cutoff frequency of the high-pass filter.
- "trans_width": A float specifying the transition width of the high-pass filter (only for "sine2" type).
Returns:
filters.fourier_filter: the high-pass filter.
"""
if cfg['type'] == 'identity':
return identity_filter()
elif cfg['type'] == 'butter4':
cutoff = cfg['cutoff']
return high_pass_butter4(fc=cutoff)
elif cfg['type'] == 'sine2':
cutoff = cfg['cutoff']
trans_width = cfg['trans_width']
return high_pass_sine2(cutoff=cutoff, width=trans_width)
else:
raise ValueError('Unsupported filter type. Supported filters are `identity`, `butter4` and `sine2`')
[docs]
def get_bpf(cfg):
"""
Returns a band-pass filter based on the configuration.
Args:
cfg (dict): A dictionary containing the band-pass filter configuration.
It must have the following keys:
- "type": A string specifying the type of band-pass filter. Supported values are "identity", "butter4" and "sine2".
- "center": A float specifying the center frequency of the band-pass filter.
- "width": A float specifying the width of the band-pass filter.
- "trans_width": A float specifying the transition width of the band-pass filter (only for "sine2" type).
Returns:
filters.fourier_filter: the band-pass filter.
"""
if cfg['type'] == 'identity':
return identity_filter()
elif cfg['type'] == 'butter4':
center = cfg['center']
width = cfg['width']
return low_pass_butter4(fc=center + width/2.) *\
high_pass_butter4(fc=center - width/2.)
elif cfg['type'] == 'sine2':
center = cfg['center']
width = cfg['width']
trans_width = cfg['trans_width']
return low_pass_sine2(cutoff=center + width/2., width=trans_width)*\
high_pass_sine2(cutoff=center - width/2., width=trans_width)
else:
raise ValueError('Unsupported filter type. Supported filters are `identity`, `butter4` and `sine2`')