Skip to content

Commit

Permalink
Merge pull request #201 from EmmaRenauld/Add_weight_to_EOS_loss
Browse files Browse the repository at this point in the history
Add the possibility to put more weight to EOS (except classif)
  • Loading branch information
EmmaRenauld authored Sep 21, 2023
2 parents 1bc864c + c5a93d3 commit 29f8f84
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 30 deletions.
77 changes: 47 additions & 30 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,45 @@ def __init__(self, input_size: int, key: str,
supports_compressed_streamlines: bool, dropout: float = None,
compress_loss: bool = False, compress_eps: float = 1e-3,
weight_loss_with_angle: bool = False,
loss_description: str = '', add_eos: bool = False):
loss_description: str = '', add_eos: bool = False,
eos_weight: float = 1.0):
"""
Parameters
----------
input_size: int
Should be computed directly. Probably the output size of the first
layers of your main model.
dropout: float
Dropout rate.
key: str
The (children) class's key.
supports_compressed_streamlines: bool
Whether this model supports compressed streamlines.
dropout: float
Dropout rate. Usage depends on the child class.
compress_loss: bool
If set, compress the loss. This is used idependently of the state
If set, compress the loss. This is used independently of the state
of the streamlines received (compressed or resampled).
weight_loss_with_angle: bool
If set, weight loss with local angle. Can't be used together with
compress_loss.
compress_eps: float
Compression threshold. As long as the angle is smaller than eps
(in rad), the next points' loss are averaged together.
weight_loss_with_angle: bool
If set, weight loss with local angle. Can't be used together with
compress_loss.
loss_description: str
Only meant to help users.
add_eos: bool
If true, child class should manage EOS.
eos_weight: float
If add_eos, proportion of the loss for EOS when it is calculated
separately. ** Cannot be used with classification.
Final loss will be: loss + eos_weight * eos.
"""
super().__init__()

self.input_size = input_size
self.output_size = None
self.device = None
self.add_eos = add_eos
self.eos_weight = eos_weight
self.compress_loss = compress_loss
self.compress_eps = compress_eps
self.weight_loss_with_angle = weight_loss_with_angle
Expand Down Expand Up @@ -143,6 +155,12 @@ def params(self):
'input_size': self.input_size,
'dropout': self.dropout,
'key': self.key,
'add_eos': self.add_eos,
'eos_weight': self.eos_weight,
'compress_loss': self.compress_loss,
'compress_eps': self.compress_eps,
'weight_loss_with_angle': self.weight_loss_with_angle,
'loss_description': self.loss_description
}

return params
Expand Down Expand Up @@ -182,26 +200,33 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):
def compute_loss(self, outputs: List[Tensor],
target_streamlines: List[Tensor], average_results=True):
if self.compress_loss and not average_results:
logging.warning("Current implementation of compress_loss does not "
"allow returning non-averaged loss.")
raise ValueError("Current implementation of compress_loss does not "
"allow returning non-averaged loss.")

# Compute directions
target_dirs = compute_directions(target_streamlines)

# For compress_loss: remember raw target dirs.
line_dirs = None
if self.weight_loss_with_angle or self.compress_loss:
line_dirs = [t.detach().clone() for t in target_dirs]

# Modify directions based on child model requirements.
# Ex: Add eos label. Convert to classes. Etc.
target_dirs = self._prepare_dirs_for_loss(target_dirs)
lengths = [len(t) for t in target_dirs]

# For compress_loss and weight_with_angle, do not average now, we
# will do our own averaging.
tmp_average_results = average_results
if self.compress_loss or self.weight_loss_with_angle:
# Compress: we will do our own average.
# Weight: we will weight before averaging.
tmp_average_results = False

# Stack and compute loss based on child model's loss definition.
outputs, target_dirs = self.stack_batch(outputs, target_dirs)
loss = self._compute_loss(outputs, target_dirs, tmp_average_results)

