"""
Power Spectra Computation Module
=================================
This module provides classes for computing polarization power spectra from
CMB observations using the pseudo-Câ„“ method with NaMaster (pymaster).
The module implements:
- Cross-correlation power spectra between different frequency channels
- Foreground x CMB cross-spectra for component separation
- Foreground x foreground auto-spectra
- Mode coupling correction using pre-computed coupling matrices
- LAT x SAT cross-spectrum analysis
- Parallel computation support
Classes
-------
Spectra
Main class for computing power spectra from LAT or SAT observations.
Handles observed x observed, dust x observed, sync x observed, and
foreground auto-spectra calculations.
SpectraCross
Class for computing cross-spectra between LAT and SAT observations,
used for calibration analysis and cosmic birefringence studies.
Features
--------
- Automatic mask apodization
- Pure B-mode estimation support
- Template bandpass integration
- CO and point source masking
- Parallel computation options
- Efficient caching of intermediate results
- Memory-efficient streaming computation
Example
-------
from cobi.spectra import Spectra
from cobi.simulation import LATsky, Mask
# Initialize LAT sky simulation
lat_sky = LATsky(libdir, nside=512)
# Create spectra computation object
spec = Spectra(
lat_lib=lat_sky,
common_dir=common_libdir,
aposcale=2.0,
pureB=True,
lmax=2000
)
# Compute all spectra for simulation index 0
spec.compute(idx=0, sync=True)
# Retrieve computed spectra
spectra_dict = spec.get_spectra(idx=0, sync=True)
Notes
-----
This implementation is optimized for the COBI cosmic birefringence analysis pipeline.
"""
# object oriented version of Patricia's code
import numpy as np
import healpy as hp
import pymaster as nmt
import os
from tqdm import tqdm
from cobi.simulation import LATsky, Foreground,Mask, SATsky, LATskyC, SATskyC
from cobi.utils import Logger
from cobi import mpi
from typing import Dict, Optional, Any, Union, List, Tuple
from concurrent.futures import ThreadPoolExecutor
import pickle as pl
import operator
import gc
# PDP: eventually we might want to also mask Galactic dust
[docs]
class Spectra:
[docs]
def __init__(self,
lat_lib: LATsky | SATsky,
common_dir: str,
aposcale: float = 2.0,
template_bandpass: bool = False,
pureB: bool = False,
CO: bool = True,
PS: bool = True,
galcut: int | str | float = 0,
verbose: bool = True,
cache: bool = True,
parallel: int = 0,
dust_model: int = -1,
sync_model: int = -1,
binwidth: int = 1,
lmax: int = 0,
) -> None:
"""
Initializes the Spectra class for computing and handling power spectra of observed CMB maps.
Parameters:
libdir (str): Directory where the spectra will be stored.
lat_lib (LATsky): An instance of the LATsky class containing LAT-related configurations.
aposcale (float, optional): Apodisation scale in degrees. Defaults to 2 deg
template_bandpass (bool, optional): Apply bandpass integration to the foreground template. Defaults to False.
pureB (bool, optional): Apply B-mode purification. Defaults to False
CO (bool, optional): Mask the brightest regions of CO emission. Defautls to True.
PS (bool, optional): Mask the brightest polarised extragalactic point sources. Defaults to True.
"""
self.logger = Logger(self.__class__.__name__, verbose)
self.lat = lat_lib
self.nside = self.lat.nside
libdir = self.lat.libdir
if dust_model != -1:
assert sync_model != -1, "Both dust and sync models must be specified"
self.dust_model = dust_model
self.sync_model = sync_model
self.logger.log(f"Evaluating special case: Simulation uses 'd{self.lat.dust_model}s{self.lat.sync_model}' FG model",'warning')
self.logger.log(f"The template foreground is set to d{dust_model}s{sync_model}",'warning')
fld_ext = f"_temp{dust_model}{sync_model}"
else:
self.dust_model = self.lat.dust_model
self.sync_model = self.lat.sync_model
fld_ext = ""
self.__fld_ext__ = fld_ext
if lmax > 0:
self.lmax = lmax
else:
self.lmax = min(2000,3 * self.lat.nside - 1)
deconv = self.lat.deconv_maps
if self.lmax == 2000:
self.logger.log(f"Setting lmax to 2000",'info')
libdiri = os.path.join(libdir, f"spectra_{self.nside}{'_d' if deconv else ''}_aposcale{str(aposcale).replace('.','p')}{'_pureB' if pureB else ''}" + fld_ext)
comdir = os.path.join(common_dir, f"spectra_{self.nside}{'_d' if deconv else ''}_aposcale{str(aposcale).replace('.','p')}{'_pureB' if pureB else ''}" + fld_ext)
else:
self.logger.log(f"Setting lmax to {self.lmax}",'info')
libdiri = os.path.join(libdir, f"spectra_N{self.nside}_lmax{self.lmax}{'_d' if deconv else ''}_aposcale{str(aposcale).replace('.','p')}{'_pureB' if pureB else ''}" + fld_ext)
comdir = os.path.join(common_dir, f"spectra_N{self.nside}_lmax{self.lmax}{'_d' if deconv else ''}_aposcale{str(aposcale).replace('.','p')}{'_pureB' if pureB else ''}" + fld_ext)
self.__set_dir__(libdiri, comdir)
self.temp_bp = template_bandpass
self.fg = Foreground(self.lat.foreground.basedir, self.nside, self.dust_model, self.sync_model, self.temp_bp, verbose=False)
self.binwidth = binwidth
self.binInfo = nmt.NmtBin.from_lmax_linear(self.lmax, binwidth)
self.Nell = self.binInfo.get_n_bands()
self.pureB = pureB
self.aposcale = aposcale
self.CO = CO
self.PS = PS
self.galcut = galcut
self.mask = self.get_apodised_mask()
self.fsky = np.mean(self.mask**2)**2/np.mean(self.mask**4)
# PDP: saving the spectra in this order makes the indexing of the mle easier
self.freqs = self.lat.freqs
self.Nfreq = len(self.freqs)
self.bands = []
for nu in self.freqs:
for split in range(self.lat.nsplits):
self.bands.append(f'{nu}-{split+1}')
self.Nbands = len(self.bands)
self.obs_qu_maps = None
self.dust_qu_maps = None
self.sync_qu_maps = None
self.workspace = nmt.NmtWorkspace()
self.get_coupling_matrix()
self.cache = cache
self.parallel = parallel
match self.parallel:
case 0:
msg = "No parallelization"
case 1:
msg = "Parallelized single loop"
case 2:
msg = "Parallelized double loop"
case _:
raise ValueError("Invalid parallelization option")
self.logger.log(msg,'info')
[docs]
def get_apodised_mask(self) -> np.ndarray:
fname = os.path.join(
self.wdir,
f"mask_N{self.nside}_aposcale{str(self.aposcale).replace('.','p')}{'_CO' if self.CO else ''}{'_PS' if self.PS else ''}{'_G'+str(self.galcut).replace('.','p') if self.galcut != 0 else ''}.fits",
)
if not os.path.isfile(fname):
mask_str = self.lat.__class__.__name__[:3]
if self.CO:
mask_str += 'xCO'
if self.PS:
mask_str += 'xPS'
if self.galcut != 0:
mask_str += 'xGAL'
maskobj = Mask(self.lat.basedir, self.nside, mask_str, self.aposcale,gal_cut=self.galcut)
mask = maskobj.mask
self.logger.log(f"Apodised mask saved to {fname}",'info')
hp.write_map(fname, mask, dtype=np.float32)
return mask
else:
self.logger.log(f"Reading apodised mask from {fname}",'info')
return hp.read_map(fname)
[docs]
def get_coupling_matrix(self) -> None:
"""
Computes or loads the coupling matrix for power spectrum estimation.
"""
fsky = np.round(self.fsky, 2)
fname = os.path.join(
self.wdir,
f"coupling_matrix_N{self.nside}_fsky{str(fsky).replace('.','p')}_aposcale{str(self.aposcale).replace('.','p')}_bw{self.binwidth}{'_CO' if self.CO else ''}{'_PS' if self.PS else ''}{'_pureB' if self.pureB else ''}{'_G'+str(self.galcut).replace('.','p') if self.galcut != 0 else ''}.fits",
)
if not os.path.isfile(fname):
self.logger.log("Computing coupling Matrix",'info')
mask_f = nmt.NmtField(
self.mask, [self.mask, self.mask], lmax=self.lmax, purify_b=self.pureB
)
self.workspace.compute_coupling_matrix(mask_f, mask_f, self.binInfo)
del mask_f
self.workspace.write_to(fname)
self.logger.log(f"Coupling Matrix saved to {fname}",'info')
else:
self.logger.log(f"Reading coupling Matrix from {fname}",'info')
self.workspace.read_from(fname)
[docs]
def compute_master(self, f_a: nmt.NmtField, f_b: nmt.NmtField) -> np.ndarray:
"""
Computes the decoupled power spectrum using the MASTER algorithm.
Parameters:
f_a (nmt.NmtField): First NmtField object.
f_b (nmt.NmtField): Second NmtField object.
Returns:
np.ndarray: Decoupled power spectrum.
"""
cl_coupled = nmt.compute_coupled_cell(f_a, f_b)
cl_decoupled = self.workspace.decouple_cell(cl_coupled)
return cl_decoupled
[docs]
def __set_dir__(self, idir: str, cdir: str) -> None:
"""
Sets up directories for storing power spectra and workspaces.
Parameters:
dir (str): Directory for specific spectra.
cdir (str): Common directory for spectra and workspaces.
"""
self.oxo_dir = os.path.join(idir, "obs_x_obs")
self.dxo_dir = os.path.join(idir, "dust_x_obs")
self.sxo_dir = os.path.join(idir, "sync_x_obs")
self.dxd_dir = os.path.join(cdir, "dust_x_dust")
self.sxs_dir = os.path.join(cdir, "sync_x_sync")
self.sxd_dir = os.path.join(cdir, "sync_x_dust")
self.wdir = os.path.join(cdir, "workspaces")
if mpi.rank == 0:
os.makedirs(self.oxo_dir, exist_ok=True)
os.makedirs(self.dxo_dir, exist_ok=True)
os.makedirs(self.sxo_dir, exist_ok=True)
os.makedirs(self.dxd_dir, exist_ok=True)
os.makedirs(self.sxs_dir, exist_ok=True)
os.makedirs(self.sxd_dir, exist_ok=True)
os.makedirs(self.wdir, exist_ok=True)
mpi.barrier()
[docs]
def load_obsQUmaps(self, idx: int) -> None:
"""
Loads observed Q and U Stokes parameter maps for all frequency bands.
Parameters:
idx (int): Index for the realization of the CMB map.
"""
maps = np.zeros((self.Nbands, 2, hp.nside2npix(self.nside)), dtype=np.float64)
for i, band in enumerate(self.bands):
maps[i] = self.lat.obsQU(idx, band)
self.obs_qu_maps = maps
[docs]
def Obs_qu_maps(self, idx: int, ii: int) -> np.ndarray:
return self.lat.obsQU(idx, self.bands[ii])
[docs]
def __get_fg_QUmap__(self, nu: str, fg: str) -> Tuple[np.ndarray, np.ndarray]:
"""
Retrieves or generates the Q and U Stokes parameter maps for dust emission for a specific frequency band.
Parameters:
band (str): The frequency identifier.
fg (str): Foreground type, either 'dust' or 'sync'
Returns:
Tuple[np.ndarray, np.ndarray]: Q and U maps for dust emission.
"""
if fg not in ['dust', 'sync']:
raise ValueError('Unknown foreground')
fname = os.path.join(self.fg.libdir, f"{fg}QU_N{self.nside}_{nu}_template{'_bp' if self.temp_bp else ''}.fits")
if os.path.isfile(fname):
m = hp.read_map(fname, field=(0, 1))
return m[0], m[1]
else:
if fg=='dust':
m = self.fg.dustQU(nu)
elif fg=='sync':
m = self.fg.syncQU(nu)
E, B = hp.map2alm_spin(m, 2, lmax=self.lmax)
fwhm = self.lat.fwhm[self.freqs==nu][0]
bl = hp.gauss_beam(np.radians(fwhm / 60), pol=True, lmax=self.lmax)
pwf = np.array(hp.pixwin(self.nside, pol=True, lmax=self.lmax))
hp.almxfl(E, bl[:,1]*pwf[1,:], inplace=True)
hp.almxfl(B, bl[:,2]*pwf[1,:], inplace=True)
m = hp.alm2map_spin([E, B], self.nside, 2, self.lmax)*self.mask
hp.write_map(fname, m, dtype=np.float64)
return m[0], m[1]
[docs]
def load_dustQUmaps(self) -> None:
"""
Loads dust Q and U Stokes parameter maps for all frequency bands.
"""
maps = np.zeros((self.Nfreq, 2, hp.nside2npix(self.nside)), dtype=np.float64)
for i, nu in enumerate(self.freqs):
maps[i] = self.__get_fg_QUmap__(nu, 'dust')
self.dust_qu_maps = maps
[docs]
def Dust_qu_maps(self, ii: int) -> np.ndarray:
return self.__get_fg_QUmap__(self.freqs[ii], 'dust')
[docs]
def load_syncQUmaps(self) -> None:
"""
Loads synchrotron Q and U Stokes parameter maps for all frequency bands.
"""
maps = np.zeros((self.Nfreq, 2, hp.nside2npix(self.nside)), dtype=np.float64)
for i, nu in enumerate(self.freqs):
maps[i] = self.__get_fg_QUmap__(nu, 'sync')
self.sync_qu_maps = maps
[docs]
def Sync_qu_maps(self, ii: int) -> np.ndarray:
return self.__get_fg_QUmap__(self.freqs[ii], 'sync')
[docs]
def obs_x_obs_check(self, idx: int, read_test=False) -> None:
"""
Checks if the observed x observed power spectra have been computed for all frequency bands.
Parameters:
idx (int): Index for the realization of the CMB map.
"""
c = []
for ii in range(self.Nbands):
fname = os.path.join(
self.oxo_dir,
f"obs_x_obs_{self.bands[ii]}{'_obsBP' if self.lat.bandpass else ''}{'_d' if self.lat.deconv_maps else ''}_{idx:03d}.npy",
)
c.append(os.path.isfile(fname))
if read_test and os.path.isfile(fname):
try:
_ = np.load(fname)
except:
print(f"Error loading {fname}")
return c
[docs]
def __obs_x_obs_helper_series__(self, ii: int, idx: int, recache: bool = False) -> np.ndarray:
"""
Helper function:
Computes or loads the observed x observed power spectra for a specific frequency band.
Parameters:
ii (int): Index for the current frequency band.
idx (int): Index for the realization of the CMB map.
Returns:
np.ndarray: Power spectra for the observed x observed fields.
"""
if self.__fld_ext__ != "":
self.logger.log(f"Special case: Assumes that there exsist a previous run",'warning')
self.logger.log(f"Special case: the default obsxobs directory is {self.oxo_dir}",'warning')
oxo_dir = self.oxo_dir.replace(self.__fld_ext__,'')
self.logger.log(f"Special case: the obsxobs directory is set to {oxo_dir}",'warning')
else:
oxo_dir = self.oxo_dir
fname = os.path.join(
oxo_dir,
f"obs_x_obs_{self.bands[ii]}{'_obsBP' if self.lat.bandpass else ''}{'_d' if self.lat.deconv_maps else ''}_{idx:03d}.npy",
)
if os.path.isfile(fname) and not recache:
try:
return np.load(fname)
except:
self.logger.log(f"Error loading {fname}",'error')
self.logger.log(f"Recomputing band:{ii},simulation:{idx}",'info')
return self.__obs_x_obs_helper_series__(ii, idx, recache=True)
else:
cl = np.zeros(
(self.Nbands, self.Nbands, 3, self.Nell + 2), dtype=np.float64
)
#assert self.obs_qu_maps is not None, "Observed Q and U maps not loaded" #changed for the special case
fp_i = nmt.NmtField(
self.mask, self.Obs_qu_maps(idx,ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.obs_qu_maps[ii]
masked_on_input=False
)
for jj in range(ii, self.Nbands, 1):
fp_j = nmt.NmtField(
self.mask, self.Obs_qu_maps(idx,jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.obs_qu_maps[jj]
masked_on_input=False
)
cl_ij = self.compute_master(fp_i, fp_j) # (EiEj, EiBj, BiEj, BiBj)
cl[ii, jj, 0, 2:] = cl_ij[0, :] # EiEj
cl[ii, jj, 1, 2:] = cl_ij[3, :] # BiBj
cl[ii, jj, 2, 2:] = cl_ij[1, :] # EiBj
if ii != jj:
cl[jj, ii, 0, 2:] = cl_ij[0, :] # EjEi = EiEj
cl[jj, ii, 1, 2:] = cl_ij[3, :] # BjBi = BiBj
cl[jj, ii, 2, 2:] = cl_ij[2, :] # EjBi
del fp_j
if self.cache:
np.save(fname, cl)
return cl
[docs]
def __obs_x_obs_helper_parallel__(self, ii: int, idx: int, recache: bool = False) -> np.ndarray:
"""
Helper function:
Computes or loads the observed x observed power spectra for a specific frequency band.
Parameters:
ii (int): Index for the current frequency band.
idx (int): Index for the realization of the CMB map.
Returns:
np.ndarray: Power spectra for the observed x observed fields.
"""
if self.__fld_ext__ != "":
self.logger.log(f"Special case: Assumes that there exsist a previous run",'warning')
self.logger.log(f"Special case: the default obsxobs directory is {self.oxo_dir}",'warning')
oxo_dir = self.oxo_dir.replace(self.__fld_ext__,'')
self.logger.log(f"Special case: the obsxobs directory is set to {oxo_dir}",'warning')
else:
oxo_dir = self.oxo_dir
fname = os.path.join(
oxo_dir,
f"obs_x_obs_{self.bands[ii]}{'_obsBP' if self.lat.bandpass else ''}{'_d' if self.lat.deconv_maps else ''}_{idx:03d}.npy",
)
if os.path.isfile(fname) and not recache:
try:
return np.load(fname)
except:
self.logger.log(f"Error loading {fname}",'error')
self.logger.log(f"Recomputing band:{ii},simulation:{idx}",'info')
return self.__obs_x_obs_helper_parallel__(ii, idx, recache=True)
else:
cl = np.zeros(
(self.Nbands, self.Nbands, 3, self.Nell + 2), dtype=np.float64
)
# assert self.obs_qu_maps is not None, "Observed Q and U maps not loaded" #changed for the special case
fp_i = nmt.NmtField(
self.mask, self.Obs_qu_maps(idx,ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.obs_qu_maps[ii]
masked_on_input=False
)
def compute_for_band(jj):
# assert self.obs_qu_maps is not None, "Observed Q and U maps not loaded" #changed for the special case
fp_j = nmt.NmtField(
self.mask, self.Obs_qu_maps(idx,jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.obs_qu_maps[jj]
masked_on_input=False
)
cl_ij = self.compute_master(fp_i, fp_j) # (EiEj, EiBj, BiEj, BiBj)
# Update the cl array in the appropriate positions
cl[ii, jj, 0, 2:] = cl_ij[0, :] # EiEj
cl[ii, jj, 1, 2:] = cl_ij[3, :] # BiBj
cl[ii, jj, 2, 2:] = cl_ij[1, :] # EiBj
if ii != jj:
cl[jj, ii, 0, 2:] = cl_ij[0, :] # EjEi = EiEj
cl[jj, ii, 1, 2:] = cl_ij[3, :] # BjBi = BiBj
cl[jj, ii, 2, 2:] = cl_ij[2, :] # EjBi
del fp_j
# Use ThreadPoolExecutor to parallelize the loop
with ThreadPoolExecutor() as executor:
executor.map(compute_for_band, range(ii, self.Nbands, 1))
if self.cache:
np.save(fname, cl)
return cl
def __obs_x_obs_helper__(self, ii: int, idx: int) -> np.ndarray:
if self.parallel == 2:
return self.__obs_x_obs_helper_parallel__(ii, idx)
else:
return self.__obs_x_obs_helper_series__(ii, idx)
[docs]
def obs_x_obs(self, idx: int, progress: bool = False,) -> np.ndarray:
"""
Computes or loads the observed x observed power spectra for all frequency bands.
Parameters:
idx (int): Index for the realization of the CMB map.
progress (bool, optional): If True, displays a progress bar. Defaults to False.
parallel (int, optional): If 0, runs serially; otherwise, runs with multithreading. Defaults to 1.
Returns:
np.ndarray: Combined power spectra for the observed x observed fields across all bands.
"""
cl = np.zeros((self.Nbands, self.Nbands, 3, self.Nell + 2), dtype=np.float64)
def process_band(ii):
return self.__obs_x_obs_helper__(ii, idx)
if self.parallel == 0:
# Serial execution
for ii in tqdm(
range(self.Nbands),
desc="obs x obs spectra",
unit="band",
disable=not progress,
):
cl += process_band(ii)
else:
# Parallel execution
if progress:
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(process_band, range(self.Nbands)),
total=self.Nbands,
desc="obs x obs spectra",
unit="band"):
cl += result
else:
with ThreadPoolExecutor() as executor:
for result in executor.map(process_band, range(self.Nbands)):
cl += result
return cl
[docs]
def __fg_x_obs_helper_series__(self, ii: int, idx: int, fg: str, recache: bool = False) -> np.ndarray:
"""
Helper function:
Computes or loads the dust x observed power spectra for a specific frequency band.
Parameters:
ii (int): Index for the current frequency band.
idx (int): Index for the realization of the CMB map.
fg (str): Type of foregrounds, either 'dust' or 'sync'
Returns:
np.ndarray: Power spectra for the dust x observed fields.
"""
if fg not in ['dust', 'sync']:
raise ValueError('Unknown foreground')
if fg=='dust':
base_dir = self.dxo_dir
elif fg=='sync':
base_dir = self.sxo_dir
fname = os.path.join(base_dir,
f"{fg}_x_obs_{self.freqs[ii]}{'_obsBP' if self.lat.bandpass else ''}{'_tempBP' if self.temp_bp else ''}_{idx:03d}.npy",
)
if os.path.isfile(fname) and not recache:
try:
return np.load(fname)
except:
self.logger.log(f"Error loading {fname}",'error')
self.logger.log(f"Recomputing band:{ii},simulation:{idx}, FG: {fg}",'info')
return self.__fg_x_obs_helper_series__(ii, idx, fg, recache=True)
else:
cl = np.zeros((self.Nfreq, self.Nbands, 4, self.Nell + 2), dtype=np.float64)
if fg=='dust':
fp_i = nmt.NmtField(
self.mask, self.Dust_qu_maps(ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.dust_qu_maps[ii]
masked_on_input=False
)
elif fg=='sync':
fp_i = nmt.NmtField(
self.mask, self.Sync_qu_maps(ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.sync_qu_maps[ii]
masked_on_input=False
)
for jj in range(0, self.Nbands, 1):
fp_j = nmt.NmtField(
self.mask, self.Obs_qu_maps(idx,jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.obs_qu_maps[jj]
masked_on_input=False
)
cl_ij = self.compute_master(fp_i,fp_j) # (EiEj, EiBj, BiEj, BiBj)
cl[ii, jj, 0, 2:] = cl_ij[0, :] # EiEj
cl[ii, jj, 1, 2:] = cl_ij[3, :] # BiBj
cl[ii, jj, 2, 2:] = cl_ij[1, :] # EiBj
cl[ii, jj, 3, 2:] = cl_ij[2, :] # BiEj
del fp_j
if self.cache:
np.save(fname, cl)
return cl
[docs]
def __fg_x_obs_helper_parallel__(self, ii: int, idx: int, fg: str, recache: bool = False) -> np.ndarray:
"""
Helper function:
Computes or loads the dust x observed power spectra for a specific frequency band.
Parameters:
ii (int): Index for the current frequency band.
idx (int): Index for the realization of the CMB map.
fg (str): Type of foregrounds, either 'dust' or 'sync'
Returns:
np.ndarray: Power spectra for the dust x observed fields.
"""
if fg not in ['dust', 'sync']:
raise ValueError('Unknown foreground')
base_dir = self.dxo_dir if fg == 'dust' else self.sxo_dir
fname = os.path.join(
base_dir,
f"{fg}_x_obs_{self.freqs[ii]}{'_obsBP' if self.lat.bandpass else ''}{'_tempBP' if self.temp_bp else ''}_{idx:03d}.npy",
)
if os.path.isfile(fname) and not recache:
try:
return np.load(fname)
except:
self.logger.log(f"Error loading {fname}",'error')
self.logger.log(f"Recomputing band:{ii},simulation:{idx}, FG: {fg}",'info')
return self.__fg_x_obs_helper_parallel__(ii, idx, fg, recache=True)
else:
cl = np.zeros((self.Nfreq, self.Nbands, 4, self.Nell + 2), dtype=np.float64)
# Choose the field based on the foreground type
fp_i = nmt.NmtField(
self.mask,
self.Dust_qu_maps(ii) if fg == 'dust' else self.Sync_qu_maps(ii), #changed from self.dust_qu_maps[ii] and self.sync_qu_maps[ii]
lmax=self.lmax,
purify_b=self.pureB,
masked_on_input=False
)
def compute_for_band(jj):
# Inner function to process each band in parallel
fp_j = nmt.NmtField(
self.mask, self.Obs_qu_maps(idx,jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.obs_qu_maps[jj]
masked_on_input=False
)
cl_ij = self.compute_master(fp_i, fp_j) # (EiEj, EiBj, BiEj, BiBj)
# Update the cl array in the appropriate positions
cl[ii, jj, 0, 2:] = cl_ij[0, :] # EiEj
cl[ii, jj, 1, 2:] = cl_ij[3, :] # BiBj
cl[ii, jj, 2, 2:] = cl_ij[1, :] # EiBj
cl[ii, jj, 3, 2:] = cl_ij[2, :] # BiEj
del fp_j
# Use ThreadPoolExecutor to parallelize the inner loop
with ThreadPoolExecutor() as executor:
executor.map(compute_for_band, range(0, self.Nbands, 1))
if self.cache:
np.save(fname, cl)
return cl
def __fg_x_obs_helper__(self, ii: int, idx: int, fg: str) -> np.ndarray:
if self.parallel == 2:
return self.__fg_x_obs_helper_parallel__(ii, idx, fg)
else:
return self.__fg_x_obs_helper_series__(ii, idx, fg)
[docs]
def dust_x_obs(self, idx: int, progress: bool = False) -> np.ndarray:
"""
Computes or loads the dust x observed power spectra for all frequency bands.
Parameters:
idx (int): Index for the realization of the CMB map.
progress (bool, optional): If True, displays a progress bar. Defaults to False.
parallel (int, optional): If 0, runs serially; otherwise, runs with multithreading. Defaults to 1.
Returns:
np.ndarray: Combined power spectra for the dust x observed fields across all bands.
"""
cl = np.zeros((self.Nfreq, self.Nbands, 4, self.Nell + 2), dtype=np.float64)
def process_band(ii):
return self.__fg_x_obs_helper__(ii, idx, 'dust')
if self.parallel == 0:
# Serial execution
for ii in tqdm(
range(self.Nfreq),
desc="dust x obs spectra",
unit="band",
disable=not progress,
):
cl += process_band(ii)
else:
# Parallel execution
if progress:
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(process_band, range(self.Nfreq)),
total=self.Nfreq,
desc="dust x obs spectra",
unit="band"):
cl += result
else:
with ThreadPoolExecutor() as executor:
for result in executor.map(process_band, range(self.Nfreq)):
cl += result
return cl
[docs]
def dust_x_obs_check(self, idx: int) -> None:
"""
Checks if the dust x observed power spectra have been computed for all frequency bands.
Parameters:
idx (int): Index for the realization of the CMB map.
"""
c = []
for ii in range(self.Nfreq):
fname = os.path.join(
self.dxo_dir,
f"dust_x_obs_{self.freqs[ii]}{'_obsBP' if self.lat.bandpass else ''}{'_tempBP' if self.temp_bp else ''}_{idx:03d}.npy",
)
c.append(os.path.isfile(fname))
return c
[docs]
def sync_x_obs_check(self, idx: int) -> None:
"""
Checks if the synchrotron x observed power spectra have been computed for all frequency bands.
Parameters:
idx (int): Index for the realization of the CMB map.
"""
c = []
for ii in range(self.Nfreq):
fname = os.path.join(
self.sxo_dir,
f"sync_x_obs_{self.freqs[ii]}{'_obsBP' if self.lat.bandpass else ''}{'_tempBP' if self.temp_bp else ''}_{idx:03d}.npy",
)
c.append(os.path.isfile(fname))
return c
[docs]
def sync_x_obs(self, idx: int, progress: bool = False) -> np.ndarray:
"""
Computes or loads the synchrotron x observed power spectra for all frequency bands.
Parameters:
idx (int): Index for the realization of the CMB map.
progress (bool, optional): If True, displays a progress bar. Defaults to False.
parallel (int, optional): Controls parallelization.
0 = serial, 2 = multithreading.
Defaults to using self.parallel.
Returns:
np.ndarray: Combined power spectra for the synchrotron x observed fields across all bands.
"""
cl = np.zeros((self.Nfreq, self.Nbands, 4, self.Nell + 2), dtype=np.float64)
def process_band(ii):
return self.__fg_x_obs_helper__(ii, idx, 'sync')
if self.parallel == 0:
# Serial execution
for ii in tqdm(
range(self.Nfreq),
desc="sync x obs spectra",
unit="band",
disable=not progress,
):
cl += process_band(ii)
else:
# Parallel execution using multithreading
if progress:
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(process_band, range(self.Nfreq)),
total=self.Nfreq,
desc="sync x obs spectra",
unit="band"):
cl += result
else:
with ThreadPoolExecutor() as executor:
for result in executor.map(process_band, range(self.Nfreq)):
cl += result
return cl
[docs]
def __fg_x_fg_helper_series__(self, ii: int, fg: str, recache:bool = False) -> np.ndarray:
"""
Helper function:
Computes or loads the synchrotron x synchrotron power spectra for a specific frequency band.
Parameters:
ii (int): Index for the current frequency band.
fg (str): Type of foregrounds, either 'dust' or 'sync'
Returns:
np.ndarray: Power spectra for the synchrotron x synchrotron fields.
"""
if fg not in ['dust', 'sync']:
raise ValueError('Unknown foreground')
if fg=='dust':
base_dir = self.dxd_dir
model = self.dust_model
elif fg=='sync':
base_dir = self.sxs_dir
model = self.sync_model
fname = os.path.join(base_dir,
f"{fg}_x_{fg}_{model}{self.freqs[ii]}{'_tempBP' if self.temp_bp else ''}.npy",
)
if os.path.isfile(fname):
try:
return np.load(fname)
except:
self.logger.log(f"Error loading {fname}",'error')
self.logger.log(f"Recomputing band:{ii}, FG: {fg}",'info')
return self. __fg_x_fg_helper_series__(ii, fg, recache=True)
else:
cl = np.zeros(
(self.Nfreq, self.Nfreq, 3, self.Nell + 2), dtype=np.float64
)
if fg=='dust':
fp_i = nmt.NmtField(
self.mask, self.Dust_qu_maps(ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.dust_qu_maps[ii]
masked_on_input=False
)
elif fg=='sync':
fp_i = nmt.NmtField(
self.mask, self.Sync_qu_maps(ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.sync_qu_maps[ii]
masked_on_input=False
)
for jj in range(ii, self.Nfreq, 1):
if fg=='dust':
fp_j = nmt.NmtField(
self.mask, self.Dust_qu_maps(jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.dust_qu_maps[jj]
masked_on_input=False
)
elif fg=='sync':
fp_j = nmt.NmtField(
self.mask, self.Sync_qu_maps(jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.sync_qu_maps[jj]
masked_on_input=False
)
cl_ij = self.compute_master(fp_i, fp_j)
cl[ii, jj, 0, 2:] = cl_ij[0, :] # EiEj
cl[ii, jj, 1, 2:] = cl_ij[3, :] # BiBj
cl[ii, jj, 2, 2:] = cl_ij[1, :] # EiBj
if ii != jj:
cl[jj, ii, 0, 2:] = cl_ij[0, :] # EjEi = EiEj
cl[jj, ii, 1, 2:] = cl_ij[3, :] # BjBi = BiBj
cl[jj, ii, 2, 2:] = cl_ij[2, :] # EjBi
del fp_j
if self.cache:
np.save(fname, cl)
return cl
[docs]
def __fg_x_fg_helper_parallel__(self, ii: int, fg: str, recache:bool = False) -> np.ndarray:
"""
Helper function:
Computes or loads the synchrotron x synchrotron power spectra for a specific frequency band.
Parameters:
ii (int): Index for the current frequency band.
fg (str): Type of foregrounds, either 'dust' or 'sync'
Returns:
np.ndarray: Power spectra for the synchrotron x synchrotron fields.
"""
if fg not in ['dust', 'sync']:
raise ValueError('Unknown foreground')
base_dir = self.dxd_dir if fg == 'dust' else self.sxs_dir
model = self.dust_model if fg == 'dust' else self.sync_model
fname = os.path.join(
base_dir,
f"{fg}_x_{fg}_{model}{self.freqs[ii]}{'_tempBP' if self.temp_bp else ''}.npy",
)
if os.path.isfile(fname):
try:
return np.load(fname)
except:
self.logger.log(f"Error loading {fname}",'error')
self.logger.log(f"Recomputing band:{ii}, FG: {fg}",'info')
return self.__fg_x_fg_helper_parallel__(ii, fg, recache=True)
else:
cl = np.zeros((self.Nfreq, self.Nfreq, 3, self.Nell + 2), dtype=np.float64)
if fg == 'dust':
fp_i = nmt.NmtField(
self.mask, self.Dust_qu_maps(ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.dust_qu_maps[ii]
masked_on_input=False
)
elif fg == 'sync':
fp_i = nmt.NmtField(
self.mask, self.Sync_qu_maps(ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.sync_qu_maps[ii]
masked_on_input=False
)
def process_jj(jj):
if fg == 'dust':
fp_j = nmt.NmtField(
self.mask, self.Dust_qu_maps(jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.dust_qu_maps[jj]
masked_on_input=False
)
elif fg == 'sync':
fp_j = nmt.NmtField(
self.mask, self.Sync_qu_maps(jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.sync_qu_maps[jj]
masked_on_input=False
)
cl_ij = self.compute_master(fp_i, fp_j)
# Update cl for the given indices
cl[ii, jj, 0, 2:] = cl_ij[0, :] # EiEj
cl[ii, jj, 1, 2:] = cl_ij[3, :] # BiBj
cl[ii, jj, 2, 2:] = cl_ij[1, :] # EiBj
if ii != jj:
cl[jj, ii, 0, 2:] = cl_ij[0, :] # EjEi = EiEj
cl[jj, ii, 1, 2:] = cl_ij[3, :] # BjBi = BiBj
cl[jj, ii, 2, 2:] = cl_ij[2, :] # EjBi
del fp_j
# Parallelize the loop over jj using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
executor.map(process_jj, range(ii, self.Nfreq, 1))
if self.cache:
np.save(fname, cl)
return cl
def __fg_x_fg_helper__(self, ii: int, fg: str) -> np.ndarray:
if self.parallel == 2:
return self.__fg_x_fg_helper_parallel__(ii, fg)
else:
return self.__fg_x_fg_helper_series__(ii, fg)
[docs]
def sync_x_sync(self, progress: bool = False) -> np.ndarray:
"""
Computes or loads the synchrotron x synchrotron power spectra for all frequency bands.
Parameters:
progress (bool, optional): If True, displays a progress bar. Defaults to False.
Returns:
np.ndarray: Combined power spectra for the synchrotron x synchrotron fields across all bands.
"""
cl = np.zeros((self.Nfreq, self.Nfreq, 3, self.Nell + 2), dtype=np.float64)
def process_band(ii):
return self.__fg_x_fg_helper__(ii, 'sync')
if self.parallel == 0:
# Serial execution
for ii in tqdm(
range(self.Nfreq),
desc="sync x sync spectra",
unit="band",
disable=not progress,
):
cl += process_band(ii)
else:
# Parallel execution using multithreading
if progress:
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(process_band, range(self.Nfreq)),
total=self.Nfreq,
desc="sync x sync spectra",
unit="band"):
cl += result
else:
with ThreadPoolExecutor() as executor:
for result in executor.map(process_band, range(self.Nfreq)):
cl += result
return cl
[docs]
def dust_x_dust(self, progress: bool = False) -> np.ndarray:
"""
Computes or loads the dust x dust power spectra for all frequency bands.
Parameters:
progress (bool, optional): If True, displays a progress bar. Defaults to False.
Returns:
np.ndarray: Combined power spectra for the dust x dust fields across all bands.
"""
cl = np.zeros((self.Nfreq, self.Nfreq, 3, self.Nell + 2), dtype=np.float64)
def process_band(ii):
return self.__fg_x_fg_helper__(ii, 'dust')
if self.parallel == 0:
# Serial execution
for ii in tqdm(
range(self.Nfreq),
desc="dust x dust spectra",
unit="band",
disable=not progress,
):
cl += process_band(ii)
else:
# Parallel execution using multithreading
if progress:
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(process_band, range(self.Nfreq)),
total=self.Nfreq,
desc="dust x dust spectra",
unit="band"):
cl += result
else:
with ThreadPoolExecutor() as executor:
for result in executor.map(process_band, range(self.Nfreq)):
cl += result
return cl
[docs]
def __sync_x_dust_helper_series__(self, ii: int) -> np.ndarray:
"""
Helper function:
Computes or loads the synchrotron x dust power spectra for a specific frequency band.
Parameters:
ii (int): Index for the current frequency band.
Returns:
np.ndarray: Power spectra for the synchrotron x dust fields.
"""
fname = os.path.join(self.sxd_dir, f"sync{self.sync_model}_x_dust{self.dust_model}_{self.freqs[ii]}{'_tempBP' if self.temp_bp else ''}.npy")
if os.path.isfile(fname):
return np.load(fname)
else:
cl = np.zeros(
(self.Nfreq, self.Nfreq, 4, self.Nell + 2), dtype=np.float64
)
fp_i = nmt.NmtField(
self.mask, self.Sync_qu_maps(ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.sync_qu_maps[ii]
masked_on_input=False
)
for jj in range(0, self.Nfreq, 1):
fp_j = nmt.NmtField(
self.mask, self.Dust_qu_maps(jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.dust_qu_maps[jj]
masked_on_input=False
)
cl_ij = self.compute_master(fp_i,fp_j) # (EiEj, EiBj, BiEj, BiBj)
cl[ii, jj, 0, 2:] = cl_ij[0, :] # EiEj
cl[ii, jj, 1, 2:] = cl_ij[3, :] # BiBj
cl[ii, jj, 2, 2:] = cl_ij[1, :] # EiBj
cl[ii, jj, 3, 2:] = cl_ij[2, :] # BiEj
del fp_j
if self.cache:
np.save(fname, cl)
return cl
[docs]
def __sync_x_dust_helper_parallel__(self, ii: int) -> np.ndarray:
"""
Helper function:
Computes or loads the synchrotron x dust power spectra for a specific frequency band.
Parameters:
ii (int): Index for the current frequency band.
Returns:
np.ndarray: Power spectra for the synchrotron x dust fields.
"""
fname = os.path.join(self.sxd_dir, f"sync{self.sync_model}_x_dust{self.dust_model}_{self.freqs[ii]}{'_tempBP' if self.temp_bp else ''}.npy")
if os.path.isfile(fname):
return np.load(fname)
else:
cl = np.zeros((self.Nfreq, self.Nfreq, 4, self.Nell + 2), dtype=np.float64)
fp_i = nmt.NmtField(
self.mask, self.Sync_qu_maps(ii), lmax=self.lmax, purify_b=self.pureB, #changed from self.sync_qu_maps[ii]
masked_on_input=False
)
def process_jj(jj):
fp_j = nmt.NmtField(
self.mask, self.Dust_qu_maps(jj), lmax=self.lmax, purify_b=self.pureB, #changed from self.dust_qu_maps[jj]
masked_on_input=False
)
cl_ij = self.compute_master(fp_i, fp_j) # (EiEj, EiBj, BiEj, BiBj)
# Update cl for the given indices
cl[ii, jj, 0, 2:] = cl_ij[0, :] # EiEj
cl[ii, jj, 1, 2:] = cl_ij[3, :] # BiBj
cl[ii, jj, 2, 2:] = cl_ij[1, :] # EiBj
cl[ii, jj, 3, 2:] = cl_ij[2, :] # BiEj
del fp_j
# Parallelize the loop over jj using ThreadPoolExecutor
num_workers = os.cpu_count() # Utilize all available CPU cores
with ThreadPoolExecutor(max_workers=num_workers) as executor:
executor.map(process_jj, range(self.Nfreq))
if self.cache:
np.save(fname, cl)
return cl
def __sync_x_dust_helper__(self, ii: int) -> np.ndarray:
if self.parallel == 2:
return self.__sync_x_dust_helper_parallel__(ii)
else:
return self.__sync_x_dust_helper_series__(ii)
[docs]
def sync_x_dust(self, progress: bool = False) -> np.ndarray:
"""
Computes or loads the synchrotron x dust power spectra for all frequency bands.
Parameters:
progress (bool, optional): If True, displays a progress bar. Defaults to False.
Returns:
np.ndarray: Combined power spectra for the synchrotron x dust fields across all bands.
"""
cl = np.zeros((self.Nfreq, self.Nfreq, 4, self.Nell + 2), dtype=np.float64)
def process_band(ii):
return self.__sync_x_dust_helper__(ii)
if self.parallel == 0:
# Serial execution
for ii in tqdm(
range(self.Nfreq),
desc="sync x dust spectra",
unit="band",
disable=not progress,
):
cl += process_band(ii)
else:
# Parallel execution using multithreading
if progress:
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(process_band, range(self.Nfreq)),
total=self.Nfreq,
desc="sync x dust spectra",
unit="band"):
cl += result
else:
with ThreadPoolExecutor() as executor:
for result in executor.map(process_band, range(self.Nfreq)):
cl += result
return cl
[docs]
def clear_obs_qu_maps(self) -> None:
"""Clears the loaded observed Q and U maps to free up memory."""
self.obs_qu_maps = None
[docs]
def clear_dust_qu_maps(self) -> None:
"""Clears the loaded dust Q and U maps to free up memory."""
self.dust_qu_maps = None
[docs]
def clear_sync_qu_maps(self) -> None:
"""Clears the loaded synchrotron Q and U maps to free up memory."""
self.sync_qu_maps = None
[docs]
def compute(self, idx: int, sync: bool = False) -> None:
"""
Computes and stores all relevant spectra for a given realization index.
Parameters:
idx (int): Index for the realization of the CMB map.
sync (bool, optional): If True, calculate also synchrotron power spectra. Defaults to False.
"""
self.load_dustQUmaps()
dxd = self.dust_x_dust(progress=True)
self.load_obsQUmaps(idx)
oxo = self.obs_x_obs(idx, progress=True)
dxo = self.dust_x_obs(idx, progress=True)
if sync:
self.load_syncQUmaps()
sxd = self.sync_x_dust(progress=True)
self.clear_dust_qu_maps()
sxs = self.sync_x_sync(progress=True)
sxo = self.sync_x_obs(idx, progress=True)
self.clear_obs_qu_maps()
self.clear_sync_qu_maps()
del (oxo, dxo, dxd, sxd, sxs, sxo)
else:
self.clear_dust_qu_maps()
self.clear_obs_qu_maps()
del (oxo, dxo, dxd)
[docs]
def Compute(self, idx: int, sync: bool = False) -> None:
dxd = self.dust_x_dust(progress=True)
oxo = self.obs_x_obs(idx, progress=True)
dxo = self.dust_x_obs(idx, progress=True)
if sync:
sxd = self.sync_x_dust(progress=True)
sxs = self.sync_x_sync(progress=True)
sxo = self.sync_x_obs(idx, progress=True)
del (oxo, dxo, dxd, sxd, sxs, sxo)
else:
del (oxo, dxo, dxd)
[docs]
def _compute_keep_idx_bands(self, avoid_bands):
"""Indices for axes of length Nbands (e.g., 12), based on self.bands."""
if not avoid_bands:
return np.arange(self.Nbands)
avoid = set(map(str, avoid_bands))
keep = [i for i, b in enumerate(self.bands) if b.split('-')[0] not in avoid]
if not keep:
raise ValueError("All bands filtered out via self.bands.")
return np.asarray(keep, dtype=int)
[docs]
def _compute_keep_idx_freq(self, avoid_bands):
"""Indices for axes of length Nbands//2 (e.g., 6), based on self.freq."""
if self.Nbands % self.lat.nsplits != 0:
raise ValueError("Nbands must be even to use a freq axis of Nbands//2.")
Nfreq = self.Nbands // self.lat.nsplits
if not hasattr(self, 'freqs') or len(self.freqs) != Nfreq:
raise ValueError("self.freq must exist and have length Nbands//2.")
if not avoid_bands:
return np.arange(Nfreq)
avoid = set(map(str, avoid_bands))
keep = [i for i, f in enumerate(self.freqs) if str(f) not in avoid]
if not keep:
raise ValueError("All freqs filtered out via self.freq.")
return np.asarray(keep, dtype=int)
[docs]
def _filter_bands_and_freq_axes(self, arr, keep_idx_bands, keep_idx_freq):
"""
Sequentially index any axis sized Nbands with keep_idx_bands and
any axis sized Nbands//2 with keep_idx_freq. Logs shapes before/after.
"""
if not isinstance(arr, np.ndarray):
return arr
out = arr
Nfreq = self.Nbands // self.lat.nsplits
while True:
axes_bands = [ax for ax, sz in enumerate(out.shape) if sz == self.Nbands]
axes_freq = [ax for ax, sz in enumerate(out.shape) if sz == Nfreq]
if axes_bands:
ax = axes_bands[0]
slicer = [slice(None)] * out.ndim
slicer[ax] = keep_idx_bands
out = out[tuple(slicer)]
continue # re-scan shapes
elif axes_freq:
ax = axes_freq[0]
slicer = [slice(None)] * out.ndim
slicer[ax] = keep_idx_freq
out = out[tuple(slicer)]
continue # re-scan shapes
else:
break
return out
[docs]
def get_spectra(self, idx: int,
sync: bool = False,
avoid_bands: Optional[List[str]] = None
) -> Dict:
"""
Retrieves all relevant spectra for a given realization index.
Parameters:
idx (int): Index for the realization of the CMB map.
sync (bool, optional): If True, calculate also synchrotron power spectra. Defaults to False.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
Tuple containing the power spectra (oxo, dxo, dxd, sxd, sxs, sxo).
"""
oxo = self.obs_x_obs(idx)
dxo = self.dust_x_obs(idx)
dxd = self.dust_x_dust()
if sync:
sxo = self.sync_x_obs(idx)
sxs = self.sync_x_sync()
sxd = self.sync_x_dust()
out = {'oxo': oxo, 'dxd': dxd, 'sxs': sxs, 'dxo': dxo, 'sxo': sxo, 'sxd': sxd}
else:
out = {'oxo': oxo, 'dxd': dxd, 'dxo': dxo}
if avoid_bands is None:
return out
else:
keep_idx_bands = self._compute_keep_idx_bands(avoid_bands)
keep_idx_freq = self._compute_keep_idx_freq(avoid_bands)
for k, v in out.items():
out[k] = self._filter_bands_and_freq_axes(v, keep_idx_bands, keep_idx_freq)
return out
[docs]
class SpectraCross:
[docs]
def __init__(self, libdir:str, lat:LATskyC, sat:SATskyC|None=None, binwidth:int=1, galcut:int=40, aposcale:int=2,lmax:int=3000):
self.lat = lat
self.sat = sat
self.lat_only = (sat is None)
if not self.lat_only:
if lat.nside != sat.nside:
raise ValueError("LAT and SAT nside must be the same")
if lat.cmb.beta != sat.cmb.beta:
raise ValueError("LAT and SAT cmb beta must be the same")
self.nside = lat.nside
self.binwidth = binwidth
self.galcut = galcut
self.aposcale = aposcale
self.lmax = lmax
self.binInfo = nmt.NmtBin.from_lmax_linear(self.lmax, binwidth)
if not self.lat_only:
if list(lat.freqs) != list(sat.freqs):
assert lat.nsplits == sat.nsplits, "Number of splits must be the same for LAT and SAT"
raise ValueError("LAT and SAT frequencies must be the same for cross spectra")
self.freqs = lat.freqs
self.nsplits = lat.nsplits
self.__create_maptags__()
laerr = lat.alpha_err
saerr = sat.alpha_err if not self.lat_only else 0.0
suffix = f"LatOnly_laerr{str(laerr).replace('.','p')}" if self.lat_only else f"laerr{str(laerr).replace('.','p')}_saerr{str(saerr).replace('.','p')}"
if sat is not None:
assert lat.noise.sensitivity_mode == sat.noise.sensitivity_mode, "LAT and SAT sensitivity modes must match"
sens_mode = lat.noise.sensitivity_mode
self.libdir = os.path.join(libdir,f"SpectraCross_d{lat.dust_model}s{lat.sync_model}_{sens_mode}_s{str(lat.cmb.beta).replace('.','p')}_n{self.nside}_b{self.binwidth}_g{self.galcut}_a{str(self.aposcale).replace('.','p')}_l{self.lmax}_{suffix}")
self.specdir = os.path.join(self.libdir,'spectra')
os.makedirs(self.libdir, exist_ok=True)
os.makedirs(self.specdir, exist_ok=True)
if not self.lat_only:
self.__sat_workspace__ = self.__get_coupling_matrix__('SAT')
self.__satlat_workspace__ = self.__get_coupling_matrix__('SATxLAT')
self.__lat_workspace__ = self.__get_coupling_matrix__('LAT')
def __create_maptags__(self)-> None:
latmaptags = [f'LAT_{f}' for f in self.freqs]
latmaptags = [f'{tag}-{i+1}' for tag in latmaptags for i in range(self.nsplits)]
if self.lat_only:
self.maptags = latmaptags
else:
satmaptags = [f'SAT_{f}' for f in self.freqs]
satmaptags = [f'{tag}-{i+1}' for tag in satmaptags for i in range(self.nsplits)]
self.maptags = latmaptags + satmaptags
def __get_mask__(self,tel:str)-> Mask:
if tel=='LAT':
mask_str = self.lat.__class__.__name__[:3]
mask_str += 'xGAL'
maskobj = Mask(self.lat.basedir, self.nside, mask_str, self.aposcale,gal_cut=self.galcut)
elif tel=='SAT':
if self.lat_only:
raise ValueError("SAT mask requested but in LAT-only mode")
mask_str = self.sat.__class__.__name__[:3]
mask_str += 'xGAL'
maskobj = Mask(self.sat.basedir, self.nside, mask_str, self.aposcale,gal_cut=self.galcut)
elif tel=='SATxLAT':
if self.lat_only:
raise ValueError("SATxLAT mask requested but in LAT-only mode")
mask_str = 'SATxLATxGAL'
maskobj = Mask(self.sat.basedir, self.nside, mask_str, self.aposcale,gal_cut=self.galcut)
else:
raise ValueError(f"Unknown telescope: {tel}")
return maskobj
[docs]
def get_mask(self,tel:str)-> np.ndarray:
fname = os.path.join(self.libdir,f'mask_{tel}_galcut{self.galcut}_aposcale{self.aposcale}.fits')
if os.path.exists(fname):
mask = hp.read_map(fname,dtype=np.float64)
else:
maskobj = self.__get_mask__(tel)
mask = maskobj.mask
hp.write_map(fname, mask,dtype=np.float64)
return mask
@property
def satmask(self) -> np.ndarray:
if self.lat_only:
raise ValueError("SAT mask not available in LAT-only mode")
return self.get_mask('SAT')
@property
def latmask(self) -> np.ndarray:
return self.get_mask('LAT')
@property
def satlatmask(self) -> np.ndarray:
if self.lat_only:
raise ValueError("SATxLAT mask not available in LAT-only mode")
return self.get_mask('SATxLAT')
def __get_coupling_matrix__(self,tel) -> None:
wrk = nmt.NmtWorkspace()
fname = os.path.join(self.libdir,f'coupling_matrix_{tel}.fits')
if not os.path.isfile(fname):
if tel=='LAT':
mask = self.latmask
elif tel=='SAT':
mask = self.satmask
elif tel=='SATxLAT':
mask = self.satlatmask
else:
raise ValueError(f"Unknown telescope: {tel}")
mask_f = nmt.NmtField(mask, [mask, mask], lmax=self.lmax)
wrk.compute_coupling_matrix(mask_f, mask_f, self.binInfo)
del mask_f
wrk.write_to(fname)
else:
wrk.read_from(fname)
return wrk
[docs]
def compute_master(self, tel:str, f_a: nmt.NmtField, f_b: nmt.NmtField) -> np.ndarray:
cl_coupled = nmt.compute_coupled_cell(f_a, f_b)
if tel=='LAT':
workspace = self.__lat_workspace__
elif tel=='SAT':
if self.lat_only:
raise ValueError("SAT workspace not available in LAT-only mode")
workspace = self.__sat_workspace__
elif tel=='SATxLAT':
if self.lat_only:
raise ValueError("SATxLAT workspace not available in LAT-only mode")
workspace = self.__satlat_workspace__
else:
raise ValueError(f"Unknown telescope: {tel}")
cl_decoupled = workspace.decouple_cell(cl_coupled)
return cl_decoupled
def __get_QU_maps__(self, idx:int, maptag:str)-> tuple:
tel, freq = maptag.split('_')
if tel=='LAT':
qmap, umap = self.lat.obsQU(idx,freq)
mask = self.latmask
elif tel=='SAT':
if self.lat_only:
raise ValueError(f"SAT maps requested but in LAT-only mode")
qmap, umap = self.sat.obsQU(idx,freq)
mask = self.satmask
else:
raise ValueError(f"Unknown telescope: {tel}")
return qmap, umap, mask
def __get_nmt_index__(self, which:str)-> tuple:
if which=='EB':
ij = 1
ji = 2
elif which=='EE':
ij = 0
ji = 0
elif which=='BB':
ij = 3
ji = 3
else:
raise ValueError(f"Unknown spectra type: {which}")
return ij, ji
def __spectra_matrix__fname__(self, idx:int, which='EB', check=False,checker='e')-> str | bool:
if check:
fname = os.path.join(self.specdir,f'spectra_matrix_{which}_{idx:03d}.pkl')
if checker=='e':
return os.path.isfile(fname)
elif checker=='r':
try:
pl.load(open(fname,'rb'))
return True
except:
return False
else:
raise ValueError("checker must be 'e' or 'r'")
return os.path.join(self.specdir,f'spectra_matrix_{which}_{idx:03d}.pkl')
def __spectra_matrix_core__(self, idx:int, which='EB')->np.ndarray:
fname = self.__spectra_matrix__fname__(idx, which)
if os.path.exists(fname):
return pl.load(open(fname,'rb'))
else:
matrix = np.zeros((len(self.maptags), len(self.maptags), self.binInfo.get_n_bands()))
for i in tqdm(range(len(self.maptags)), desc='Outer loop', position=0):
maptag_i = self.maptags[i]
qi, ui, maski = self.__get_QU_maps__(idx, maptag_i)
f_i = nmt.NmtField(maski, [qi*maski, ui*maski], lmax=self.lmax,masked_on_input=True)
del qi, ui, maski
for j in tqdm(range(i + 1), desc='Inner loop', position=1, leave=False):
if i == j:
f_j = f_i
maptag_j = maptag_i
else:
maptag_j = self.maptags[j]
qj, uj, maskj = self.__get_QU_maps__(idx, maptag_j)
f_j = nmt.NmtField(maskj, [qj*maskj, uj*maskj], lmax=self.lmax,masked_on_input=True)
del qj, uj, maskj
if maptag_i.startswith('LAT') and maptag_j.startswith('LAT'):
tel = 'LAT'
elif maptag_i.startswith('SAT') and maptag_j.startswith('SAT'):
if self.lat_only:
raise ValueError("SAT-SAT correlation not available in LAT-only mode")
tel = 'SAT'
else:
if self.lat_only:
raise ValueError("LAT-SAT correlation not available in LAT-only mode")
tel = 'SATxLAT'
cl_decoupled = self.compute_master(tel, f_i, f_j)
ij, ji = self.__get_nmt_index__(which)
matrix[i, j, :] = cl_decoupled[ij]
if i != j:
matrix[j, i, :] = cl_decoupled[ji]
del f_j
del f_i
gc.collect()
pl.dump(matrix, open(fname,'wb'))
return matrix
[docs]
def data_matrix(self, idx:int,
which:str = 'EB',
sat_lrange:tuple = (None, None),
lat_lrange:tuple = (None, None),
avg_splits:bool = False,
common_mask_op = operator.and_)->np.ndarray:
matrix = self.__spectra_matrix_core__(idx, which)
bin_centers = self.binInfo.get_effective_ells()
current_maptags = self.maptags.copy() # Make a copy to modify if needed
n_bins = self.binInfo.get_n_bands()
sat_indices = [i for i, tag in enumerate(current_maptags) if tag.startswith('SAT')]
lat_indices = [i for i, tag in enumerate(current_maptags) if tag.startswith('LAT')]
# Build ell-range masks
sat_ell_mask = np.ones(n_bins, dtype=bool)
if sat_lrange[0] is not None:
sat_ell_mask &= (bin_centers >= sat_lrange[0])
if sat_lrange[1] is not None:
sat_ell_mask &= (bin_centers <= sat_lrange[1])
lat_ell_mask = np.ones(n_bins, dtype=bool)
if lat_lrange[0] is not None:
lat_ell_mask &= (bin_centers >= lat_lrange[0])
if lat_lrange[1] is not None:
lat_ell_mask &= (bin_centers <= lat_lrange[1])
# Apply masks to different correlation types
# SAT-SAT correlations: use SAT mask
for i in sat_indices:
for j in sat_indices:
matrix[i, j, ~sat_ell_mask] = 0
# LAT-LAT correlations: use LAT mask
for i in lat_indices:
for j in lat_indices:
matrix[i, j, ~lat_ell_mask] = 0
# SAT-LAT cross-correlations: use union of both masks (common valid range)
cross_ell_mask = common_mask_op(sat_ell_mask, lat_ell_mask)
for i in sat_indices:
for j in lat_indices:
matrix[i, j, ~cross_ell_mask] = 0
matrix[j, i, ~cross_ell_mask] = 0
# Average over splits if requested
if avg_splits:
# Group by frequency (without split number)
freq_groups = {}
for i, tag in enumerate(current_maptags):
# Extract base tag without split number (e.g., 'LAT_93' from 'LAT_93-1')
base_tag = tag.rsplit('-', 1)[0]
if base_tag not in freq_groups:
freq_groups[base_tag] = []
freq_groups[base_tag].append(i)
# Create averaged matrix
n_groups = len(freq_groups)
averaged_matrix = np.zeros((n_groups, n_groups, n_bins))
new_maptags = list(freq_groups.keys())
for i, (tag_i, indices_i) in enumerate(freq_groups.items()):
for j, (tag_j, indices_j) in enumerate(freq_groups.items()):
# Average over all split combinations
values = []
for idx_i in indices_i:
for idx_j in indices_j:
values.append(matrix[idx_i, idx_j, :])
averaged_matrix[i, j, :] = np.nanmean(values, axis=0)
matrix = averaged_matrix
current_maptags = new_maptags
return matrix