import os
from numba import njit, types
from numba.typed import Dict
import numpy as np
from scipy.interpolate import interp1d
from .template import Model
from .. import units as u
from .. import utils
from pysm.utils import trapz_step_inplace
import healpy as hp
[docs]class InterpolatingComponent(Model):
def __init__(
self,
path,
input_units,
nside,
interpolation_kind="linear",
has_polarization=True,
map_dist=None,
verbose=False,
):
"""PySM component interpolating between precomputed maps
In order to save memory, maps are converted to float32, if this is not acceptable, please
open an issue on the PySM repository.
When you create the model, PySM checks the folder of the templates and stores a list of
available frequencies. Once you call `get_emission`, maps are read, ud_graded to the target
nside and stored for future use. This is useful if you are running many channels with a similar
bandpass.
If not, you can call `cached_maps.clear()` to remove the cached maps.
Parameters
----------
path : str
Path should contain maps named as the frequency in GHz e.g. 20.fits or 20.5.fits or 00100.fits
input_units : str
Any unit available in PySM (see `pysm.convert_units` e.g. `Jysr`, `MJsr`, `uK_RJ`, `K_CMB`).
nside : int
HEALPix NSIDE of the output maps
interpolation_kind : string
Currently only linear is implemented
has_polarization : bool
whether or not to simulate also polarization maps
map_dist : pysm.MapDistribution
Required for partial sky or MPI, see the PySM docs
verbose : bool
Control amount of output
"""
super().__init__(nside=nside, map_dist=map_dist)
self.maps = {}
self.maps = self.get_filenames(path)
# use a numba typed Dict so we can used in JIT compiled code
self.cached_maps = Dict.empty(
key_type=types.float32, value_type=types.float32[:, :]
)
self.freqs = np.array(list(self.maps.keys()))
self.freqs.sort()
self.input_units = input_units
self.has_polarization = has_polarization
self.interpolation_kind = interpolation_kind
self.verbose = verbose
[docs] def get_filenames(self, path):
# Override this to implement name convention
filenames = {}
for f in os.listdir(path):
if f.endswith(".fits"):
freq = float(os.path.splitext(f)[0])
filenames[freq] = os.path.join(path, f)
return filenames
[docs] @u.quantity_input
def get_emission(self, freqs: u.GHz, weights=None) -> u.uK_RJ:
nu = freqs.to(u.GHz).value
weights = utils.normalize_weights(freqs, weights)
if not np.isscalar(nu) and len(nu) == 1:
nu = nu[0]
if np.isscalar(nu):
# special case: we request only 1 frequency and that is among the ones
# available as input
check_isclose = np.isclose(self.freqs, nu)
if np.any(check_isclose):
freq = self.freqs[check_isclose][0]
out = self.read_map_by_frequency(freq)
if self.has_polarization:
return out << u.uK_RJ
else:
zeros = np.zeros_like(out)
return np.array([out, zeros, zeros]) << u.uK_RJ
else: # continue with interpolation as with an array of nus
nu = np.array([nu])
else:
nu = np.asarray(nu)
assert (
nu[0] >= self.freqs[0]
), "Frequency not supported, requested {} Ghz < lower bound {} GHz".format(
nu[0], self.freqs[0]
)
assert (
nu[-1] <= self.freqs[-1]
), "Frequency not supported, requested {} Ghz > upper bound {} GHz".format(
nu[-1], self.freqs[-1]
)
first_freq_i, last_freq_i = np.searchsorted(self.freqs, [nu[0], nu[-1]])
first_freq_i -= 1
last_freq_i += 1
freq_range = self.freqs[first_freq_i:last_freq_i]
if self.verbose:
print("Frequencies considered:", freq_range)
if self.map_dist is None or self.map_dist.pixel_indices is None:
npix = hp.nside2npix(self.nside)
else:
npix = len(self.map_dist.pixel_indices)
for freq in freq_range:
if freq not in self.cached_maps:
m = self.read_map_by_frequency(freq)
if not self.has_polarization:
m = m.reshape((1, -1))
self.cached_maps[freq] = m.astype(np.float32)
if self.verbose:
for i_pol, pol in enumerate(
"IQU" if self.has_polarization else "I"
):
print(
"Mean emission at {} GHz in {}: {:.4g} uK_RJ".format(
freq, pol, self.cached_maps[freq][i_pol].mean()
)
)
out = compute_interpolated_emission_numba(
nu, weights, freq_range, self.cached_maps
)
# the output of out is always 2D, (IQU, npix)
return out << u.uK_RJ
[docs] def read_map_by_frequency(self, freq):
filename = self.maps[freq]
return self.read_map_file(freq, filename)
[docs] def read_map_file(self, freq, filename):
if self.verbose:
print("Reading map {}".format(filename))
m = self.read_map(
filename,
field=(0, 1, 2) if self.has_polarization else 0,
unit=self.input_units,
)
return m.to(u.uK_RJ, equivalencies=u.cmb_equivalencies(freq * u.GHz)).value
@njit(parallel=False)
def compute_interpolated_emission_numba(freqs, weights, freq_range, all_maps):
output = np.zeros(
all_maps[freq_range[0]].shape, dtype=all_maps[freq_range[0]].dtype
)
index_range = np.arange(len(freq_range))
for i in range(len(freqs)):
interpolation_weight = np.interp(freqs[i], freq_range, index_range)
int_interpolation_weight = int(interpolation_weight)
m = (interpolation_weight - int_interpolation_weight) * all_maps[
freq_range[int_interpolation_weight]
]
m += (int_interpolation_weight + 1 - interpolation_weight) * all_maps[
freq_range[int_interpolation_weight + 1]
]
trapz_step_inplace(freqs, weights, i, m, output)
return output