Skip to content

Commit

Permalink
Merge pull request #220 from arnaudbore/add_autoencoder_streamlines
Browse files Browse the repository at this point in the history
[NF] Auto-encoders - streamlines - FINTA
  • Loading branch information
EmmaRenauld authored Oct 2, 2024
2 parents 3ebc194 + fc55338 commit 04639fa
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 17 deletions.
4 changes: 1 addition & 3 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
raise ValueError(
"The data_per_streamline key '{}' was not found in "
"the sft. Check your tractogram file.".format(dps_key))

logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key))
streamlines_group.create_dataset('dps_' + dps_key,
data=sft.data_per_streamline[dps_key])
Expand Down Expand Up @@ -669,8 +669,6 @@ def _process_one_streamline_group(
Reference used to load and send the streamlines in voxel space and
to create final merged SFT. If the file is a .trk, 'same' is used
instead.
remove_invalid : bool
If True, invalid streamlines will be removed
Returns
-------
Expand Down
6 changes: 3 additions & 3 deletions dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def resample_or_compress(sft, step_size_mm: float = None,
nb_points: int = None,
compress: float = None,
remove_invalid: bool = False):
if step_size_mm is not None:
if step_size_mm:
# Note. No matter the chosen space, resampling is done in mm.
logging.debug(" Resampling (step size): {}mm".format(step_size_mm))
sft = resample_streamlines_step_size(sft, step_size=step_size_mm)
elif nb_points is not None:
elif nb_points:
logging.debug(" Resampling: " +
"{} points per streamline".format(nb_points))
sft = resample_streamlines_num_points(sft, nb_points)
elif compress is not None:
elif compress:
logging.debug(" Compressing: {}".format(compress))
sft = compress_sft(sft, compress)

Expand Down
1 change: 1 addition & 0 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def params_for_checkpoint(self):
'experiment_name': self.experiment_name,
'step_size': self.step_size,
'compress_lines': self.compress_lines,
'nb_points': self.nb_points,
}

@property
Expand Down
173 changes: 173 additions & 0 deletions dwi_ml/models/projects/ae_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
import logging
from typing import List

import torch
from torch.nn import functional as F

from dwi_ml.models.main_models import MainModelAbstract


class ModelAE(MainModelAbstract):
"""
Recurrent tracking model.
Composed of an embedding for the imaging data's input + for the previous
direction's input, an RNN model to process the sequences, and a direction
getter model to convert the RNN outputs to the right structure, e.g.
deterministic (3D vectors) or probabilistic (based on probability
distribution parameters).
"""
def __init__(self,
experiment_name: str,
step_size: float = None,
nb_points: int = None,
compress_lines: float = False,
# Other
log_level=logging.root.level):
super().__init__(experiment_name,
step_size=step_size,
nb_points=nb_points,
compress_lines=compress_lines,
log_level=log_level)

self.kernel_size = 3
self.latent_space_dims = 32

self.pad = torch.nn.ReflectionPad1d(1)

def pre_pad(m):
return torch.nn.Sequential(self.pad, m)

self.fc1 = torch.nn.Linear(8192,
self.latent_space_dims) # 8192 = 1024*8
self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192)

"""
Encode convolutions
"""
self.encod_conv1 = pre_pad(
torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0)
)
self.encod_conv2 = pre_pad(
torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0)
)
self.encod_conv3 = pre_pad(
torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0)
)
self.encod_conv4 = pre_pad(
torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0)
)
self.encod_conv5 = pre_pad(
torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0)
)
self.encod_conv6 = pre_pad(
torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0)
)

"""
Decode convolutions
"""
self.decod_conv1 = pre_pad(
torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0)
)
self.upsampl1 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv2 = pre_pad(
torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0)
)
self.upsampl2 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv3 = pre_pad(
torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0)
)
self.upsampl3 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv4 = pre_pad(
torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0)
)
self.upsampl4 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv5 = pre_pad(
torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0)
)
self.upsampl5 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv6 = pre_pad(
torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0)
)

def forward(self,
input_streamlines: List[torch.tensor],
):
"""Run the model on a batch of sequences.
Parameters
----------
input_streamlines: List[torch.tensor],
Batch of streamlines. Only used if previous directions are added to
the model. Used to compute directions; its last point will not be
used.
Returns
-------
model_outputs : List[Tensor]
Output data, ready to be passed to either `compute_loss()` or
`get_tracking_directions()`.
"""

