Skip to content

Commit

Permalink
Weak lensing updates (#1074)
Browse files Browse the repository at this point in the history
* Notebook to demonstrate making predictions with pretrained weights

* Split shear into two components in prior; add avg shear/conv to tile catalog and bypass checks in catalog.py

* updated reporting of train and val loss

* updated on_validation_epoch to include metric calculations and updates

* Fix names in lensing prior and decoder so they match what's expected in metrics

* Prep lensing config for training on simulated images

* Small tweaks to lensing maps

* Slightly harder setting: constant shear within image, variable shear between images

* Remove vmin and vmax in lensing maps

* Lensing metrics: fix MSE denominator and use zero as baseline shear prediction

* Split up config into two; use better paths for plots during training

* Use better output path

* Refactor lensing maps

* Finish refactoring lensing maps

* Small updates to dc2 config

* Fix bug in lensing plots by resetting to empty list

* Scatterplots of shear1, shear2, and convergence during training

* Add dashed y=x line to scatterplots

* Better matplotlib approach for training and val loss plots

* Specify random seed in generate

* Small syntax things in lensing encoder

* New random seed in generate_one_file to match master

* Remove empty line to match master

* Update logger names and patience/max_epochs in lensing configs

* Update encoder evaluation notebook

* Update weak lensing README

---------

Co-authored-by: Shreyas Chandrashekaran <[email protected]>
  • Loading branch information
timwhite0 and Shreyas Chandrashekaran authored Oct 2, 2024
1 parent 1868ae0 commit 1baf4d4
Show file tree
Hide file tree
Showing 11 changed files with 820 additions and 215 deletions.
2 changes: 2 additions & 0 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def to_full_catalog(self, tile_slen):
continue
if param_name == "locs": # full catalog uses plocs instead of locs
continue
if param_name in {"shear", "shear_1", "shear_2", "convergence"}:
continue
k = tile_param.shape[-1]
param = rearrange(tile_param, "b nth ntw s k -> b (nth ntw s) k", k=k)
indices_for_param = repeat(indices_to_retrieve, "b nth_ntw_s -> b nth_ntw_s k", k=k)
Expand Down
26 changes: 20 additions & 6 deletions case_studies/weak_lensing/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
### Mapping dark matter using astronomical images
### Winter 2024
#### Steve Fan and Tahseen Younus
#### Supervised by Tim White and Jeffrey Regier
### Neural posterior estimation of weak lensing shear and convergence for the LSST DESC DC2 simulated sky survey
#### Shreyas Chandrashekaran, Tim White, Camille Avestruz, and Jeffrey Regier, with assistance from Steve Fan and Tahseen Younus

This case study aims to estimate weak lensing shear and convergence for the DC2 simulated sky survey. See `notebooks/dc2/evaluate_encoder.ipynb` for our most recent results.

Gravitational lensing is an important probe of the large-scale structure of the universe, as it allows astronomers and physicists to create mass maps reflecting the distribution of dark matter across the sky. However, weak lensing, the most prevalent form of lensing, is difficult for the human eye to discern in astronomical images. We propose a novel method for simultaneously estimating shear and convergence, two weak lensing observables, from images using neural posterior estimation. After experimentation, we conclude that the current model architecture is suitable for learning complex spatial patterns of shear and convergence, but has trouble producing estimates on the correct scale.
Some useful commands:

- Train `lensing_encoder` on DC2 images

*This project was completed through the Undergraduate Research Program in Statistics (URPS), a competitive program that pairs promising undergraduates with Statistics faculty on a research project for the winter semester. For future URPS opportunities, see [here](https://lsa.umich.edu/stats/undergraduate-students/undergraduate-research-opportunities-.html).*
```
nohup bliss -cp /home/twhit/bliss/case_studies/weak_lensing/ -cn lensing_config_dc2.yaml mode=train &> train_on_dc2.out &
```

- Generate synthetic images with shear and convergence, as specified in `lensing_prior`

```
nohup bliss -cp /home/twhit/bliss/case_studies/weak_lensing/ -cn lensing_config_simulator.yaml mode=generate &> generate_synthetic.out &
```

- Train `lensing_encoder` on synthetic images:

```
nohup bliss -cp /home/twhit/bliss/case_studies/weak_lensing/ -cn lensing_config_simulator.yaml mode=train &> train_on_synthetic.out &
```
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,8 @@ defaults:
mode: train

paths:

dc2: /data/scratch/dc2local
output: /data/scratch/shreyasc/bliss_output

prior:
_target_: case_studies.weak_lensing.lensing_prior.LensingPrior
n_tiles_h: 12 # cropping 2 tiles from each side (4 total)
n_tiles_w: 12 # cropping 2 tiles from each side (4 total)
batch_size: 2
max_sources: 200
constant_shear: 0.2
constant_convergence: 0.2
prob_galaxy: 1.0
mean_sources: 82 # 0.02 * (256/4) * (256/4)

decoder:
_target_: case_studies.weak_lensing.lensing_decoder.LensingDecoder
tile_slen: 256
use_survey_background: false
with_dither: false
with_noise: false

cached_simulator:
batch_size: 2
train_transforms: []
output: /home/twhit/bliss/

variational_factors:
- _target_: bliss.encoder.variational_dist.NormalFactor
Expand All @@ -40,20 +17,11 @@ variational_factors:
- _target_: bliss.encoder.variational_dist.NormalFactor
name: shear_2
nll_gating: null
# - _target_: bliss.encoder.variational_dist.BivariateNormalFactor
# name: shear
# nll_gating: null
# - _target_: bliss.encoder.variational_dist.NormalFactor
# name: convergence
# nll_gating: null
# high_clamp: 20.0
# low_clamp: -20.0
- _target_: bliss.encoder.variational_dist.NormalFactor
name: convergence
nll_gating: null

my_normalizers:
# asinh:
# _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer
# q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999]
# sample_every_n: 4
nully:
_target_: bliss.encoder.image_normalizer.NullNormalizer

Expand All @@ -63,11 +31,9 @@ my_metrics:

my_render:
lensing_shear_conv:
_target_: case_studies.weak_lensing.lensing_plots.PlotWeakLensingShearConvergence
_target_: case_studies.weak_lensing.lensing_plots.PlotLensingMaps
frequency: 1
restrict_batch: 0
tile_slen: 256
save_local: lensing_maps
save_local: ${paths.output}/${train.trainer.logger.name}/${train.trainer.logger.version}/lensing_maps

encoder:
_target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder
Expand All @@ -82,7 +48,6 @@ encoder:
milestones: [32]
gamma: 0.1
image_normalizers: ${my_normalizers}

var_dist:
_target_: bliss.encoder.variational_dist.VariationalDist
tile_slen: ${encoder.tile_slen}
Expand All @@ -101,7 +66,7 @@ encoder:
metrics: ${my_render}
use_double_detect: false
use_checkerboard: false
train_loss_location: train_loss
loss_plots_location: ${paths.output}/${train.trainer.logger.name}/${train.trainer.logger.version}/loss_plots

surveys:
dc2:
Expand All @@ -116,19 +81,20 @@ surveys:
avg_ellip_kernel_sigma: 3
batch_size: 1
num_workers: 1
cached_data_path: ${paths.output}/dc2_corrected_shear_only_cd_fix
cached_data_path: ${paths.dc2}/dc2_corrected_shear_only_cd_fix

# generate:
# n_image_files: 50
# n_batches_per_file: 4
train:
trainer:
logger:
name: dc2_weak_lensing_exp
version: exp_09_16
devices: [0] # cuda:0 for gl
name: weak_lensing_experiments_dc2
version: october1
max_epochs: 250
devices: 1
use_distributed_sampler: false
precision: 32-true
callbacks:
early_stopping:
patience: 50
data_source: ${surveys.dc2}
pretrained_weights: null
seed: 123123
113 changes: 113 additions & 0 deletions case_studies/weak_lensing/lensing_config_simulator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
---
defaults:
- ../../bliss/conf@_here_: base_config
- _self_
- override hydra/job_logging: stdout

mode: train

paths:
cached_data: /data/scratch/weak_lensing/weak_lensing_img2048_constantshear
output: /home/twhit/bliss/

prior:
_target_: case_studies.weak_lensing.lensing_prior.LensingPrior
n_tiles_h: 12 # cropping 2 tiles from each side (4 total)
n_tiles_w: 12 # cropping 2 tiles from each side (4 total)
batch_size: 2
max_sources: 200
shear_min: -0.1
shear_max: 0.1
convergence_min: -0.0001
convergence_max: 0.0001
prob_galaxy: 1.0
mean_sources: 82 # 0.02 * (256/4) * (256/4)

decoder:
_target_: case_studies.weak_lensing.lensing_decoder.LensingDecoder
tile_slen: 256
use_survey_background: false
with_dither: false
with_noise: false

cached_simulator:
batch_size: 1
train_transforms: []

variational_factors:
- _target_: bliss.encoder.variational_dist.NormalFactor
name: shear_1
nll_gating: null
- _target_: bliss.encoder.variational_dist.NormalFactor
name: shear_2
nll_gating: null
- _target_: bliss.encoder.variational_dist.NormalFactor
name: convergence
nll_gating: null

my_normalizers:
nully:
_target_: bliss.encoder.image_normalizer.NullNormalizer

my_metrics:
lensing_map:
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE

my_render:
lensing_shear_conv:
_target_: case_studies.weak_lensing.lensing_plots.PlotLensingMaps
frequency: 1
save_local: ${paths.output}/${train.trainer.logger.name}/${train.trainer.logger.version}/lensing_maps

encoder:
_target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder
survey_bands: [u, g, r, i, z]
reference_band: 2 # r-band
tile_slen: 256
n_tiles: 8
nch_hidden: 64
optimizer_params:
lr: 1e-3
scheduler_params:
milestones: [32]
gamma: 0.1
image_normalizers: ${my_normalizers}
var_dist:
_target_: bliss.encoder.variational_dist.VariationalDist
tile_slen: ${encoder.tile_slen}
factors: ${variational_factors}
mode_metrics:
_target_: torchmetrics.MetricCollection
_convert_: partial
metrics: ${my_metrics}
sample_metrics:
_target_: torchmetrics.MetricCollection
_convert_: partial
metrics: ${my_metrics}
sample_image_renders:
_target_: torchmetrics.MetricCollection
_convert_: partial
metrics: ${my_render}
use_double_detect: false
use_checkerboard: false
loss_plots_location: ${paths.output}/${train.trainer.logger.name}/${train.trainer.logger.version}/loss_plots

generate:
n_image_files: 50
n_batches_per_file: 4

train:
trainer:
logger:
name: weak_lensing_experiments_simulator
version: october1
max_epochs: 250
devices: 1
use_distributed_sampler: false
precision: 32-true
callbacks:
early_stopping:
patience: 50
data_source: ${cached_simulator}
pretrained_weights: null
seed: 123123
1 change: 0 additions & 1 deletion case_studies/weak_lensing/lensing_convnet_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def forward(self, x):
out = self.conv1(x)
out = self.gn1(out)
out = self.silu(out)

out = self.conv2(out)
out = self.gn2(out)

Expand Down
6 changes: 3 additions & 3 deletions case_studies/weak_lensing/lensing_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def render_galaxy(self, psf, band, source_params):
"""
galaxy = self.render_bulge_plus_disk(band, source_params)

shear = source_params["shear"]
shear1, shear2 = shear
convergence = source_params["convergence"]
shear1 = source_params["shear_1_per_galaxy"]
shear2 = source_params["shear_2_per_galaxy"]
convergence = source_params["convergence_per_galaxy"]

reduced_shear1 = shear1 / (1 - convergence)
reduced_shear2 = shear2 / (1 - convergence)
Expand Down
Loading

0 comments on commit 1baf4d4

Please sign in to comment.