Skip to content

Commit

Permalink
Refactor image/stamp workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
g-braeunlich committed Oct 10, 2023
1 parent fe7d976 commit c145f0c
Show file tree
Hide file tree
Showing 3 changed files with 493 additions and 228 deletions.
252 changes: 199 additions & 53 deletions imsim/lsst_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import numpy as np
import logging
import dataclasses
import numpy as np
import galsim
from galsim.config import RegisterImageType, GetAllParams, GetSky, AddNoise
from galsim.config.image_scattered import ScatteredImageBuilder
from galsim.wcs import PixelScale
from galsim.sensor import Sensor

from .sky_model import SkyGradient, CCD_Fringing
from .camera import get_camera
from .vignetting import Vignetting
from .stamp import StellarObject, ProcessingMode, build_obj

def merge_photon_arrays(arrays):
n_tot = sum(len(arr) for arr in arrays)
merged = galsim.PhotonArray(n_tot)
start = 0
for arr in arrays:
merged.assignAt(start, arr)
start += len(arr)
return merged


class LSST_ImageBuilder(ScatteredImageBuilder):
Expand Down Expand Up @@ -81,18 +95,15 @@ def setup(self, config, base, image_num, obj_num, ignore, logger):

self.camera_name = params.get('camera', 'LsstCam')

self.nbatch = params.get('nbatch', 10)
try:
self.checkpoint = galsim.config.GetInputObj('checkpoint', config, base, 'LSST_Image')
self.nbatch = params.get('nbatch', 10)
except galsim.config.GalSimConfigError:
self.checkpoint = None
self.nbatch = params.get('nbatch', 1)
# Note: This will probably also become 10 once we're doing the photon
# pooling stuff. But for now, let it be 1 if not checkpointing.

return xsize, ysize

