Skip to content

Commit

Permalink
Merge pull request scil-vital#230 from EmmaRenauld/fisher
Browse files Browse the repository at this point in the history
Finish Fisher von Mises loss
  • Loading branch information
EmmaRenauld authored Apr 8, 2024
2 parents a415f87 + 284f1c8 commit f8166eb
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 78 deletions.
127 changes: 92 additions & 35 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def init_2layer_fully_connected(input_size: int, output_size: int):


def binary_cross_entropy_eos(learned_eos, target_eos, average_results=True):
reduction = 'none'
if average_results:
reduction = 'mean'
reduction = 'mean' if average_results else 'none'

learned_eos = torch.sigmoid(learned_eos)
losses_eos = torch.nn.functional.binary_cross_entropy(
Expand Down Expand Up @@ -80,7 +78,7 @@ class AbstractDirectionGetterModel(torch.nn.Module):
-----------------------
"""
def __init__(self, input_size: int, key: str,
supports_compressed_streamlines: bool, dropout: float = None,
supports_compressed_streamlines: bool, dropout: float = None,
compress_loss: bool = False, compress_eps: float = 1e-3,
weight_loss_with_angle: bool = False,
loss_description: str = '', add_eos: bool = False,
Expand Down Expand Up @@ -210,8 +208,8 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):
def compute_loss(self, outputs: List[Tensor],
target_streamlines: List[Tensor], average_results=True):
if self.compress_loss and not average_results:
raise ValueError("Current implementation of compress_loss does not "
"allow returning non-averaged loss.")
raise ValueError("Current implementation of compress_loss does "
"not allow returning non-averaged loss.")

# Compute directions
target_dirs = compute_directions(target_streamlines)
Expand Down Expand Up @@ -272,8 +270,9 @@ def stack_batch(outputs, target_dirs):
outputs = torch.vstack(outputs)
return outputs, target_dirs

def _compute_loss(self, outputs: Tensor, target_dirs: Tensor,
average_results=True) -> Union[Tuple[Tensor, int], Tensor]:
def _compute_loss(
self, outputs: Tensor, target_dirs: Tensor,
average_results=True) -> Union[Tuple[Tensor, int], Tensor]:
"""
Expecting a single tensor.
Expand Down Expand Up @@ -762,7 +761,7 @@ def _compute_loss(self, logits_per_class: Tensor, targets_probs: Tensor,
# buggy: reduction is supposed to be a str but if I send 'none', it
# says that it expects an int.)
# Gives the same result as above, but averaged instead of summed.
# The real definition is integral (i.e. sum). Typically for our
# The real definition is integral (i.e. sum). Typically, for our
# data (724 classes), that's a big difference: from values ~7 to values
# around 0.04. Nicer for visu with sum.
# So, avoiding torch's 'mean' reduction; reducing ourselves.
Expand All @@ -775,7 +774,8 @@ def _compute_loss(self, logits_per_class: Tensor, targets_probs: Tensor,

# Integral over classes per point.
kl_loss = KLDivLoss(reduction='none', log_target=False)
nll_losses = torch.sum(kl_loss(logits_per_class, targets_probs), dim=-1)
nll_losses = torch.sum(kl_loss(logits_per_class, targets_probs),
dim=-1)

if average_results:
return _mean_and_weight(nll_losses)
Expand Down Expand Up @@ -893,8 +893,10 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor],
"""
# 1. Main loss
means, sigmas = learned_gaussian_params
learned_eos = means[:, -1]
means = means[:, 0:3]
learned_eos = None
if self.add_eos:
learned_eos = means[:, -1]
means = means[:, 0:3]

# Create an official function-probability distribution from the means
# and variances
Expand All @@ -905,7 +907,8 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor],
if self.entropy_weight > 0:
# Trying to ensure that sigma values are not too small.
# Entropy values range between 0 and log(K). 0 = high probability.
# We want a high entropy / low certainty = we will minimize -entropy.
# We want a high entropy / low certainty = we will minimize
# -entropy.
entropy = distribution.entropy()
logging.info("Computing batch loss with sigma {}, entropy: {}"
.format(torch.mean(sigmas), torch.mean(entropy)))
Expand All @@ -918,11 +921,11 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor],
# 2. EOS loss:
if self.add_eos:
# Binary cross-entropy
loss_eos = binary_cross_entropy_eos(learned_eos, target_dirs[:, -1],
loss_eos = binary_cross_entropy_eos(learned_eos,
target_dirs[:, -1],
average_results)
return nll_loss + self.eos_weight * loss_eos, n
else:
n = 1
return nll_loss, n

def _sample_tracking_direction_prob(
Expand Down Expand Up @@ -956,6 +959,7 @@ def _get_tracking_direction_det(self, learned_gaussian_params: Tensor,
Get the predicted class with highest logits (=probabilities).
"""
# Returns the direction of the max of the Gaussian = the mean.
# Not using sigma
means, sigmas = learned_gaussian_params
dirs = means[:, 0:3]

Expand Down Expand Up @@ -1156,43 +1160,68 @@ def __init__(self, **kwargs):
loss_description='negative log-likelihood',
**kwargs)

if self.add_eos:
raise NotImplementedError
self.layers_mean = init_2layer_fully_connected(self.input_size, 3)
# Layers
# 3 values as mean, 1 value as kappa
# If EOS: Adding it to the mean layer. Could be separated.
oneifeos = 1 if self.add_eos else 0
self.layers_mean = init_2layer_fully_connected(self.input_size,
3 + oneifeos)
self.layers_kappa = init_2layer_fully_connected(self.input_size, 1)

self.output_size = 4
# Loss will be defined in _compute_loss, using torch distribution

def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):
"""
Should be called before _compute_loss, before concatenating your
streamlines.
Returns: list[Tensors], the directions.
"""
# Need to normalize before adding EOS labels (dir = 0,0,0)
target_dirs = normalize_directions(target_dirs)
return add_label_as_last_dim(target_dirs, add_sos=False,
add_eos=self.add_eos)

