Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visu transfo #221

Closed
wants to merge 10 commits into from
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ target/
*.swo

# dwi_ml stuff
.ipynb_config/
.ipynb_config/
.ipynb_checkpoints/
11 changes: 11 additions & 0 deletions dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ def get_mri_data(self, subj_idx: int, group_idx: int,

Contrary to get_volume_verify_cache, this does not send data to
cache for later use.

Parameters
----------
subj_idx: int
The subject id.
group_idx: int
The volume group idx.
load_it: bool
If data is lazy, get the volume as a LazyMRIData (False) or load it
as non-lazy (if True).
"""
if self.subjs_data_list.is_lazy:
if load_it:
Expand Down Expand Up @@ -484,6 +494,7 @@ def load_data(self, load_training=True, load_validation=True,
self.streamline_groups = poss_strea_groups
self.streamlines_contain_connectivity = contains_connectivity

self.streamline_groups = list(self.streamline_groups)
group_info = (self.volume_groups, self.nb_features,
self.streamline_groups,
self.streamlines_contain_connectivity)
Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/data/processing/dwi/dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from dipy.reconst.shm import sph_harm_lookup
import nibabel as nib
import numpy as np

from scilpy.io.utils import validate_sh_basis_choice
from scilpy.reconst.raw_signal import compute_sh_coefficients
from scilpy.reconst.sh import compute_sh_coefficients

eps = 1e-6

Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from nibabel.streamlines.tractogram import (PerArrayDict, PerArraySequenceDict)
import numpy as np

from scilpy.tracking.tools import resample_streamlines_step_size
from scilpy.tractograms.streamline_operations import \
resample_streamlines_step_size
from scilpy.utils.streamlines import compress_sft


Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/data/processing/streamlines/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from scilpy.tractanalysis.tools import \
extract_longest_segments_from_profile as segmenting_func
from scilpy.tractanalysis.uncompress import uncompress
from scilpy.tractograms.uncompress import uncompress

# We could try using nan instead of zeros for non-existing previous dirs...
DEFAULT_UNEXISTING_VAL = torch.zeros((1, 3), dtype=torch.float32)
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def add_memory_args(p: ArgumentParser, add_lazy_options=False,
add_processes_arg(ram_options)
ram_options.add_argument(
'--use_gpu', action='store_true',
help="If set, use GPU for processing. Cannot be used \ntogether "
help="If set, use GPU for processing. Cannot be used together "
"with --processes.")
else:
p.add_argument('--use_gpu', action='store_true',
Expand Down
20 changes: 20 additions & 0 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,26 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):

def compute_loss(self, outputs: List[Tensor],
target_streamlines: List[Tensor], average_results=True):
"""
Parameters
----------
outputs: List[Tensor]
Your model's outputs
target_streamlines: List[Tensor]
The streamlines. Directions will be computed and formatted based
on child class requirements.
average_results: bool
If true, returns the average over all values.

Returns
-------
If compress_loss or average_results:
Tuple(tensor, n)
The average loss and the n points averaged.
Else:
List[Tensor]
The loss for each point in each streamline.
"""
if self.compress_loss and not average_results:
raise ValueError("Current implementation of compress_loss does not "
"allow returning non-averaged loss.")
Expand Down
40 changes: 40 additions & 0 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,22 @@ def forward(self, inputs, streamlines):
def compute_loss(self, model_outputs, target_streamlines):
raise NotImplementedError

def merge_batches_outputs(self, all_outputs, new_batch):
"""
To be used at testing time. At training or validation time, outputs are
discarded after each batch; only the loss is measured. At testing time,
it may be necessary to merge batches. The way to do it will depend on
your model's format.

Parameters
----------
all_outputs: Any or None
All previous outputs from previous batches, already combined.
new_batch:
The batch to merge
"""
raise NotImplementedError


class ModelWithNeighborhood(MainModelAbstract):
"""
Expand Down Expand Up @@ -752,3 +768,27 @@ def compute_loss(self, model_outputs: List[Tensor], target_streamlines,
def move_to(self, device):
super().move_to(device)
self.direction_getter.move_to(device)

def merge_batches_outputs(self, all_outputs, new_batch, device=None):
# 1. Send new batch to device
if 'gaussian' in self.direction_getter.key:
# all_outputs = (means, sigmas)
new_batch = ([m.to(device=device) for m in new_batch[0]],
[s.to(device=device) for s in new_batch[1]])
else:
new_batch = [a.to(device) for a in new_batch]

# 2. Concat
if all_outputs is None:
return new_batch
else:
if 'gaussian' in self.direction_getter.key:
# all_outputs = (means, sigmas)
all_outputs[0].extend(new_batch[0])
all_outputs[1].extend(new_batch[1])
elif 'fisher' in self.direction_getter.key:
raise NotImplementedError
else:
# all_outputs = a list of tensors per streamline.
all_outputs.extend(new_batch)
return all_outputs
7 changes: 2 additions & 5 deletions dwi_ml/models/projects/copy_previous_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,5 @@ def compute_loss(self, model_outputs: List[torch.Tensor],
if self.skip_first_point:
target_streamlines = [t[1:] for t in target_streamlines]

if self._context == 'visu':
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)
else:
raise NotImplementedError
return self.direction_getter.compute_loss(
model_outputs, target_streamlines, average_results)
15 changes: 10 additions & 5 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def __init__(self, experiment_name,
self.instantiate_direction_getter(self.rnn_model.output_size)

def set_context(self, context):
# Training, validation: Used by trainer. Nothing special.
# Tracking: Used by tracker. Returns only the last point.
# Preparing_backward: Used by tracker. Nothing special, but does
# not return only the last point.
# Visu: Nothing special. Used by tester.
assert context in ['training', 'validation', 'tracking', 'visu',
'preparing_backward']
self._context = context
Expand Down Expand Up @@ -253,7 +258,7 @@ def forward(self, x: List[torch.tensor],
"""
# Reminder.
# Correct interpolation and management of points should be done before.
if self._context is None:
if self.context is None:
raise ValueError("Please set context before usage.")

# Right now input is always flattened (interpolation is implemented
Expand All @@ -268,7 +273,7 @@ def forward(self, x: List[torch.tensor],
# Making sure we can use default 'enforce_sorted=True' with packed
# sequences.
unsorted_indices = None
if not self._context == 'tracking':
if not self.context == 'tracking':
# Ordering streamlines per length.
lengths = torch.as_tensor([len(s) for s in x])
_, sorted_indices = torch.sort(lengths, descending=True)
Expand Down Expand Up @@ -351,7 +356,7 @@ def forward(self, x: List[torch.tensor],
x = x + copy_prev_dir

# Unpacking.
if not self._context == 'tracking':
if not self.context == 'tracking':
# (during tracking: keeping as one single tensor.)
if 'gaussian' in self.dg_key or 'fisher' in self.dg_key:
# Separating mean, sigmas (gaussian) or mean, kappa (fisher)
Expand All @@ -370,7 +375,7 @@ def forward(self, x: List[torch.tensor],
if return_hidden:
# Return the hidden states too. Necessary for the generative
# (tracking) part, done step by step.
if not self._context == 'tracking':
if not self.context == 'tracking':
# (ex: when preparing backward tracking.
# Must also re-sort hidden states.)
if self.rnn_model.rnn_torch_key == 'lstm':
Expand Down Expand Up @@ -402,7 +407,7 @@ def copy_prev_dir(self, dirs):
# Converting the input directions into classes the same way as
# during loss, but convert to one-hot.
# The first previous dir (0) converts to index 0.
if self._context == 'tracking':
if self.context == 'tracking':
if dirs[0].shape[0] == 0:
copy_prev_dir = torch.zeros(
len(dirs),
Expand Down
Loading
Loading