def buildImage(self, config, base, image_num, obj_num, logger):
def buildImage(self, config, base, image_num, _obj_num, logger):
"""Build the Image.
This is largely the same as the GalSim Scattered image type.
Expand Down Expand Up @@ -140,67 +151,57 @@ def buildImage(self, config, base, image_num, obj_num, logger):

full_image = None
current_var = 0
start_num = obj_num

# For cases where there is noise in individual stamps, we need to keep track of the
# stamp bounds and their current variances. When checkpointing, we don't need to
# save the pixel values for this, just the bounds and the current_var value of each.
all_stamps = []
all_vars = []
all_obj_nums = []
photon_batch_num = 0

if self.checkpoint is not None:
chk_name = 'buildImage_%s'%(self.det_name)
saved = self.checkpoint.load(chk_name)
if saved is not None:
full_image, all_bounds, all_vars, start_num, extra_builder = saved
if extra_builder is not None:
base['extra_builder'] = extra_builder
all_stamps = [galsim._Image(np.array([]), b, full_image.wcs) for b in all_bounds]
logger.warning('File %d: Loaded checkpoint data from %s.',
base.get('file_num', 0), self.checkpoint.file_name)
if start_num == obj_num + self.nobjects:
logger.warning('All objects already rendered for this image.')
else:
logger.warning("Objects %d..%d already rendered", obj_num, start_num-1)
logger.warning('Starting at obj_num %d', start_num)
nobj_tot = self.nobjects - (start_num - obj_num)
chk_name = "buildImage_" + self.det_name
full_image, all_vars, all_stamps, all_obj_nums, photon_batch_num = load_checkpoint(self.checkpoint, chk_name, base, logger)
remaining_obj_nums = sorted(frozenset(range(self.nobjects)) - frozenset(all_obj_nums))

if full_image is None:
full_image = galsim.Image(full_xsize, full_ysize, dtype=dtype)
full_image.setOrigin(base['image_origin'])
full_image.wcs = wcs
full_image.setZero()
start_batch = 0
base['current_image'] = full_image

nbatch = min(self.nbatch, nobj_tot)
for batch in range(nbatch):
start_obj_num = start_num + (nobj_tot * batch // nbatch)
end_obj_num = start_num + (nobj_tot * (batch+1) // nbatch)
nobj_batch = end_obj_num - start_obj_num
if nbatch > 1:
logger.warning("Start batch %d/%d with %d objects [%d, %d)",
batch+1, nbatch, nobj_batch, start_obj_num, end_obj_num)
stamps, current_vars = galsim.config.BuildStamps(
nobj_batch, base, logger=logger, obj_num=start_obj_num, do_noise=False)
n_fft_batch = min(self.nbatch, len(remaining_obj_nums))
sensor = base.get('sensor', None)
if sensor is not None:
rng = galsim.config.GetRNG(config, base, logger, "LSST_Silicon")
sensor.updateRNG(rng)
fft_objects, phot_objects, faint_objects = partition_objects(load_objects(remaining_obj_nums, config, base, logger))
if self.checkpoint is not None:
if not fft_objects:
logger.warning('All FFT objects already rendered for this image.')
else:
logger.warning("%d objects already rendered", len(all_obj_nums))

# Handle FFT objects first:
for batch_num, batch in enumerate(make_batches(fft_objects, n_fft_batch), start=1):
if n_fft_batch > 1:
logger.warning("Start FFT batch %d/%d with %d objects",
batch_num, n_fft_batch, len(batch))
stamps, current_vars = build_stamps(base, logger, batch, stamp_type="LSST_Silicon")
base['index_key'] = 'image_num'

for k in range(nobj_batch):
# This is our signal that the object was skipped.
if stamps[k] is None:
continue
bounds = stamps[k].bounds & full_image.bounds
if not bounds.isDefined(): # pragma: no cover
# These noramlly show up as stamp==None, but technically it is possible
# to get a stamp that is off the main image, so check for that here to
# avoid an error. But this isn't covered in the imsim test suite.
for stamp_obj, stamp in zip(batch, stamps):
bounds = stamp_bounds(stamp, full_image.bounds)
if bounds is None:
continue

logger.debug('image %d: full bounds = %s', image_num, str(full_image.bounds))
logger.debug('image %d: stamp %d bounds = %s',
image_num, k+start_obj_num, str(stamps[k].bounds))
image_num, stamp_obj.index, str(stamp.bounds))
logger.debug('image %d: Overlap = %s', image_num, str(bounds))
full_image[bounds] += stamps[k][bounds]
full_image[bounds] += stamp[bounds]
all_obj_nums.append(stamp_obj.index)

# Note: in typical imsim usage, all current_vars will be 0. So this normally doens't
# add much to the checkpointing data.
Expand All @@ -209,16 +210,47 @@ def buildImage(self, config, base, image_num, obj_num, logger):
all_vars.extend([current_vars[k] for k in nz_var])

if self.checkpoint is not None:
# Don't save the full stamps. All we need for FlattenNoiseVariance is the bounds.
# Everything else about the stamps has already been handled above.
all_bounds = [stamp.bounds for stamp in all_stamps]
data = (full_image, all_bounds, all_vars, end_obj_num,
base.get('extra_builder',None))
self.checkpoint.save(chk_name, data)
logger.warning('File %d: Completed batch %d with objects [%d, %d), and wrote '
save_checkpoint(self.checkpoint, chk_name, base, full_image, all_stamps, all_vars, all_obj_nums, photon_batch_num)
logger.warning('File %d: Completed batch %d, and wrote '
'checkpoint data to %s',
base.get('file_num', 0), batch+1, start_obj_num, end_obj_num,
base.get('file_num', 0), batch_num,
self.checkpoint.file_name)
# Handle photons:
phot_batches = make_photon_batches(
config, base, logger, phot_objects, faint_objects, self.nbatch
)

if photon_batch_num > 0:
logger.warning(
"Photon batches [0, %d) / %d already rendered - skipping",
photon_batch_num,
self.nbatch,
)
phot_batches = phot_batches[photon_batch_num:]
stamps = (build_stamps(base, logger, batch, stamp_type="PhotonStampBuilder") for batch in phot_batches)
for batch_num, batch in enumerate(phot_batches, start=photon_batch_num):
if not batch:
continue
base['index_key'] = 'image_num'
stamps, current_vars = build_stamps(base, logger, batch, stamp_type="PhotonStampBuilder")
photons = merge_photon_arrays(stamps)
photon_ops_cfg = {"photon_ops": base.get("stamp", {}).get("photon_ops", [])}
base["image_pos"].x = full_image.center.x
base["image_pos"].y = full_image.center.y
photon_ops = galsim.config.BuildPhotonOps(photon_ops_cfg, 'photon_ops', base, logger)
# TODO: Can this be done using public API?
local_wcs = wcs.local(galsim.position._PositionD(0., 0.))
for op in photon_ops:
op.applyTo(photons, local_wcs, rng)
accumulate_photons(photons, full_image, sensor, full_image.center)

# Note: in typical imsim usage, all current_vars will be 0. So this normally doens't
# add much to the checkpointing data.
nz_var = np.nonzero(current_vars)[0]
all_vars.extend([current_vars[k] for k in nz_var])

if self.checkpoint is not None:
save_checkpoint(self.checkpoint, chk_name, base, full_image, all_stamps, all_vars, all_obj_nums, batch_num+1)

# Bring the image so far up to a flat noise variance
current_var = galsim.config.FlattenNoiseVariance(
Expand Down Expand Up @@ -297,4 +329,118 @@ def addNoise(self, image, config, base, image_num, obj_num, current_var, logger)
AddNoise(base,image,current_var,logger)


def accumulate_photons(photons, image, sensor, center):
if sensor is None:
sensor = Sensor()
imview = image._view()
imview._shift(-center) # equiv. to setCenter(), but faster
imview.wcs = PixelScale(1.0)
if imview.dtype in (np.float32, np.float64):
sensor.accumulate(photons, imview, imview.center)
else:
# Need a temporary
im1 = galsim.image.ImageD(bounds=imview.bounds)
sensor.accumulate(photons, im1, imview.center)
imview += im1

def make_batches(objects, nbatch: int):
per_batch = len(objects) // nbatch
o_iter = iter(objects)
for _ in range(nbatch):
yield [obj for _, obj in zip(range(per_batch), o_iter)]


def build_stamps(base, logger, objects: list[StellarObject], stamp_type: str):
base["stamp"]["type"] = stamp_type
if not objects:
return [], []
base["_objects"] = {obj.index: obj for obj in objects}

images, current_vars = zip(
*(
galsim.config.BuildStamp(
base, obj.index, xsize=0, ysize=0, do_noise=False, logger=logger
)
for obj in objects
)
)
return images, current_vars


def make_photon_batches(config, base, logger, phot_objects: list[StellarObject], faint_objects: list[StellarObject], nbatch: int):
if not phot_objects and not faint_objects:
return []
batches = [
[dataclasses.replace(obj, realized_flux=obj.realized_flux / nbatch) for obj in phot_objects]
] * nbatch
rng = galsim.config.GetRNG(config, base, logger, "LSST_Silicon")
ud = galsim.UniformDeviate(rng)
# Shuffle faint objects into the batches randomly:
for obj in faint_objects:
batch_index = int(ud() * nbatch)
batches[batch_index].append(obj)

return batches

def stamp_bounds(stamp, full_image_bounds):
if stamp is None:
return None
bounds = stamp.bounds & full_image_bounds
if not bounds.isDefined(): # pragma: no cover
# These noramlly show up as stamp==None, but technically it is possible
# to get a stamp that is off the main image, so check for that here to
# avoid an error. But this isn't covered in the imsim test suite.
return None
return bounds


def partition_objects(objects):
objects_by_mode = {
ProcessingMode.FFT: [],
ProcessingMode.PHOT: [],
ProcessingMode.FAINT: [],
}
for obj in objects:
objects_by_mode[obj.mode].append(obj)
return (
objects_by_mode[ProcessingMode.FFT],
objects_by_mode[ProcessingMode.PHOT],
objects_by_mode[ProcessingMode.FAINT],
)


def load_objects(obj_numbers, config, base, logger):
gsparams = {}
stamp = base['stamp']
if 'gsparams' in stamp:
gsparams = galsim.gsobject.UpdateGSParams(gsparams, stamp['gsparams'], config)

for obj_num in obj_numbers:
galsim.config.SetupConfigObjNum(config, obj_num, logger)
obj = build_obj(stamp, base, logger)
if obj is not None:
yield build_obj(stamp, base, logger)


def load_checkpoint(checkpoint, chk_name, base, logger):
saved = checkpoint.load(chk_name)
if saved is not None:
full_image, all_bounds, all_vars, all_obj_nums, extra_builder, photon_batch_num = saved
if extra_builder is not None:
base['extra_builder'] = extra_builder
all_stamps = [galsim._Image(np.array([]), b, full_image.wcs) for b in all_bounds]
logger.warning('File %d: Loaded checkpoint data from %s.',
base.get('file_num', 0), checkpoint.file_name)
return full_image, all_vars, all_stamps, all_obj_nums, photon_batch_num
return (None,)*5

def save_checkpoint(checkpoint, chk_name, base, full_image, all_stamps, all_vars, all_obj_nums, photon_batch_num):
# Don't save the full stamps. All we need for FlattenNoiseVariance is the bounds.
# Everything else about the stamps has already been handled above.
all_bounds = [stamp.bounds for stamp in all_stamps]
data = (full_image, all_bounds, all_vars, all_obj_nums,
base.get('extra_builder',None), photon_batch_num)
checkpoint.save(chk_name, data)


RegisterImageType('LSST_Image', LSST_ImageBuilder())
1 change: 1 addition & 0 deletions imsim/photon_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def deserialize_rubin_optics(config, base, _logger):
@photon_op_type("RubinDiffractionOptics", input_type="telescope")
def deserialize_rubin_diffraction_optics(config, base, _logger):
kwargs = config_kwargs(config, base, RubinDiffractionOptics, _rubin_optics_base_args)
print("RubinOptics", kwargs["image_pos"].x, kwargs["image_pos"].y)
telescope = base["det_telescope"]
rubin_diffraction = RubinDiffraction(
telescope=telescope,
Expand Down
Loading

0 comments on commit c145f0c

Please sign in to comment.