Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve visu of loss #225

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ target/
*.swo

# dwi_ml stuff
.ipynb_config/
.ipynb_config/
.ipynb_checkpoints/
3 changes: 2 additions & 1 deletion dwi_ml/data/processing/dwi/dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from dipy.reconst.shm import sph_harm_lookup
import nibabel as nib
import numpy as np

from scilpy.io.utils import validate_sh_basis_choice
from scilpy.reconst.raw_signal import compute_sh_coefficients
from scilpy.reconst.sh import compute_sh_coefficients

eps = 1e-6

Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from nibabel.streamlines.tractogram import (PerArrayDict, PerArraySequenceDict)
import numpy as np

from scilpy.tracking.tools import resample_streamlines_step_size
from scilpy.tractograms.streamline_operations import \
resample_streamlines_step_size
from scilpy.utils.streamlines import compress_sft


Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/data/processing/streamlines/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from scilpy.tractanalysis.tools import \
extract_longest_segments_from_profile as segmenting_func
from scilpy.tractanalysis.uncompress import uncompress
from scilpy.tractograms.uncompress import uncompress

# We could try using nan instead of zeros for non-existing previous dirs...
DEFAULT_UNEXISTING_VAL = torch.zeros((1, 3), dtype=torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def add_memory_args(p: ArgumentParser, add_lazy_options=False,
add_processes_arg(ram_options)
ram_options.add_argument(
'--use_gpu', action='store_true',
help="If set, use GPU for processing. Cannot be used \ntogether "
help="If set, use GPU for processing. Cannot be used together "
"with --processes.")
else:
p.add_argument('--use_gpu', action='store_true',
Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/testing/projects/transformer_visualisation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import add_reference_arg, add_overwrite_arg, add_bbox_arg
from scilpy.tracking.tools import resample_streamlines_step_size
from scilpy.tractograms.streamline_operations import \
resample_streamlines_step_size
from scilpy.utils.streamlines import compress_sft

from dwi_ml.io_utils import add_logging_arg
Expand Down
31 changes: 18 additions & 13 deletions dwi_ml/testing/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
from tqdm import tqdm

from dwi_ml.data.processing.streamlines.data_augmentation import \
resample_or_compress
Expand All @@ -23,6 +24,7 @@ class Tester:
from the hdf5. This choice allows to test the loss on various bundles for
a better interpretation of the models' performances.
"""

def __init__(self, experiment_path: str, model: ModelWithDirectionGetter,
batch_size: int = None, device: torch.device = None):
"""
Expand Down Expand Up @@ -101,8 +103,7 @@ def _volume_groups(self):

def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,
compute_loss=True, uncompress_loss=False,
force_compress_loss=False,
weight_with_angle=False):
force_compress_loss=False, weight_with_angle=False):
"""
Equivalent of one validation pass.

Expand Down Expand Up @@ -153,9 +154,9 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,
batch_start = 0
batch_end = batch_size
with torch.no_grad():
for batch in range(nb_batches):
logging.info(" Batch #{}: {} - {}"
.format(batch + 1, batch_start, batch_end))
for batch in tqdm(range(nb_batches),
desc="Batches", total=nb_batches):

# 1. Prepare batch
streamlines = [
torch.as_tensor(s, dtype=torch.float32, device=self.device)
Expand Down Expand Up @@ -201,24 +202,28 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,

if self.model.direction_getter.compress_loss:
total_n = sum(compressed_n)
total_loss = sum([loss * n for loss, n in
zip(losses, compressed_n)]) / total_n
mean_loss = sum([loss * n for loss, n in
zip(losses, compressed_n)]) / total_n
print("Loss function, averaged over all {} compressed points "
"in the chosen SFT, is: {}.".format(total_n, mean_loss))
else:
total_n = sum([len(line_loss) for line_loss in losses])
total_loss = torch.mean(torch.hstack(losses))

print("Loss function, averaged over all {} points in the chosen "
"SFT, is: {}.".format(total_n, total_loss))
tmp = torch.hstack(losses)
# \u00B1 is the plus or minus sign.
print(u"Loss function, averaged over all {} points in the "
"chosen SFT, is: {} \u00B1 {}. Min: {}. Max: {}"
.format(total_n, torch.mean(tmp), torch.std(tmp),
torch.min(tmp), torch.max(tmp)))

if (not self.model.direction_getter.compress_loss) and \
add_zeros_if_no_eos and \
not self.model.direction_getter.add_eos:
zero = torch.zeros(1)
losses = [torch.hstack([line, zero]) for line in losses]
total_n = sum([len(line_loss) for line_loss in losses])
total_loss = torch.mean(torch.hstack(losses))
mean_loss = torch.mean(torch.hstack(losses))
print("When adding a 0 loss at the EOS position, the mean "
"loss for {} points is {}.".format(total_n, total_loss))
"loss for {} points is {}.".format(total_n, mean_loss))

self.model.direction_getter.compress_loss = save_val
self.model.direction_getter.weight_loss_with_angle = save_val_angle
Expand Down
19 changes: 15 additions & 4 deletions dwi_ml/testing/visu_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def prepare_args_visu_loss(p: ArgumentParser, use_existing_experiment=True):
if use_existing_experiment:
# Should only be False for debugging tests.
add_arg_existing_experiment_path(p)
p.add_argument('--use_latest_epoch', action='store_true',
help="If true, use model at latest epoch rather than "
"default (best model).")
p.add_argument('--uncompress_loss', action='store_true',
help="If model uses compressed loss, take uncompressed "
"equivalent.")
Expand All @@ -45,8 +48,9 @@ def prepare_args_visu_loss(p: ArgumentParser, use_existing_experiment=True):
add_args_testing_subj_hdf5(p)

# Options
p.add_argument('--batch_size', type=int)
add_memory_args(p)
g = add_memory_args(p)
g.add_argument('--batch_size', type=int, metavar='n',
help="Batch size in number of streamlines. Default: None.")

g = p.add_argument_group("Options to save loss as a colored SFT")
g.add_argument('--save_colored_tractogram', metavar='out_name.trk',
Expand Down Expand Up @@ -221,7 +225,14 @@ def combine_displacement_with_ref(out_dirs, sft, step_size_mm=None):
def run_visu_save_colored_displacement(
args, model: ModelWithDirectionGetter, losses: List[torch.Tensor],
outputs: List[torch.Tensor], sft: StatefulTractogram,
colorbar_name: str, best_sft_name: str, worst_sft_name: str):
colorbar_name: str, best_sft_name: str, worst_sft_name: str,
show_histogram: bool = True):

if show_histogram:
tmp = torch.hstack(losses)
plt.figure()
_ = plt.hist(tmp.numpy(), bins='auto')
plt.title("Histogram of losses")

if model.direction_getter.compress_loss:
if not ('uncompress_loss' in args and args.uncompress_loss):
Expand Down Expand Up @@ -291,5 +302,5 @@ def run_visu_save_colored_displacement(

save_tractogram(sft, args.out_displacement_sft, bbox_valid_check=False)

if args.show_colorbar:
if args.show_colorbar or show_histogram:
plt.show()
6 changes: 3 additions & 3 deletions dwi_ml/tracking/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,18 @@ def prepare_seed_generator(parser, args, hdf_handle):
seed_generator = SeedGenerator(seed_data, seed_res, space=ALWAYS_VOX_SPACE,
origin=ALWAYS_CORNER)

if len(seed_generator.seeds_vox) == 0:
if len(seed_generator.seeds_vox_corner) == 0:
parser.error('Seed mask "{}" does not have any voxel with value > 0.'
.format(args.in_seed))

if args.npv:
# Note. Not really nb seed per voxel, just in average.
nbr_seeds = len(seed_generator.seeds_vox) * args.npv
nbr_seeds = len(seed_generator.seeds_vox_corner) * args.npv
elif args.nt:
nbr_seeds = args.nt
else:
# Setting npv = 1.
nbr_seeds = len(seed_generator.seeds_vox)
nbr_seeds = len(seed_generator.seeds_vox_corner)

seed_header = nib.Nifti1Image(seed_data, affine).header

Expand Down
5 changes: 2 additions & 3 deletions dwi_ml/tracking/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,9 @@ def _gpu_simultaneous_tracking(self):
if seed_count + nb_next_seeds > self.nbr_seeds:
nb_next_seeds = self.nbr_seeds - seed_count

next_seeds = np.arange(seed_count, seed_count + nb_next_seeds)

n_seeds = self.seed_generator.get_next_n_pos(
random_generator, indices, next_seeds)
random_generator, indices, which_seed_start=seed_count,
n=nb_next_seeds)

tmp_lines, tmp_seeds = \
self._get_multiple_lines_both_directions(n_seeds)
Expand Down
24 changes: 16 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
# Supported for python 3.10
# Should work for python > 3.8.

# Scilpy and comet_ml both require requests. In comet: >=2.18.*,
# which installs a version >2.28. Adding request version explicitely.
# -------
# Main dependency: scilpy
# Scilpy and comet_ml both require requests. In comet: >=2.18.*,
# which installs a version >2.28. Adding request version explicitely.
# -------
requests==2.28.*
-e git+https://github.com/scilus/scilpy.git#egg=scilpy


# -------
# Other important dependencies
# -------
bertviz~=1.4.0 # For transformer's visu
torch==1.13.*
tqdm==4.64.*
comet-ml==3.21.*
contextlib2==21.6.0
jupyterlab>=3.6.2 # For transformer's visu
IProgress>=0.4 # For jupyter with tdqm
nested_lookup==0.2.25
nose==1.3.*
scilpy==1.5.post2
nested_lookup==0.2.25 # For lists management

## Necessary but should be installed with scilpy (Last check: 09/2023):
# -------
# Necessary but should be installed with scilpy (Last check: 01/2024):
# -------
future==0.18.*
h5py==3.7.* # h5py must absolutely be >2.4: that's when it became thread-safe
matplotlib==3.6.* # Hint: If matplotlib fails, you may try to install pyQt5.
nibabel==4.0.*
nibabel==5.2.*
numpy==1.23.*
scipy==1.9.*



# --------------- Notes to developers
# If we upgrade torch, verify if code copied in
# models.projects.transformers_from_torch has changed.
Expand Down
10 changes: 8 additions & 2 deletions scripts_python/l2t_visualize_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,14 @@ def main():

# 1. Load model
logging.debug("Loading model.")
model = Learn2TrackModel.load_model_from_params_and_state(
args.experiment_path + '/best_model', log_level=sub_logger_level)
if args.use_latest_epoch:
model = Learn2TrackModel.load_model_from_params_and_state(
args.experiment_path + '/checkpoint/model',
log_level=sub_logger_level)
else:
model = Learn2TrackModel.load_model_from_params_and_state(
args.experiment_path + '/best_model',
log_level=sub_logger_level)

# 2. Compute loss
tester = TesterOneInput(args.experiment_path, model, args.batch_size, device)
Expand Down
Loading