# Finalize
if self.weight_loss_with_angle:
loss = list(torch.split(loss, lengths))
if self.add_eos:
Expand Down Expand Up @@ -321,6 +346,8 @@ class AbstractRegressionDG(AbstractDirectionGetterModel):
layers are:
1. Linear1: output size = ceil(input_size/2)
2. Linear2: output size = 3
EOS usage: uses a 4th dimension to targets to learn the SOS label.
"""
def __init__(self, normalize_targets: float = False,
normalize_outputs: float = False, **kwargs):
Expand All @@ -332,8 +359,6 @@ def __init__(self, normalize_targets: float = False,
normalize_outputs: float
Value to which to normalize the learned direction.
Default: 0 (no normalization).
add_eos: bool
If true, add a 4th dimension to learn the SOS label.
"""
super().__init__(**kwargs)

Expand All @@ -351,7 +376,6 @@ def params(self):
p.update({
'normalize_targets': self.normalize_targets,
'normalize_outputs': self.normalize_outputs,
'add_eos': self.add_eos
})
return p

Expand All @@ -376,10 +400,6 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):
return add_label_as_last_dim(target_dirs, add_sos=False,
add_eos=self.add_eos)

def _compute_loss_dir(self, learned_directions: Tensor,
target_directions: Tensor):
raise NotImplementedError

def _compute_loss(self, learned_directions: Tensor, target_dirs: Tensor,
average_results=True):

Expand All @@ -401,11 +421,11 @@ def _compute_loss(self, learned_directions: Tensor, target_dirs: Tensor,
mean_loss_eos = torch.nn.functional.binary_cross_entropy(
learned_eos, target_dirs[:, 3])
mean_loss_dir, n = _mean_and_weight(losses_dirs)
return mean_loss_dir + mean_loss_eos, n
return mean_loss_dir + self.eos_weight * mean_loss_eos, n
else:
losses_eos = torch.nn.functional.binary_cross_entropy(
learned_eos, target_dirs[:, 3], reduction='none')
return losses_eos + losses_dirs
return losses_dirs + self.eos_weight * losses_eos
elif average_results:
return _mean_and_weight(losses_dirs)
else:
Expand Down Expand Up @@ -528,8 +548,6 @@ def __init__(self, sphere: str = 'symmetric724', **kwargs):
"""
sphere: str
An choice of dipy's Sphere.
add_eos_class: bool
If true, adds a class for EOS.
"""
super().__init__(**kwargs)

Expand All @@ -540,6 +558,12 @@ def __init__(self, sphere: str = 'symmetric724', **kwargs):
self.output_size = sphere.vertices.shape[0] # nb_classes

# EOS
if self.eos_weight != 1.0:
raise NotImplementedError(
"Current EOS computation when using classification cannot "
"be used with eos_weight: Loss is computed all at once on "
"all classes, including EOS.")

if self.add_eos:
self.output_size += 1
self.eos_class_idx = self.output_size - 1 # Last class. Idx -1.
Expand All @@ -558,20 +582,12 @@ def move_to(self, device):
super().move_to(device)
self.torch_sphere.move_to(device)

def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):
"""
Needs to convert directions to classes.
To be defined by child classes.
"""
raise NotImplementedError

@property
def params(self):
params = super().params

params.update({
'sphere': self.sphere_name,
'add_eos': self.add_eos,
})
return params

Expand Down Expand Up @@ -790,6 +806,7 @@ def __init__(self, **kwargs):
self.layers_sigmas = init_2layer_fully_connected(self.input_size, 3)

if self.add_eos:
# Don't forget to add eos_weight.
raise NotImplementedError

self.output_size = 6
Expand Down
6 changes: 6 additions & 0 deletions dwi_ml/models/utils/direction_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,17 @@ def add_direction_getter_args(p: ArgumentParser, gaussian_fisher_args=True):
" 2) In REGRESSION models: Adds a fourth dimension during "
"prediction.\n"
" In CLASSIFICATION models, adds an additional EOS class.\n")
p.add_argument(
'--eos_weight', type=float, default=1.0, metavar='w',
help="In the case of regression, Gaussian and Fisher von Mises models: "
"defines the \nweight of the EOS loss: "
"final_loss = loss + weight * eos_loss")


def check_args_direction_getter(args):
dg_args = {'dropout': args.dg_dropout,
'add_eos': args.add_eos,
'eos_weight': args.eos_weight,
'compress_loss': args.compress_loss is not None,
'compress_eps': args.compress_loss,
'weight_loss_with_angle': args.weight_loss_with_angle,
Expand Down

0 comments on commit 29f8f84

Please sign in to comment.