Skip to content

Commit

Permalink
Add entropy option
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Oct 5, 2023
1 parent 08d3939 commit 36b1c00
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
4 changes: 2 additions & 2 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,15 +807,15 @@ class SingleGaussianDG(AbstractDirectionGetterModel):
USE GRADIENT CLIPPING **AND** low learning rate.
"""
def __init__(self, normalize_targets: float = None,
**kwargs):
entropy_weight: float = 0.0, **kwargs):
# 3D gaussian supports compressed streamlines
super().__init__(key='gaussian',
supports_compressed_streamlines=True,
loss_description='negative log-likelihood',
**kwargs)

self.normalize_targets = normalize_targets
self.entropy_weight = 0.0 # Not tested yet.
self.entropy_weight = entropy_weight

# Layers
# 3 values as mean, 3 values as sigma
Expand Down
36 changes: 25 additions & 11 deletions dwi_ml/models/utils/direction_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,29 @@ def add_direction_getter_args(p: ArgumentParser, gaussian_fisher_args=True):

# Gaussian models, Fisher-von-Mises models
if gaussian_fisher_args:
p.add_argument(
'--add_entropy_to_gauss', nargs='?', const=1.0, type=float,
metavar='f',
help="For GAUSSIAN models: If set, adds the entropy to the negative "
"log-likelihood \nloss. By defaut, weight is 1.0, but a "
"value >1 can be added \n to increase its influence.")
p.add_argument(
'--dg_nb_gaussians', type=int, metavar='n',
help="Number of gaussians in the case of a Gaussian Mixture model "
"for the direction \ngetter. [3]")
help="For GAUSSIAN models: Number of gaussians in the case of a "
"mixture model. [3]")
p.add_argument(
'--dg_nb_clusters', type=int,
help="Number of clusters in the case of a Fisher von Mises "
"Mixture model for the direction \ngetter. [3].")
help="For FISHER VON MISES models: Number of clusters in the case "
"of a mixture model for the direction \ngetter. [3]")
p.add_argument(
'--normalize_targets', const=1., nargs='?', type=float,
metavar='norm',
help="For REGRESSION models: If set, target directions will be "
help="For REGRESSION models: If set, target directions will be "
"normalized before \ncomputing the loss. Default norm: 1.")
p.add_argument(
'--normalize_outputs', const=1., nargs='?', type=float,
metavar='norm',
help="For REGRESSION models: If set, model outputs will be "
help="For REGRESSION models: If set, model outputs will be "
"normalized. Default norm: 1.")

# EOS
Expand Down Expand Up @@ -82,14 +88,22 @@ def check_args_direction_getter(args):
if args.dg_dropout < 0 or args.dg_dropout > 1:
raise ValueError('The dg dropout rate must be between 0 and 1.')

# Gaussian additional arg = nb_gaussians.
# Gaussian additional arg = nb_gaussians and entropy_weight.
if args.dg_key == 'gaussian-mixture':
if args.dg_nb_gaussians:
dg_args.update({'nb_gaussians': args.dg_nb_gaussians})
elif args.dg_nb_gaussians:
logging.warning("You have provided a value for --dg_nb_gaussians but "
"the chosen direction getter is not the gaussian "
"mixture. Ignored.")
if args.add_entropy_to_gauss:
dg_args.update({'entroy_weight': args.add_entropy_to_gauss})

else:
if args.dg_nb_gaussians:
logging.warning("You have provided a value for --dg_nb_gaussians "
"but the chosen direction getter is not the "
"gaussian mixture. Ignored.")
if args.add_entropy_to_gauss:
logging.warning("You have provided a value for --add_entropy_to_gauss "
"but the chosen direction getter is not the "
"gaussian mixture. Ignored.")

# Fisher additional arg = nb_clusters
if args.dg_key == 'fisher-von-mises-mixture':
Expand Down
1 change: 0 additions & 1 deletion dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ def back_propagation(self, loss):

# Supervizing the gradient's norm.
grad_norm = compute_gradient_norm(self.model.parameters())
logging.info(" Grad norm {}. Loss: {}".format(grad_norm, loss))

# Update parameters
# Future work: We could update only every n steps.
Expand Down

0 comments on commit 36b1c00

Please sign in to comment.