x = self.decode(self.encode(input_streamlines))
return x

def encode(self, x):
# x: list of tensors
x = torch.stack(x)
x = torch.swapaxes(x, 1, 2)

h1 = F.relu(self.encod_conv1(x))
h2 = F.relu(self.encod_conv2(h1))
h3 = F.relu(self.encod_conv3(h2))
h4 = F.relu(self.encod_conv4(h3))
h5 = F.relu(self.encod_conv5(h4))
h6 = self.encod_conv6(h5)

self.encoder_out_size = (h6.shape[1], h6.shape[2])

# Flatten
h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1])

fc1 = self.fc1(h7)

return fc1

def decode(self, z):
fc = self.fc2(z)
fc_reshape = fc.view(
-1, self.encoder_out_size[0], self.encoder_out_size[1]
)
h1 = F.relu(self.decod_conv1(fc_reshape))
h2 = self.upsampl1(h1)
h3 = F.relu(self.decod_conv2(h2))
h4 = self.upsampl2(h3)
h5 = F.relu(self.decod_conv3(h4))
h6 = self.upsampl3(h5)
h7 = F.relu(self.decod_conv4(h6))
h8 = self.upsampl4(h7)
h9 = F.relu(self.decod_conv5(h8))
h10 = self.upsampl5(h9)
h11 = self.decod_conv6(h10)

return h11

def compute_loss(self, model_outputs, targets, average_results=True):