def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
"""Run the inputs through the fully-connected layer.
Returns
-------
means : torch.Tensor with shape [batch_size x 3]
?
mus : torch.Tensor with shape [batch_size x 3]
The 3D coordinate of the mean.
kappas : torch.Tensor with shape [batch_size x 1]
?
The kappa concentration parameter.
"""
means = self.loop_on_layers(inputs, self.layers_mean)
mu = self.loop_on_layers(inputs, self.layers_mean)
kappas = self.loop_on_layers(inputs, self.layers_kappa)

# mean should be a unit vector for Fisher Von-Mises distribution
means = torch.nn.functional.normalize(means, dim=-1)
# (Using [0:3] only; EOS value does not need to be normalized).
# Simple code line raises an error: inplace operation
# mu[0:3] = torch.nn.functional.normalize(mu[0:3], dim=-1)
learned_eos = None
if self.add_eos:
learned_eos = mu[:, 3][:, None]
mu = mu[:, 0:3]
mu = torch.nn.functional.normalize(mu, dim=-1)
if self.add_eos:
mu = torch.hstack((mu, learned_eos))

# Need to restrict kappa to a certain range, e.g. [0, 20]
unbound_kappa = self.loop_on_layers(inputs, self.layers_kappa)
kappas = torch.sigmoid(unbound_kappa) * 20
kappas = torch.sigmoid(kappas) * 20

# Squeeze the trailing dim, the kappa parameter is a scalar
kappas = kappas.squeeze(dim=-1)

return means, kappas
return mu, kappas

@staticmethod
def stack_batch(outputs, target_dirs):
target_dirs = torch.vstack(target_dirs)
mus = torch.vstack(outputs[0])
kappas = torch.vstack(outputs[1])
return (mus, kappas), target_dirs
mu = torch.vstack(outputs[0])
kappa = torch.hstack(outputs[1]) # Not vstack: they are vectors
return (mu, kappa), target_dirs

