From c145f0c4374af5b63c556d2cc0e52c29d58400da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerhard=20Br=C3=A4unlich?= Date: Wed, 27 Sep 2023 13:28:15 +0200 Subject: [PATCH] Refactor image/stamp workflow --- imsim/lsst_image.py | 252 +++++++++++++++++++----- imsim/photon_ops.py | 1 + imsim/stamp.py | 468 +++++++++++++++++++++++++++----------------- 3 files changed, 493 insertions(+), 228 deletions(-) diff --git a/imsim/lsst_image.py b/imsim/lsst_image.py index 63fbcddb..49a00785 100644 --- a/imsim/lsst_image.py +++ b/imsim/lsst_image.py @@ -1,11 +1,25 @@ -import numpy as np import logging +import dataclasses +import numpy as np import galsim from galsim.config import RegisterImageType, GetAllParams, GetSky, AddNoise from galsim.config.image_scattered import ScatteredImageBuilder +from galsim.wcs import PixelScale +from galsim.sensor import Sensor + from .sky_model import SkyGradient, CCD_Fringing from .camera import get_camera from .vignetting import Vignetting +from .stamp import StellarObject, ProcessingMode, build_obj + +def merge_photon_arrays(arrays): + n_tot = sum(len(arr) for arr in arrays) + merged = galsim.PhotonArray(n_tot) + start = 0 + for arr in arrays: + merged.assignAt(start, arr) + start += len(arr) + return merged class LSST_ImageBuilder(ScatteredImageBuilder): @@ -81,18 +95,15 @@ def setup(self, config, base, image_num, obj_num, ignore, logger): self.camera_name = params.get('camera', 'LsstCam') + self.nbatch = params.get('nbatch', 10) try: self.checkpoint = galsim.config.GetInputObj('checkpoint', config, base, 'LSST_Image') - self.nbatch = params.get('nbatch', 10) except galsim.config.GalSimConfigError: self.checkpoint = None - self.nbatch = params.get('nbatch', 1) - # Note: This will probably also become 10 once we're doing the photon - # pooling stuff. But for now, let it be 1 if not checkpointing. return xsize, ysize - def buildImage(self, config, base, image_num, obj_num, logger): + def buildImage(self, config, base, image_num, _obj_num, logger): """Build the Image. This is largely the same as the GalSim Scattered image type. @@ -140,67 +151,57 @@ def buildImage(self, config, base, image_num, obj_num, logger): full_image = None current_var = 0 - start_num = obj_num # For cases where there is noise in individual stamps, we need to keep track of the # stamp bounds and their current variances. When checkpointing, we don't need to # save the pixel values for this, just the bounds and the current_var value of each. all_stamps = [] all_vars = [] + all_obj_nums = [] + photon_batch_num = 0 if self.checkpoint is not None: - chk_name = 'buildImage_%s'%(self.det_name) - saved = self.checkpoint.load(chk_name) - if saved is not None: - full_image, all_bounds, all_vars, start_num, extra_builder = saved - if extra_builder is not None: - base['extra_builder'] = extra_builder - all_stamps = [galsim._Image(np.array([]), b, full_image.wcs) for b in all_bounds] - logger.warning('File %d: Loaded checkpoint data from %s.', - base.get('file_num', 0), self.checkpoint.file_name) - if start_num == obj_num + self.nobjects: - logger.warning('All objects already rendered for this image.') - else: - logger.warning("Objects %d..%d already rendered", obj_num, start_num-1) - logger.warning('Starting at obj_num %d', start_num) - nobj_tot = self.nobjects - (start_num - obj_num) + chk_name = "buildImage_" + self.det_name + full_image, all_vars, all_stamps, all_obj_nums, photon_batch_num = load_checkpoint(self.checkpoint, chk_name, base, logger) + remaining_obj_nums = sorted(frozenset(range(self.nobjects)) - frozenset(all_obj_nums)) if full_image is None: full_image = galsim.Image(full_xsize, full_ysize, dtype=dtype) full_image.setOrigin(base['image_origin']) full_image.wcs = wcs full_image.setZero() - start_batch = 0 base['current_image'] = full_image - nbatch = min(self.nbatch, nobj_tot) - for batch in range(nbatch): - start_obj_num = start_num + (nobj_tot * batch // nbatch) - end_obj_num = start_num + (nobj_tot * (batch+1) // nbatch) - nobj_batch = end_obj_num - start_obj_num - if nbatch > 1: - logger.warning("Start batch %d/%d with %d objects [%d, %d)", - batch+1, nbatch, nobj_batch, start_obj_num, end_obj_num) - stamps, current_vars = galsim.config.BuildStamps( - nobj_batch, base, logger=logger, obj_num=start_obj_num, do_noise=False) + n_fft_batch = min(self.nbatch, len(remaining_obj_nums)) + sensor = base.get('sensor', None) + if sensor is not None: + rng = galsim.config.GetRNG(config, base, logger, "LSST_Silicon") + sensor.updateRNG(rng) + fft_objects, phot_objects, faint_objects = partition_objects(load_objects(remaining_obj_nums, config, base, logger)) + if self.checkpoint is not None: + if not fft_objects: + logger.warning('All FFT objects already rendered for this image.') + else: + logger.warning("%d objects already rendered", len(all_obj_nums)) + + # Handle FFT objects first: + for batch_num, batch in enumerate(make_batches(fft_objects, n_fft_batch), start=1): + if n_fft_batch > 1: + logger.warning("Start FFT batch %d/%d with %d objects", + batch_num, n_fft_batch, len(batch)) + stamps, current_vars = build_stamps(base, logger, batch, stamp_type="LSST_Silicon") base['index_key'] = 'image_num' - for k in range(nobj_batch): - # This is our signal that the object was skipped. - if stamps[k] is None: - continue - bounds = stamps[k].bounds & full_image.bounds - if not bounds.isDefined(): # pragma: no cover - # These noramlly show up as stamp==None, but technically it is possible - # to get a stamp that is off the main image, so check for that here to - # avoid an error. But this isn't covered in the imsim test suite. + for stamp_obj, stamp in zip(batch, stamps): + bounds = stamp_bounds(stamp, full_image.bounds) + if bounds is None: continue - logger.debug('image %d: full bounds = %s', image_num, str(full_image.bounds)) logger.debug('image %d: stamp %d bounds = %s', - image_num, k+start_obj_num, str(stamps[k].bounds)) + image_num, stamp_obj.index, str(stamp.bounds)) logger.debug('image %d: Overlap = %s', image_num, str(bounds)) - full_image[bounds] += stamps[k][bounds] + full_image[bounds] += stamp[bounds] + all_obj_nums.append(stamp_obj.index) # Note: in typical imsim usage, all current_vars will be 0. So this normally doens't # add much to the checkpointing data. @@ -209,16 +210,47 @@ def buildImage(self, config, base, image_num, obj_num, logger): all_vars.extend([current_vars[k] for k in nz_var]) if self.checkpoint is not None: - # Don't save the full stamps. All we need for FlattenNoiseVariance is the bounds. - # Everything else about the stamps has already been handled above. - all_bounds = [stamp.bounds for stamp in all_stamps] - data = (full_image, all_bounds, all_vars, end_obj_num, - base.get('extra_builder',None)) - self.checkpoint.save(chk_name, data) - logger.warning('File %d: Completed batch %d with objects [%d, %d), and wrote ' + save_checkpoint(self.checkpoint, chk_name, base, full_image, all_stamps, all_vars, all_obj_nums, photon_batch_num) + logger.warning('File %d: Completed batch %d, and wrote ' 'checkpoint data to %s', - base.get('file_num', 0), batch+1, start_obj_num, end_obj_num, + base.get('file_num', 0), batch_num, self.checkpoint.file_name) + # Handle photons: + phot_batches = make_photon_batches( + config, base, logger, phot_objects, faint_objects, self.nbatch + ) + + if photon_batch_num > 0: + logger.warning( + "Photon batches [0, %d) / %d already rendered - skipping", + photon_batch_num, + self.nbatch, + ) + phot_batches = phot_batches[photon_batch_num:] + stamps = (build_stamps(base, logger, batch, stamp_type="PhotonStampBuilder") for batch in phot_batches) + for batch_num, batch in enumerate(phot_batches, start=photon_batch_num): + if not batch: + continue + base['index_key'] = 'image_num' + stamps, current_vars = build_stamps(base, logger, batch, stamp_type="PhotonStampBuilder") + photons = merge_photon_arrays(stamps) + photon_ops_cfg = {"photon_ops": base.get("stamp", {}).get("photon_ops", [])} + base["image_pos"].x = full_image.center.x + base["image_pos"].y = full_image.center.y + photon_ops = galsim.config.BuildPhotonOps(photon_ops_cfg, 'photon_ops', base, logger) + # TODO: Can this be done using public API? + local_wcs = wcs.local(galsim.position._PositionD(0., 0.)) + for op in photon_ops: + op.applyTo(photons, local_wcs, rng) + accumulate_photons(photons, full_image, sensor, full_image.center) + + # Note: in typical imsim usage, all current_vars will be 0. So this normally doens't + # add much to the checkpointing data. + nz_var = np.nonzero(current_vars)[0] + all_vars.extend([current_vars[k] for k in nz_var]) + + if self.checkpoint is not None: + save_checkpoint(self.checkpoint, chk_name, base, full_image, all_stamps, all_vars, all_obj_nums, batch_num+1) # Bring the image so far up to a flat noise variance current_var = galsim.config.FlattenNoiseVariance( @@ -297,4 +329,118 @@ def addNoise(self, image, config, base, image_num, obj_num, current_var, logger) AddNoise(base,image,current_var,logger) +def accumulate_photons(photons, image, sensor, center): + if sensor is None: + sensor = Sensor() + imview = image._view() + imview._shift(-center) # equiv. to setCenter(), but faster + imview.wcs = PixelScale(1.0) + if imview.dtype in (np.float32, np.float64): + sensor.accumulate(photons, imview, imview.center) + else: + # Need a temporary + im1 = galsim.image.ImageD(bounds=imview.bounds) + sensor.accumulate(photons, im1, imview.center) + imview += im1 + +def make_batches(objects, nbatch: int): + per_batch = len(objects) // nbatch + o_iter = iter(objects) + for _ in range(nbatch): + yield [obj for _, obj in zip(range(per_batch), o_iter)] + + +def build_stamps(base, logger, objects: list[StellarObject], stamp_type: str): + base["stamp"]["type"] = stamp_type + if not objects: + return [], [] + base["_objects"] = {obj.index: obj for obj in objects} + + images, current_vars = zip( + *( + galsim.config.BuildStamp( + base, obj.index, xsize=0, ysize=0, do_noise=False, logger=logger + ) + for obj in objects + ) + ) + return images, current_vars + + +def make_photon_batches(config, base, logger, phot_objects: list[StellarObject], faint_objects: list[StellarObject], nbatch: int): + if not phot_objects and not faint_objects: + return [] + batches = [ + [dataclasses.replace(obj, realized_flux=obj.realized_flux / nbatch) for obj in phot_objects] + ] * nbatch + rng = galsim.config.GetRNG(config, base, logger, "LSST_Silicon") + ud = galsim.UniformDeviate(rng) + # Shuffle faint objects into the batches randomly: + for obj in faint_objects: + batch_index = int(ud() * nbatch) + batches[batch_index].append(obj) + + return batches + +def stamp_bounds(stamp, full_image_bounds): + if stamp is None: + return None + bounds = stamp.bounds & full_image_bounds + if not bounds.isDefined(): # pragma: no cover + # These noramlly show up as stamp==None, but technically it is possible + # to get a stamp that is off the main image, so check for that here to + # avoid an error. But this isn't covered in the imsim test suite. + return None + return bounds + + +def partition_objects(objects): + objects_by_mode = { + ProcessingMode.FFT: [], + ProcessingMode.PHOT: [], + ProcessingMode.FAINT: [], + } + for obj in objects: + objects_by_mode[obj.mode].append(obj) + return ( + objects_by_mode[ProcessingMode.FFT], + objects_by_mode[ProcessingMode.PHOT], + objects_by_mode[ProcessingMode.FAINT], + ) + + +def load_objects(obj_numbers, config, base, logger): + gsparams = {} + stamp = base['stamp'] + if 'gsparams' in stamp: + gsparams = galsim.gsobject.UpdateGSParams(gsparams, stamp['gsparams'], config) + + for obj_num in obj_numbers: + galsim.config.SetupConfigObjNum(config, obj_num, logger) + obj = build_obj(stamp, base, logger) + if obj is not None: + yield build_obj(stamp, base, logger) + + +def load_checkpoint(checkpoint, chk_name, base, logger): + saved = checkpoint.load(chk_name) + if saved is not None: + full_image, all_bounds, all_vars, all_obj_nums, extra_builder, photon_batch_num = saved + if extra_builder is not None: + base['extra_builder'] = extra_builder + all_stamps = [galsim._Image(np.array([]), b, full_image.wcs) for b in all_bounds] + logger.warning('File %d: Loaded checkpoint data from %s.', + base.get('file_num', 0), checkpoint.file_name) + return full_image, all_vars, all_stamps, all_obj_nums, photon_batch_num + return (None,)*5 + +def save_checkpoint(checkpoint, chk_name, base, full_image, all_stamps, all_vars, all_obj_nums, photon_batch_num): + # Don't save the full stamps. All we need for FlattenNoiseVariance is the bounds. + # Everything else about the stamps has already been handled above. + all_bounds = [stamp.bounds for stamp in all_stamps] + data = (full_image, all_bounds, all_vars, all_obj_nums, + base.get('extra_builder',None), photon_batch_num) + checkpoint.save(chk_name, data) + + RegisterImageType('LSST_Image', LSST_ImageBuilder()) diff --git a/imsim/photon_ops.py b/imsim/photon_ops.py index 2565af4c..d998414f 100644 --- a/imsim/photon_ops.py +++ b/imsim/photon_ops.py @@ -380,6 +380,7 @@ def deserialize_rubin_optics(config, base, _logger): @photon_op_type("RubinDiffractionOptics", input_type="telescope") def deserialize_rubin_diffraction_optics(config, base, _logger): kwargs = config_kwargs(config, base, RubinDiffractionOptics, _rubin_optics_base_args) + print("RubinOptics", kwargs["image_pos"].x, kwargs["image_pos"].y) telescope = base["det_telescope"] rubin_diffraction = RubinDiffraction( telescope=telescope, diff --git a/imsim/stamp.py b/imsim/stamp.py index 198ce39f..a1af6b60 100644 --- a/imsim/stamp.py +++ b/imsim/stamp.py @@ -1,3 +1,4 @@ +from enum import Enum, auto from functools import lru_cache from dataclasses import dataclass, fields, MISSING import numpy as np @@ -10,6 +11,21 @@ from .camera import get_camera +class ProcessingMode(Enum): + FFT = auto() + PHOT = auto() + FAINT = auto() + + +@dataclass +class StellarObject: + index: int + gal: object + psf: object + realized_flux: float + mode: ProcessingMode + + @dataclass class DiffractionFFT: exptime: float @@ -44,6 +60,169 @@ def from_config(cls, config: dict, base: dict) -> "DiffractionFFT": return cls(**kwargs) +def build_obj(stamp_config, base, logger): + PIXEL_SCALE = LSST_SiliconBuilder._pixel_scale + obj_num = base.get('obj_num', 0) + gal, _ = galsim.config.BuildGSObject(base, 'gal', logger=logger) + if gal is None: + return None + if not hasattr(gal, 'flux'): + # In this case, the object flux has not been precomputed + # or cached by the skyCatalogs code. + gal.flux = gal.calculateFlux(base['bandpass']) + rng = galsim.config.GetRNG(stamp_config, base, logger, "LSST_Silicon") + realized_flux = galsim.PoissonDeviate(rng, mean=gal.flux)() + if realized_flux == 0.: + return None + # For very bright things, we might want to change this for FFT drawing. + if 'fft_sb_thresh' in stamp_config: + fft_sb_thresh = galsim.config.ParseValue(stamp_config,'fft_sb_thresh',base,float)[0] + else: + fft_sb_thresh = 0. + set_image_pos(stamp_config, base, logger) + psf, _ = galsim.config.BuildGSObject(base, 'psf', gsparams={}, logger=logger) + if realized_flux < 1.e6 or not fft_sb_thresh or realized_flux < fft_sb_thresh: + use_fft = False + else: + bandpass = base['bandpass'] + fft_psf = make_fft_psf(psf.evaluateAtWavelength(bandpass.effective_wavelength), logger) + logger.info('Object %d has flux = %s. Check if we should switch to FFT', + obj_num, realized_flux) + # Now this object should have a much better estimate of the real maximum surface brightness + # than the original psf did. + # However, the max_sb feature gives an over-estimate, whereas to be conservative, we would + # rather an under-estimate. For this kind of profile, dividing by 2 does a good job + # of giving us an underestimate of the max surface brightness. + # Also note that `max_sb` is in photons/arcsec^2, so multiply by pixel_scale**2 + # to get photons/pixel, which we compare to fft_sb_thresh. + gal_achrom = gal.evaluateAtWavelength(bandpass.effective_wavelength) + fft_obj = galsim.Convolve(gal_achrom, fft_psf).withFlux(realized_flux) + max_sb = fft_obj.max_sb/2. * PIXEL_SCALE**2 + use_fft = max_sb > fft_sb_thresh + if use_fft: + psf = fft_psf + max_flux_simple = stamp_config.get('max_flux_simple', 100) + if use_fft: + mode = ProcessingMode.FFT + elif realized_flux < max_flux_simple: + mode = ProcessingMode.FAINT + else: + mode = ProcessingMode.PHOT + if use_fft: + logger.info('Yes. Use FFT for object %d. max_sb = %.0f > %.0f', + obj_num, max_sb, fft_sb_thresh) + else: + logger.info('No. Use photon shooting for object %d. ' + 'max_sb = %.0f <= %.0f', + base.get('obj_num'), max_sb, fft_sb_thresh) + + return StellarObject(obj_num, gal, psf, realized_flux, mode) + + +def set_image_pos(stamp_config, base, logger): + builder = LSST_SiliconBuilder() + xsize, ysize, image_pos, world_pos = builder.setup(stamp_config, base, 0.0, 0.0, galsim.config.stamp.stamp_ignore, logger) + builder.locateStamp(stamp_config, base, xsize, ysize, image_pos, world_pos, logger) + + +class PhotonStampBuilder(StampBuilder): + def setup(self, config, base, xsize, ysize, ignore, logger): + return LSST_SiliconBuilder().setup(config, base, xsize, ysize, ignore, logger) + + def getDrawMethod(self, config, base, logger): + return "phot" + + def updateOrigin(self, stamp, config, image): + return + + def draw(self, prof, image, method, offset, config, base, logger): + """Draw the profile on the postage stamp image. + + Parameters: + prof: The profile to draw. + image: The image onto which to draw the profile (which may be None). + method: The method to use in drawImage. + offset: The offset to apply when drawing. + config: The configuration dict for the stamp field. + base: The base configuration dict. + logger: A logger object to log progress. + + Returns: + the resulting image + """ + if prof is None: + # If was decide to do any rejection steps, this could be set to None, in which case, + # don't draw anything. + return image + + # Prof is normally a convolution here with obj_list being [gal, psf1, psf2,...] + # for some number of component PSFs. + gal, *psfs = prof.obj_list if hasattr(prof,'obj_list') else [prof] + obj_num = base.get('obj_num',0) + stellar_obj = base.get("_objects", {})[obj_num] # Use cached object + bandpass = base['bandpass'] + + def fix_seds(prof): + # If any SEDs are not currently using a LookupTable for the function or if they are + # using spline interpolation, then the codepath is quite slow. + # Better to fix them before doing WavelengthSampler. + if isinstance(prof, galsim.ChromaticObject): + wave_list, _, _ = galsim.utilities.combine_wave_list(prof.SED, bandpass) + sed = prof.SED + # TODO: This bit should probably be ported back to Galsim. + # Something like sed.make_tabulated() + if (not isinstance(sed._spec, galsim.LookupTable) + or sed._spec.interpolant != 'linear'): + # Workaround for https://github.com/GalSim-developers/GalSim/issues/1228 + f = np.broadcast_to(sed(wave_list), wave_list.shape) + new_spec = galsim.LookupTable(wave_list, f, interpolant='linear') + new_sed = galsim.SED( + new_spec, + 'nm', + 'fphotons' if sed.spectral else '1' + ) + prof.SED = new_sed + + # Also recurse onto any components. + if hasattr(prof, 'obj_list'): + for obj in prof.obj_list: + fix_seds(obj) + if hasattr(prof, 'original'): + fix_seds(prof.original) + + faint = stellar_obj.mode == ProcessingMode.FAINT + if faint: + logger.info("Flux = %.0f Using trivial sed", stellar_obj.realized_flux) + gal = gal.evaluateAtWavelength(bandpass.effective_wavelength) + gal = gal * LSST_SiliconBuilder._trivial_sed + else: + fix_seds(gal) + gal = gal.withFlux(stellar_obj.realized_flux, bandpass) + + # Put the psfs at the start of the photon_ops. + # Probably a little better to put them a bit later than the start in some cases + # (e.g. after TimeSampler, PupilAnnulusSampler), but leave that as a todo for now. + rng = galsim.config.GetRNG(config, base, logger, "LSST_Silicon") + gal.drawImage(bandpass, + method='phot', + offset=offset, + rng=rng, + maxN=None, + n_photons=stellar_obj.realized_flux, + image=image, + wcs=base['wcs'], + sensor=NullSensor(), + photon_ops=psfs, + add_to_image=True, + poisson_flux=False, + save_photons=True) + + return image.photons + +# Register this as a valid stamp type +RegisterStampType('PhotonStampBuilder', PhotonStampBuilder()) + + class LSST_SiliconBuilder(StampBuilder): """This performs the tasks necessary for building the stamp for a single object. @@ -100,7 +279,10 @@ def setup(self, config, base, xsize, ysize, ignore, logger): ignore = ['fft_sb_thresh', 'max_flux_simple'] + ignore params = galsim.config.GetAllParams(config, base, req=req, opt=opt, ignore=ignore)[0] - gal = galsim.config.BuildGSObject(base, 'gal', logger=logger)[0] + obj_num = base.get('obj_num',0) + stellar_obj = base.get("_objects", {}).get(obj_num) # Use cached object + gal = galsim.config.BuildGSObject(base, 'gal', logger=logger)[0] if stellar_obj is None else stellar_obj.gal + if gal is None: raise galsim.config.SkipThisObject('gal is None (invalid parameters)') self.gal = gal @@ -117,11 +299,14 @@ def setup(self, config, base, xsize, ysize, ignore, logger): camera = get_camera(params.get('camera', 'LsstCam')) if self.vignetting: self.det = camera[params['det_name']] - if not hasattr(gal, 'flux'): - # In this case, the object flux has not been precomputed - # or cached by the skyCatalogs code. - gal.flux = gal.calculateFlux(bandpass) - self.realized_flux = galsim.PoissonDeviate(self.rng, mean=gal.flux)() + if stellar_obj is not None: + self.realized_flux = stellar_obj.realized_flux + else: + if not hasattr(gal, 'flux'): + # In this case, the object flux has not been precomputed + # or cached by the skyCatalogs code. + gal.flux = gal.calculateFlux(bandpass) + self.realized_flux = galsim.PoissonDeviate(self.rng, mean=gal.flux)() # Check if the realized flux is 0. if self.realized_flux == 0: @@ -192,7 +377,7 @@ def setup(self, config, base, xsize, ysize, ignore, logger): base['current_noise_image'] = base['current_image'] noise_var = galsim.config.CalculateNoiseVariance(base) keep_sb_level = np.sqrt(noise_var)/8. - self._large_object_sb_level = 3*keep_sb_level + _large_object_sb_level = 3*keep_sb_level image_size = self._getGoodPhotImageSize([gal_achrom, psf], keep_sb_level, pixel_scale=self._pixel_scale) @@ -200,11 +385,11 @@ def setup(self, config, base, xsize, ysize, ignore, logger): # a somewhat brighter surface brightness limit. if image_size > self._Nmax: image_size = self._getGoodPhotImageSize([gal_achrom, psf], - self._large_object_sb_level, + _large_object_sb_level, pixel_scale=self._pixel_scale) image_size = min(image_size, self._Nmax) - logger.info('Object %d will use stamp size = %s',base.get('obj_num',0),image_size) + logger.info('Object %d will use stamp size = %s', obj_num, image_size) # Determine where this object is going to go: # This is the same as what the base StampBuilder does: @@ -394,101 +579,19 @@ def buildPSF(self, config, base, gsparams, logger): Returns: the PSF """ - psf = galsim.config.BuildGSObject(base, 'psf', gsparams=gsparams, logger=logger)[0] - - # For very bright things, we might want to change this for FFT drawing. - if 'fft_sb_thresh' in config: - fft_sb_thresh = galsim.config.ParseValue(config,'fft_sb_thresh',base,float)[0] - else: - fft_sb_thresh = 0. + # Use cached psf and mode (fft / phot): + obj_num = base['obj_num'] + stellar_obj = base.get('_objects', {})[obj_num] + self.use_fft = stellar_obj.mode == ProcessingMode.FFT + self.realized_flux = stellar_obj.realized_flux + if self.use_fft and self.vignetting is not None: + pix_to_fp = self.det.getTransform(cameraGeom.PIXELS, + cameraGeom.FOCAL_PLANE) + vignetted_flux = self.realized_flux*self.vignetting.at_sky_coord( + base['sky_pos'], self.image.wcs, pix_to_fp) + self.realized_flux = round(vignetted_flux) + return stellar_obj.psf - base['realized_flux'] = self.realized_flux - if self.realized_flux < 1.e6 or not fft_sb_thresh or self.realized_flux < fft_sb_thresh: - self.use_fft = False - return psf - - # Otherwise (high flux object), we might want to switch to fft. So be a little careful. - bandpass = base['bandpass'] - fft_psf = self.make_fft_psf(psf.evaluateAtWavelength(bandpass.effective_wavelength), logger) - logger.info('Object %d has flux = %s. Check if we should switch to FFT', - base['obj_num'], self.realized_flux) - - # Now this object should have a much better estimate of the real maximum surface brightness - # than the original psf did. - # However, the max_sb feature gives an over-estimate, whereas to be conservative, we would - # rather an under-estimate. For this kind of profile, dividing by 2 does a good job - # of giving us an underestimate of the max surface brightness. - # Also note that `max_sb` is in photons/arcsec^2, so multiply by pixel_scale**2 - # to get photons/pixel, which we compare to fft_sb_thresh. - gal_achrom = self.gal.evaluateAtWavelength(bandpass.effective_wavelength) - fft_obj = galsim.Convolve(gal_achrom, fft_psf).withFlux(self.realized_flux) - max_sb = fft_obj.max_sb/2. * self._pixel_scale**2 - logger.debug('max_sb = %s. cf. %s',max_sb,fft_sb_thresh) - if max_sb > fft_sb_thresh: - self.use_fft = True - # For FFT-rendered objects, the telescope vignetting isn't - # emergent as it is for the ray-traced objects, so use the - # empirical vignetting function, if it's available, to - # scale the realized flux. - if self.vignetting is not None: - pix_to_fp = self.det.getTransform(cameraGeom.PIXELS, - cameraGeom.FOCAL_PLANE) - vignetted_flux = self.realized_flux*self.vignetting.at_sky_coord( - base['sky_pos'], self.image.wcs, pix_to_fp) - self.realized_flux = round(vignetted_flux) - - logger.info('Yes. Use FFT for object %d. max_sb = %.0f > %.0f', - base.get('obj_num'), max_sb, fft_sb_thresh) - return fft_psf - else: - self.use_fft = False - logger.info('No. Use photon shooting for object %d. ' - 'max_sb = %.0f <= %.0f', - base.get('obj_num'), max_sb, fft_sb_thresh) - return psf - - def make_fft_psf(self, psf, logger): - """Swap out any PhaseScreenPSF component with a roughly equivalent analytic approximation. - """ - if isinstance(psf, galsim.Transformation): - return galsim.Transformation(self.make_fft_psf(psf.original, logger), - psf.jac, psf.offset, psf.flux_ratio, psf.gsparams) - elif isinstance(psf, galsim.Convolution): - obj_list = [self.make_fft_psf(p, logger) for p in psf.obj_list] - return galsim.Convolution(obj_list, gsparams=psf.gsparams) - elif isinstance(psf, galsim.SecondKick): - # The Kolmogorov version of the phase screen gets most of the second kick. - # The only bit that it missing is the Airy part, so convert the SecondKick to that. - return galsim.Airy(lam=psf.lam, diam=psf.diam, obscuration=psf.obscuration) - elif isinstance(psf, galsim.PhaseScreenPSF): - # If psf is a PhaseScreenPSF, then make a simpler one the just convolves - # a Kolmogorov profile with an OpticalPSF. - r0_500 = psf.screen_list.r0_500_effective - L0 = psf.screen_list[0].L0 - atm_psf = galsim.VonKarman(lam=psf.lam, r0_500=r0_500, L0=L0, gsparams=psf.gsparams) - - opt_screens = [s for s in psf.screen_list if isinstance(s, galsim.OpticalScreen)] - logger.info('opt_screens = %r',opt_screens) - if len(opt_screens) >= 1: - # Should never be more than 1, but if there weirdly is, just use the first. - # Note: Technically, if you have both a SecondKick and an optical screen, this - # will add the Airy part twice, since it's also part of the OpticalPSF. - # It doesn't usually matter, since we usually set doOpt=False, so we don't usually - # do this branch. If it is found to matter for someone, it will require a bit - # of extra logic to do it right. - opt_screen = opt_screens[0] - optical_psf = galsim.OpticalPSF( - lam=psf.lam, - diam=opt_screen.diam, - aberrations=opt_screen.aberrations, - annular_zernike=opt_screen.annular_zernike, - obscuration=opt_screen.obscuration, - gsparams=psf.gsparams) - return galsim.Convolve([atm_psf, optical_psf], gsparams=psf.gsparams) - else: - return atm_psf - else: - return psf def getDrawMethod(self, config, base, logger): """Determine the draw method to use. @@ -587,81 +690,96 @@ def fix_seds(prof): if 'maxN' in config: maxN = galsim.config.ParseValue(config, 'maxN', base, int)[0] - if method == 'fft': - fft_image = image.copy() - fft_offset = offset - kwargs = dict( - method='fft', - offset=fft_offset, - image=fft_image, - wcs=wcs, - ) - if not faint and config.get('fft_photon_ops'): - kwargs.update({ - "photon_ops": galsim.config.BuildPhotonOps(config, 'fft_photon_ops', base, logger), - "maxN": maxN, - "rng": self.rng, - "n_subsample": 1, - }) - - # Go back to a combined convolution for fft drawing. - prof = galsim.Convolve([gal] + psfs) - try: - prof.drawImage(bandpass, **kwargs) - except galsim.errors.GalSimFFTSizeError as e: - # I think this shouldn't happen with the updates I made to how the image size - # is calculated, even for extremely bright things. So it should be ok to - # just report what happened, give some extra information to diagonose the problem - # and raise the error. - logger.error('Caught error trying to draw using FFT:') - logger.error('%s',e) - logger.error('You may need to add a gsparams field with maximum_fft_size to') - logger.error('either the psf or gal field to allow larger FFTs.') - logger.info('prof = %r',prof) - logger.info('fft_image = %s',fft_image) - logger.info('offset = %r',offset) - logger.info('wcs = %r',wcs) - raise - # Some pixels can end up negative from FFT numerics. Just set them to 0. - fft_image.array[fft_image.array < 0] = 0. - if self.diffraction_fft: - self.diffraction_fft.apply(fft_image, bandpass.effective_wavelength) - fft_image.addNoise(galsim.PoissonNoise(rng=self.rng)) - # In case we had to make a bigger image, just copy the part we need. - image += fft_image[image.bounds] - - else: - if not faint and 'photon_ops' in config: - photon_ops = galsim.config.BuildPhotonOps(config, 'photon_ops', base, logger) - else: - photon_ops = [] - # Put the psfs at the start of the photon_ops. - # Probably a little better to put them a bit later than the start in some cases - # (e.g. after TimeSampler, PupilAnnulusSampler), but leave that as a todo for now. - photon_ops = psfs + photon_ops - - if faint: - sensor = None - else: - sensor = base.get('sensor', None) - if sensor is not None: - sensor.updateRNG(self.rng) - - gal.drawImage(bandpass, - method='phot', - offset=offset, - rng=self.rng, - maxN=maxN, - n_photons=self.realized_flux, - image=image, - wcs=wcs, - sensor=sensor, - photon_ops=photon_ops, - add_to_image=True, - poisson_flux=False) + fft_image = image.copy() + fft_offset = offset + kwargs = dict( + method='fft', + offset=fft_offset, + image=fft_image, + wcs=wcs, + ) + if not faint and config.get('fft_photon_ops'): + kwargs.update({ + "photon_ops": galsim.config.BuildPhotonOps(config, 'fft_photon_ops', base, logger), + "maxN": maxN, + "rng": self.rng, + "n_subsample": 1, + }) + + # Go back to a combined convolution for fft drawing. + prof = galsim.Convolve([gal] + psfs) + try: + prof.drawImage(bandpass, **kwargs) + except galsim.errors.GalSimFFTSizeError as e: + # I think this shouldn't happen with the updates I made to how the image size + # is calculated, even for extremely bright things. So it should be ok to + # just report what happened, give some extra information to diagonose the problem + # and raise the error. + logger.error('Caught error trying to draw using FFT:') + logger.error('%s',e) + logger.error('You may need to add a gsparams field with maximum_fft_size to') + logger.error('either the psf or gal field to allow larger FFTs.') + logger.info('prof = %r',prof) + logger.info('fft_image = %s',fft_image) + logger.info('offset = %r',offset) + logger.info('wcs = %r',wcs) + raise + # Some pixels can end up negative from FFT numerics. Just set them to 0. + fft_image.array[fft_image.array < 0] = 0. + print(self.diffraction_fft) + if self.diffraction_fft: + self.diffraction_fft.apply(fft_image, bandpass.effective_wavelength) + fft_image.addNoise(galsim.PoissonNoise(rng=self.rng)) + # In case we had to make a bigger image, just copy the part we need. + image += fft_image[image.bounds] return image +def make_fft_psf(psf, logger): + """Swap out any PhaseScreenPSF component with a roughly equivalent analytic approximation. + """ + if isinstance(psf, galsim.Transformation): + return galsim.Transformation(make_fft_psf(psf.original, logger), + psf.jac, psf.offset, psf.flux_ratio, psf.gsparams) + if isinstance(psf, galsim.Convolution): + obj_list = [make_fft_psf(p, logger) for p in psf.obj_list] + return galsim.Convolution(obj_list, gsparams=psf.gsparams) + if isinstance(psf, galsim.SecondKick): + # The Kolmogorov version of the phase screen gets most of the second kick. + # The only bit that it missing is the Airy part, so convert the SecondKick to that. + return galsim.Airy(lam=psf.lam, diam=psf.diam, obscuration=psf.obscuration) + if isinstance(psf, galsim.PhaseScreenPSF): + # If psf is a PhaseScreenPSF, then make a simpler one the just convolves + # a Kolmogorov profile with an OpticalPSF. + r0_500 = psf.screen_list.r0_500_effective + L0 = psf.screen_list[0].L0 + atm_psf = galsim.VonKarman(lam=psf.lam, r0_500=r0_500, L0=L0, gsparams=psf.gsparams) + + opt_screens = [s for s in psf.screen_list if isinstance(s, galsim.OpticalScreen)] + logger.info('opt_screens = %r',opt_screens) + if len(opt_screens) >= 1: + # Should never be more than 1, but if there weirdly is, just use the first. + # Note: Technically, if you have both a SecondKick and an optical screen, this + # will add the Airy part twice, since it's also part of the OpticalPSF. + # It doesn't usually matter, since we usually set doOpt=False, so we don't usually + # do this branch. If it is found to matter for someone, it will require a bit + # of extra logic to do it right. + opt_screen = opt_screens[0] + optical_psf = galsim.OpticalPSF( + lam=psf.lam, + diam=opt_screen.diam, + aberrations=opt_screen.aberrations, + annular_zernike=opt_screen.annular_zernike, + obscuration=opt_screen.obscuration, + gsparams=psf.gsparams) + return galsim.Convolve([atm_psf, optical_psf], gsparams=psf.gsparams) + return atm_psf + return psf + + +class NullSensor(galsim.Sensor): + def accumulate(self, photons, image, orig_center=None, resume=False): + return 0. # Register this as a valid stamp type RegisterStampType('LSST_Silicon', LSST_SiliconBuilder())