targets = torch.stack(targets)
targets = torch.swapaxes(targets, 1, 2)
reconstruction_loss = torch.nn.MSELoss(reduction="sum")
mse = reconstruction_loss(model_outputs, targets)
return mse, 1
4 changes: 3 additions & 1 deletion dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def __init__(self, experiment_name,
neighborhood_type: Optional[str] = None,
neighborhood_radius: Optional[int] = None,
neighborhood_resolution: Optional[float] = None,
log_level=logging.root.level):
log_level=logging.root.level,
nb_points: Optional[int] = None):
"""
Params
------
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(self, experiment_name,
"""
super().__init__(
experiment_name=experiment_name, step_size=step_size,
nb_points=nb_points,
compress_lines=compress_lines, log_level=log_level,
# For modelWithNeighborhood
neighborhood_type=neighborhood_type,
Expand Down
8 changes: 5 additions & 3 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
import logging
from typing import Union, List, Tuple, Optional
from typing import Union, List, Optional

from dipy.data import get_sphere
import numpy as np
Expand Down Expand Up @@ -137,7 +137,8 @@ def __init__(self,
neighborhood_type: Optional[str] = None,
neighborhood_radius: Optional[int] = None,
neighborhood_resolution: Optional[float] = None,
log_level=logging.root.level):
log_level=logging.root.level,
nb_points: Optional[int] = None):
"""
Note about embedding size:
In the original model + SrcOnly model: defines d_model.
Expand Down Expand Up @@ -185,6 +186,7 @@ def __init__(self,
super().__init__(
# MainAbstract
experiment_name=experiment_name, step_size=step_size,
nb_points=nb_points,
compress_lines=compress_lines, log_level=log_level,
# Neighborhood
neighborhood_type=neighborhood_type,
Expand Down Expand Up @@ -610,7 +612,7 @@ def _prepare_data(self, inputs, _):

def _run_embeddings(self, inputs, use_padding, batch_max_len):
return self._run_input_embedding(inputs, use_padding, batch_max_len)

def _run_position_encoding(self, inputs):
inputs = self.position_encoding_layer(inputs)
inputs = self.dropout(inputs)
Expand Down
12 changes: 10 additions & 2 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
logger = logging.getLogger('batch_loader_logger')


class DWIMLAbstractBatchLoader:
class DWIMLStreamlinesBatchLoader:
def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
streamline_group_name: str, rng: int,
split_ratio: float = 0.,
Expand Down Expand Up @@ -197,7 +197,14 @@ def _data_augmentation_sft(self, sft):
self.context_subset.compress == self.model.compress_lines:
logger.debug("Compression rate is the same as when creating "
"the hdf5 dataset. Not compressing again.")
elif self.model.nb_points is not None and self.model.nb_points == self.context_subset.nb_points:
logging.debug("Number of points per streamline is the same"
" as when creating the hdf5. Not resampling again.")
else:
logger.debug("Resample streamlines using: \n" +
"- step_size: {}\n".format(self.model.step_size) +
"- compress_lines: {}".format(self.model.compress_lines) +
"- nb_points: {}".format(self.model.nb_points))
sft = resample_or_compress(sft, self.model.step_size,
self.model.nb_points,
self.model.compress_lines)
Expand Down Expand Up @@ -314,6 +321,7 @@ def load_batch_streamlines(
sft.to_vox()
sft.to_corner()
batch_streamlines.extend(sft.streamlines)

batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines]

return batch_streamlines, final_s_ids_per_subj
Expand Down Expand Up @@ -351,7 +359,7 @@ def load_batch_connectivity_matrices(
connectivity_nb_blocs, connectivity_labels)


class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader):
class DWIMLBatchLoaderOneInput(DWIMLStreamlinesBatchLoader):
"""
Loads:
input = one volume group
Expand Down
48 changes: 43 additions & 5 deletions dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dwi_ml.models.main_models import (MainModelAbstract,
ModelWithDirectionGetter)
from dwi_ml.training.batch_loaders import (
DWIMLAbstractBatchLoader, DWIMLBatchLoaderOneInput)
DWIMLStreamlinesBatchLoader, DWIMLBatchLoaderOneInput)
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.utils.gradient_norm import compute_gradient_norm
from dwi_ml.training.utils.monitoring import (
Expand Down Expand Up @@ -53,7 +53,7 @@ class DWIMLAbstractTrainer:
def __init__(self,
model: MainModelAbstract, experiments_path: str,
experiment_name: str, batch_sampler: DWIMLBatchIDSampler,
batch_loader: DWIMLAbstractBatchLoader,
batch_loader: DWIMLStreamlinesBatchLoader,
learning_rates: Union[List, float] = None,
weight_decay: float = 0.01,
optimizer: str = 'Adam', max_epochs: int = 10,
Expand All @@ -78,7 +78,7 @@ def __init__(self,
batch_sampler: DWIMLBatchIDSampler
Instantiated class used for sampling batches.
Data in batch_sampler.dataset must be already loaded.
batch_loader: DWIMLAbstractBatchLoader
batch_loader: DWIMLStreamlinesBatchLoader
Instantiated class with a load_batch method able to load data
associated to sampled batch ids. Data in batch_sampler.dataset must
be already loaded.
Expand Down Expand Up @@ -461,7 +461,7 @@ def _prepare_checkpoint_info(self) -> dict:
def init_from_checkpoint(
cls, model: MainModelAbstract, experiments_path, experiment_name,
batch_sampler: DWIMLBatchIDSampler,
batch_loader: DWIMLAbstractBatchLoader,
batch_loader: DWIMLStreamlinesBatchLoader,
checkpoint_state: dict, new_patience, new_max_epochs, log_level):
"""
Loads checkpoint information (parameters and states) to instantiate
Expand Down Expand Up @@ -1013,7 +1013,45 @@ def run_one_batch(self, data):
Any other data returned when computing loss. Not used in the
trainer, but could be useful anywhere else.
"""
raise NotImplementedError
# Data interpolation has not been done yet. GPU computations are done
# here in the main thread.
targets, ids_per_subj = data

# Dataloader always works on CPU. Sending to right device.
# (model is already moved).
targets = [s.to(self.device, non_blocking=True, dtype=torch.float)
for s in targets]

# Uses the model's method, with the batch_loader's data.
# Possibly skipping the last point if not useful.
streamlines_f = targets

# Possibly add noise to inputs here.
logger.debug('*** Computing forward propagation')

# Now possibly add noise to streamlines (training / valid)
streamlines_f = self.batch_loader.add_noise_streamlines_forward(
streamlines_f, self.device)

# Possibly computing directions twice (during forward and loss)
# but ok, shouldn't be too heavy. Easier to deal with multiple
# projects' requirements by sending whole streamlines rather
# than only directions.
model_outputs = self.model(streamlines_f)
del streamlines_f

logger.debug('*** Computing loss')
targets = self.batch_loader.add_noise_streamlines_loss(
targets, self.device)

results = self.model.compute_loss(model_outputs, targets,
average_results=True)

if self.use_gpu:
log_gpu_memory_usage(logger)

# The mean tensor is a single value. Converting to float using item().
return results

def fix_parameters(self):
"""
Expand Down
Loading

0 comments on commit 04639fa

Please sign in to comment.