Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noa #151

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Noa #151

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions configs/benchmark_hyperspectral.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# python scripts/eval/benchmark_recon.py
#Hydra config
hydra:
run:
dir: "benchmark/${now:%Y-%m-%d}/${now:%H-%M-%S}"
job:
chdir: True


dataset: PolarLitis # DiffuserCam, DigiCamCelebA, HFDataset
seed: 0
batchsize: 1 # must be 1 for iterative approaches

huggingface:
repo: "noakraicer/polarlitisnpy"
cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`.
psf: psf.mat
mask: mask.npy # null for simulating PSF
image_res: [250, 250] # used during measurement
rotate: False # if measurement is upside-down
flipud: False
flip_lensed: False # if rotate or flipud is True, apply to lensed

alignment:
top_left: null
height: null

downsample: 1
downsample_lensed: 2
split_seed: null
single_channel_psf: True

device: "cuda"
# numbers of iterations to benchmark
n_iter_range: [2000]
# number of files to benchmark
n_files: null # null for all files
#How much should the image be downsampled
downsample: 2
#algorithm to benchmark
algorithms: ["HyperSpectralFISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"]

# baseline from Monakhova et al. 2019, https://arxiv.org/abs/1908.11502
baseline: "MONAKHOVA 100iter"

save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10]
gamma_psf: 1.5 # gamma factor for PSF


# Hyperparameters
nesterov:
p: 0
mu: 0.9
fista:
tk: 1
admm:
mu1: 1e-6
mu2: 1e-5
mu3: 4e-5
tau: 0.0001


# for DigiCamCelebA
files:
test_size: 0.15
downsample: 1
celeba_root: /scratch/bezzam


# dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K
# psf: data/psf/adafruit_random_2mm_20231907.png
# vertical_shift: null
# horizontal_shift: null
# crop: null

dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K
psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png
vertical_shift: -117
horizontal_shift: -25
crop:
vertical: [0, 525]
horizontal: [265, 695]

# for prepping ground truth data
#for simulated dataset
simulation:
grayscale: False
output_dim: null # should be set if no PSF is used
# random variations
object_height: 0.33 # [m], range for random height or scalar
flip: True # change the orientation of the object (from vertical to horizontal)
random_shift: False
random_vflip: 0.5
random_hflip: 0.5
random_rotate: False
# these distance parameters are typically fixed for a given PSF
# for DiffuserCam psf # for tape_rgb psf
# scene2mask: 10e-2 # scene2mask: 40e-2
# mask2sensor: 9e-3 # mask2sensor: 4e-3
# -- for CelebA
scene2mask: 0.25 # [m]
mask2sensor: 0.002 # [m]
deadspace: True # whether to account for deadspace for programmable mask
# see waveprop.devices
use_waveprop: False # for PSF simulation
sensor: "rpi_hq"
snr_db: 10
# simulate different sensor resolution
# output_dim: [24, 32] # [H, W] or null
# Downsampling for PSF
downsample: 8
# max val in simulated measured (quantized 8 bits)
quantize: False # must be False for differentiability
max_val: 255
26 changes: 0 additions & 26 deletions configs/upload_dataset_huggingface.yaml

This file was deleted.

1 change: 1 addition & 0 deletions lensless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NesterovGradientDescent,
FISTA,
GradientDescentUpdate,
HyperSpectralFISTA
)
from .recon.tikhonov import CodedApertureReconstruction
from .hardware.sensor import VirtualSensor, SensorOptions
Expand Down
94 changes: 83 additions & 11 deletions lensless/recon/gd.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class GradientDescent(ReconstructionAlgorithm):
Object for applying projected gradient descent.
"""

def __init__(self, psf, dtype=None, proj=non_neg, **kwargs):
def __init__(self, psf,mask, dtype=None, proj=non_neg, **kwargs):
"""

Parameters
Expand All @@ -83,30 +83,30 @@ def __init__(self, psf, dtype=None, proj=non_neg, **kwargs):

assert callable(proj)
self._proj = proj
super(GradientDescent, self).__init__(psf, dtype, **kwargs)
super(GradientDescent, self).__init__(psf,mask, dtype, **kwargs)

if self._denoiser is not None:
print("Using denoiser in gradient descent.")
# redefine projection function
self._proj = self._denoiser

self.mask=mask
def reset(self):
if self.is_torch:
if self._initial_est is not None:
self._image_est = self._initial_est
else:
# initial guess, half intensity image
psf_flat = self._psf.reshape(-1, self._psf_shape[3])
pixel_start = (
torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
) / 2
# psf_flat = self._psf.reshape(-1, self._psf_shape[3])
# pixel_start = (
# torch.max(psf_flat, axis=0).values + torch.min(psf_flat, axis=0).values
# ) / 2
# initialize image estimate as [Batch, Depth, Height, Width, Channels]
self._image_est = torch.ones_like(self._psf[None, ...]) * pixel_start
self._image_est = torch.zeros((1,250,250,3)).to(self._psf.device)

# set step size as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha = torch.real(1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)
self._alpha = 1/4770.13

else:
if self._initial_est is not None:
Expand All @@ -123,8 +123,8 @@ def reset(self):
self._alpha = np.real(1.8 / np.max(Hadj_flat * H_flat, axis=0))

def _grad(self):
diff = self._convolver.convolve(self._image_est) - self._data
return self._convolver.deconvolve(diff)
diff = torch.sum(self.mask * self._convolver.convolve(self._image_est), axis=-1, keepdims=True) - self._data # (H, W, 1)
return self._convolver.deconvolve(diff * self.mask) # (H, W, C) where C is number of hyperspectral channels

def _update(self, iter):
self._image_est -= self._alpha * self._grad()
Expand Down Expand Up @@ -238,6 +238,78 @@ def _update(self, iter):
self._xk = xk


def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs):

# load data
psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs)

# create reconstruction object
recon = GradientDescent(psf, n_iter=n_iter, proj=proj)

# set data
recon.set_data(data)

# perform reconstruction
start_time = time.time()
res = recon.apply(plot=False)
proc_time = time.time() - start_time

if verbose:
print(f"Reconstruction time : {proc_time} s")
print(f"Reconstruction shape: {res.shape}")
return res
class HyperSpectralFISTA(GradientDescent):
"""
Object for applying projected gradient descent with FISTA (Fast Iterative
Shrinkage-Thresholding Algorithm) for acceleration.

Paper: https://www.ceremade.dauphine.fr/~carlier/FISTA

"""

def __init__(self, psf,mask, dtype=None, proj=non_neg, tk=1.0, **kwargs):
"""

Parameters
----------
psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor`
Point spread function (PSF) that models forward propagation.
Must be of shape (depth, height, width, channels) even if
depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf`
to load a PSF from a file such that it is in the correct format.
dtype : float32 or float64
Data type to use for optimization. Default is float32.
proj : :py:class:`function`
Projection function to apply at each iteration. Default is
non-negative.
tk : float
Initial step size parameter for FISTA. It is updated at each iteration
according to Eq. 4.2 of paper. By default, initialized to 1.0.

"""
self._initial_tk = tk

super(HyperSpectralFISTA, self).__init__(psf,mask, dtype, proj, **kwargs)

self._tk = tk
self._xk = self._image_est

def reset(self, tk=None):
super(HyperSpectralFISTA, self).reset()
if tk:
self._tk = tk
else:
self._tk = self._initial_tk
self._xk = self._image_est
def _update(self, iter):
self._image_est -= self._alpha * self._grad()
xk = self._form_image()
tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2
self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk)
self._tk = tk
self._xk = xk


def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs):

# load data
Expand Down
17 changes: 11 additions & 6 deletions lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class ReconstructionAlgorithm(abc.ABC):
def __init__(
self,
psf,
mask,
dtype=None,
pad=True,
n_iter=100,
Expand Down Expand Up @@ -369,12 +370,13 @@ def set_data(self, data):
assert len(data.shape) >= 3, "Data must be at least 3D: [..., width, height, channel]."

# assert same shapes
assert np.all(
self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1]
), "PSF and data shape mismatch"

if len(data.shape) == 3:
self._data = data[None, None, ...]
# assert np.all(
# self._psf_shape[-3:-1] == np.array(data.shape)[-3:-1]
# ), "PSF and data shape mismatch"
if len(data.shape)==3:
self._data = data.unsqueeze(-1)
# if len(data.shape) == 3:
# self._data = data[None, None, ...]
elif len(data.shape) == 4:
self._data = data[None, ...]
else:
Expand Down Expand Up @@ -569,6 +571,9 @@ def apply(

for i in range(n_iter):
self._update(i)
if i%50==0:
img = self._form_image()

if self.compensation_branch is not None and i < self._n_iter - 1:
self.compensation_branch_inputs.append(self._form_image())

Expand Down
6 changes: 3 additions & 3 deletions lensless/recon/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class RealFFTConvolve2D:
def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs):
def __init__(self, psf, dtype=None, pad=True, norm=None, rgb=None, **kwargs):
"""
Linear operator that performs convolution in Fourier domain, and assumes
real-valued signals.
Expand Down Expand Up @@ -135,10 +135,10 @@ def convolve(self, x):
Convolve with pre-computed FFT of provided PSF.
"""
if self.pad:
self._padded_data = self._pad(x)
self._padded_data = self._pad(x).to(self._psf.device)
else:
if self.is_torch:
self._padded_data = x # .type(self.dtype).to(self._psf.device)
self._padded_data = x
else:
self._padded_data[:] = x # .astype(self.dtype)

Expand Down
Loading