diff --git a/dwi_ml/models/direction_getter_models.py b/dwi_ml/models/direction_getter_models.py index a7841d91..cf535700 100644 --- a/dwi_ml/models/direction_getter_models.py +++ b/dwi_ml/models/direction_getter_models.py @@ -807,7 +807,7 @@ class SingleGaussianDG(AbstractDirectionGetterModel): USE GRADIENT CLIPPING **AND** low learning rate. """ def __init__(self, normalize_targets: float = None, - **kwargs): + entropy_weight: float = 0.0, **kwargs): # 3D gaussian supports compressed streamlines super().__init__(key='gaussian', supports_compressed_streamlines=True, @@ -815,7 +815,7 @@ def __init__(self, normalize_targets: float = None, **kwargs) self.normalize_targets = normalize_targets - self.entropy_weight = 0.0 # Not tested yet. + self.entropy_weight = entropy_weight # Layers # 3 values as mean, 3 values as sigma diff --git a/dwi_ml/models/utils/direction_getters.py b/dwi_ml/models/utils/direction_getters.py index 3ce65a3c..a537a80a 100644 --- a/dwi_ml/models/utils/direction_getters.py +++ b/dwi_ml/models/utils/direction_getters.py @@ -35,23 +35,29 @@ def add_direction_getter_args(p: ArgumentParser, gaussian_fisher_args=True): # Gaussian models, Fisher-von-Mises models if gaussian_fisher_args: + p.add_argument( + '--add_entropy_to_gauss', nargs='?', const=1.0, type=float, + metavar='f', + help="For GAUSSIAN models: If set, adds the entropy to the negative " + "log-likelihood \nloss. By defaut, weight is 1.0, but a " + "value >1 can be added \n to increase its influence.") p.add_argument( '--dg_nb_gaussians', type=int, metavar='n', - help="Number of gaussians in the case of a Gaussian Mixture model " - "for the direction \ngetter. [3]") + help="For GAUSSIAN models: Number of gaussians in the case of a " + "mixture model. [3]") p.add_argument( '--dg_nb_clusters', type=int, - help="Number of clusters in the case of a Fisher von Mises " - "Mixture model for the direction \ngetter. [3].") + help="For FISHER VON MISES models: Number of clusters in the case " + "of a mixture model for the direction \ngetter. [3]") p.add_argument( '--normalize_targets', const=1., nargs='?', type=float, metavar='norm', - help="For REGRESSION models: If set, target directions will be " + help="For REGRESSION models: If set, target directions will be " "normalized before \ncomputing the loss. Default norm: 1.") p.add_argument( '--normalize_outputs', const=1., nargs='?', type=float, metavar='norm', - help="For REGRESSION models: If set, model outputs will be " + help="For REGRESSION models: If set, model outputs will be " "normalized. Default norm: 1.") # EOS @@ -82,14 +88,22 @@ def check_args_direction_getter(args): if args.dg_dropout < 0 or args.dg_dropout > 1: raise ValueError('The dg dropout rate must be between 0 and 1.') - # Gaussian additional arg = nb_gaussians. + # Gaussian additional arg = nb_gaussians and entropy_weight. if args.dg_key == 'gaussian-mixture': if args.dg_nb_gaussians: dg_args.update({'nb_gaussians': args.dg_nb_gaussians}) - elif args.dg_nb_gaussians: - logging.warning("You have provided a value for --dg_nb_gaussians but " - "the chosen direction getter is not the gaussian " - "mixture. Ignored.") + if args.add_entropy_to_gauss: + dg_args.update({'entroy_weight': args.add_entropy_to_gauss}) + + else: + if args.dg_nb_gaussians: + logging.warning("You have provided a value for --dg_nb_gaussians " + "but the chosen direction getter is not the " + "gaussian mixture. Ignored.") + if args.add_entropy_to_gauss: + logging.warning("You have provided a value for --add_entropy_to_gauss " + "but the chosen direction getter is not the " + "gaussian mixture. Ignored.") # Fisher additional arg = nb_clusters if args.dg_key == 'fisher-von-mises-mixture': diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index f4a082b9..ecb2558a 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -796,7 +796,6 @@ def back_propagation(self, loss): # Supervizing the gradient's norm. grad_norm = compute_gradient_norm(self.model.parameters()) - logging.info(" Grad norm {}. Loss: {}".format(grad_norm, loss)) # Update parameters # Future work: We could update only every n steps.