def _compute_loss(self, learned_fisher_params: Tuple[Tensor, Tensor],
target_dirs, average_results=True):
Expand All @@ -1202,16 +1231,31 @@ def _compute_loss(self, learned_fisher_params: Tuple[Tensor, Tensor],
See the doc for explanation on the formulas:
https://dwi-ml.readthedocs.io/en/latest/formulas.html
"""
# mu.shape : [flattened_sequences, 3]
# mu.shape : [all_point, 4]. 3 first values are x, y, z. Last is EOS.
mu, kappa = learned_fisher_params
learned_eos = None
if self.add_eos:
learned_eos = mu[:, 3]
mu = mu[:, 0:3]

log_prob = fisher_von_mises_log_prob(mu, kappa, target_dirs)
nll_losses = -log_prob
# 1. Main loss
# Note. Mu was already normalized through the forward method.
log_prob = fisher_von_mises_log_prob(mu, kappa, target_dirs[:, 0:3])
nll_loss = -log_prob

n = 1
if average_results:
return _mean_and_weight(nll_losses)
nll_loss, n = _mean_and_weight(nll_loss)

# 2. EOS loss:
if self.add_eos:
# Binary cross-entropy
loss_eos = binary_cross_entropy_eos(learned_eos,
target_dirs[:, -1],
average_results)
return nll_loss + self.eos_weight * loss_eos, n
else:
return nll_losses
return nll_loss, n

def _sample_tracking_direction_prob(
self, learned_fisher_params: Tuple[Tensor, Tensor],
Expand Down Expand Up @@ -1247,7 +1291,20 @@ def _sample_tracking_direction_prob(

def _get_tracking_direction_det(self, learned_fisher_params: Tensor,
eos_stopping_thresh):
raise NotImplementedError
"""
Get the predicted class with highest logits (=probabilities).
"""
# Returns the direction of the max of the Gaussian = the mean.
# Not using sigma
mus, kappas = learned_fisher_params
dirs = mus[:, 0:3]

if self.add_eos:
eos_prob = torch.sigmoid(mus[:, -1])
eos_prob = torch.gt(eos_prob, eos_stopping_thresh)
return torch.masked_fill(dirs, eos_prob[:, None], torch.nan)
else:
return dirs

@staticmethod
def _sample_weight(kappa):
Expand Down
29 changes: 27 additions & 2 deletions dwi_ml/models/utils/fisher_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@


def fisher_von_mises_log_prob_vector(mus, kappa, targets, eps=1e-5):
"""
Same as below, but for a single vector.
Parameters
----------
mus: torch.Tensor
Shape: (3, )
kappa: torch.Tensor
Shape: (1, )
targets: torch.Tensor
Directions. Shape (3, )
eps: float
"""
log_diff_exp_kappa = np.log(
np.maximum(eps, np.exp(kappa) - np.exp(-kappa)))
log_c = np.log(kappa) - np.log(2 * np.pi) - log_diff_exp_kappa
Expand All @@ -18,12 +31,24 @@ def fisher_von_mises_log_prob_vector(mus, kappa, targets, eps=1e-5):


def fisher_von_mises_log_prob(mus, kappa, targets, eps=1e-5):
"""
Fisher von Mises loss for a batch.
Parameters
----------
mus: torch.Tensor
Shape: (n, 3)
kappa: torch.Tensor
Shape: (n, 1)
targets: torch.Tensor
Directions. Shape (n, 3)
eps: float
"""
log_2pi = np.log(2 * np.pi).astype(np.float32)

eps = torch.as_tensor(eps, device=kappa.device, dtype=torch.float32)

# Add an epsilon in case kappa is too small (i.e. a uniform
# distribution)
# Add an epsilon in case kappa is too small (i.e. a uniform distribution)
log_diff_exp_kappa = torch.log(
torch.maximum(eps, torch.exp(kappa) - torch.exp(-kappa)))

Expand Down
15 changes: 6 additions & 9 deletions dwi_ml/tracking/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,10 @@ def add_tracking_options(p):
help="Tracking mask's volume group in the hdf5.")
stop_g.add_argument('--theta', metavar='t', type=float,
default=90,
help="The tracking direction at each step being "
"defined by the model, \ntheta arg can't define "
"allowed directions in the tracking field.\n"
"Rather, this new equivalent angle, is used as "
"\na stopping criterion during propagation: "
help="Stopping criterion during propagation: "
"tracking \nis stopped when a direction is more "
"than an angle t from preceding direction")
"than an angle t from \npreceding direction."
"[%(default)s]")
stop_g.add_argument('--eos_stop', metavar='prob',
help="Stopping criterion if a EOS value was learned "
"during training. \nCan either be a probability "
Expand All @@ -80,9 +77,9 @@ def add_tracking_options(p):
"probability, no mather its value.")
stop_g.add_argument(
'--discard_last_point', action='store_true',
help="If set, discard the last point (once out of the tracking mask) \n"
"of the streamline. Default: append them. This is the default in \n"
"Dipy too. Note that points obtained after an invalid direction \n"
help="If set, discard the last point (once out of the tracking mask)\n"
"of the streamline. Default: append them. This is the default in\n"
"Dipy too. Note that points obtained after an invalid direction\n"
"(based on the propagator's definition of invalid; ex when \n"
"angle is too sharp of sh_threshold not reached) are never added.")

Expand Down
Loading

0 comments on commit f8166eb

Please sign in to comment.