Skip to content

Commit

Permalink
Ensure reprojection mask is compatible with output map
Browse files Browse the repository at this point in the history
Also some formatting cleanup
  • Loading branch information
arahlin committed Dec 18, 2024
1 parent fdb5169 commit a993940
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
62 changes: 42 additions & 20 deletions maps/python/map_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,8 +880,15 @@ class ReprojectMaps(object):
output map. For numpy array, all zeros/inf/nan/hp.UNSEEN pixels are skipped.
"""

def __init__(self, map_stub=None, rebin=1, interp=False, weighted=True,
partial=False, mask=None):
def __init__(
self,
map_stub=None,
rebin=1,
interp=False,
weighted=True,
partial=False,
mask=None,
):
assert map_stub is not None, "map_stub argument required"
self.stub = map_stub.clone(False)
self.stub.pol_type = None
Expand All @@ -890,7 +897,6 @@ def __init__(self, map_stub=None, rebin=1, interp=False, weighted=True,
self.weighted = weighted
self._mask = None
self.partial = partial

self.mask = mask

def __call__(self, frame):
Expand All @@ -917,25 +923,35 @@ def __call__(self, frame):

if key in "TQUH":
mnew = self.stub.clone(False)
maps.reproj_map(m, mnew, rebin=self.rebin, interp=self.interp,
mask=self.mask)
maps.reproj_map(
m, mnew, rebin=self.rebin, interp=self.interp, mask=self.mask
)

elif key in ["Wpol", "Wunpol"]:
mnew = maps.G3SkyMapWeights(self.stub)
for wkey in mnew.keys():
maps.reproj_map(
m[wkey], mnew[wkey], rebin=self.rebin, interp=self.interp,
mask=self.mask
m[wkey],
mnew[wkey],
rebin=self.rebin,
interp=self.interp,
mask=self.mask,
)

frame[key] = mnew
self.mask = mnew
return frame

@property
def mask(self):
"""
The mask to be used for partial reprojection, of the same shape as the
output map. Masked (1) pixels are handled by the reprojection code, and
unmasked (0) pixels are skipped, effectively setting their value to 0 in
the output map.
"""
return self._mask

@mask.setter
def mask(self, mask):
if mask is None:
Expand All @@ -944,21 +960,27 @@ def mask(self, mask):
if isinstance(mask, maps.G3SkyMapMask):
self._mask = mask
elif isinstance(mask, maps.G3SkyMap):
self._mask = maps.G3SkyMapMask(mask, use_data=True, zero_nans=True,
zero_infs=True)
self._mask = maps.G3SkyMapMask(
mask, use_data=True, zero_nans=True, zero_infs=True
)
elif isinstance(mask, np.ndarray):
from healpy import UNSEEN
import healpy as hp

tmp = self.stub.clone(False)
mask_copy = np.ones(mask.shape, dtype=int)
bad = np.logical_or.reduce([
np.isnan(mask),
np.isinf(mask),
mask==0,
mask==UNSEEN
])
bad = np.logical_or.reduce(
[
np.isnan(mask),
np.isinf(mask),
mask == 0,
hp.mask_bad(mask),
]
)
mask_copy[bad] = 0
tmp[:] = mask_copy
self._mask = maps.G3SkyMapMask(tmp, use_data=True)
else:
raise TypeError("Mask must be a G3SkyMapMask, G3SkyMap, "
"or numpy array")
raise TypeError("Mask must be a G3SkyMapMask, G3SkyMap, or numpy array")

if not self._mask.compatible(self.stub):
raise ValueError("Mask is not compatible with output map")
5 changes: 4 additions & 1 deletion maps/src/maputils.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ void FlattenPol(FlatSkyMapPtr Q, FlatSkyMapPtr U, G3SkyMapWeightsPtr W, double h


void ReprojMap(G3SkyMapConstPtr in_map, G3SkyMapPtr out_map, int rebin, bool interp,
G3SkyMapMaskConstPtr out_map_mask)
G3SkyMapMaskConstPtr out_map_mask)
{
bool rotate = false; // no transform
Quat q_rot; // quaternion for rotating from output to input coordinate system
Expand Down Expand Up @@ -311,6 +311,9 @@ void ReprojMap(G3SkyMapConstPtr in_map, G3SkyMapPtr out_map, int rebin, bool int
out_map->pol_conv = in_map->pol_conv;
}

if (!!out_map_mask && !out_map_mask->IsCompatible(*out_map))
log_fatal("Mask is not compatible with output map");

size_t stop = out_map->size();
if (rebin > 1) {
for (size_t i = 0; i < stop; i++) {
Expand Down

0 comments on commit a993940

Please sign in to comment.