Skip to content

Commit

Permalink
Fix pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Apr 9, 2024
1 parent af63707 commit 5989e17
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 29 deletions.
15 changes: 9 additions & 6 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def __init__(self,
self.dropout_rate = dropout_rate
self.activation = activation
self.norm_first = norm_first
self.ffnn_hidden_size = ffnn_hidden_size if ffnn_hidden_size is not None \
else self.d_model // 2
self.ffnn_hidden_size = ffnn_hidden_size if ffnn_hidden_size is not \
None else self.d_model // 2

# ----------- Checks
if self.d_model // self.nheads != float(self.d_model) / self.nheads:
Expand All @@ -233,7 +233,8 @@ def __init__(self,

# 2. positional encoding layer
cls_p = keys_to_positional_encodings[self.positional_encoding_key]
self.position_encoding_layer = cls_p(self.d_model, dropout_rate, max_len)
self.position_encoding_layer = cls_p(self.d_model, dropout_rate,
max_len)

# 3. target embedding layer: See child class with Target

Expand Down Expand Up @@ -869,7 +870,8 @@ def params_for_checkpoint(self):
def _run_embeddings(self, data, use_padding, batch_max_len):
# input, targets = data
inputs = self._run_input_embedding(data[0], use_padding, batch_max_len)
targets = self._run_target_embedding(data[1], use_padding, batch_max_len)
targets = self._run_target_embedding(data[1], use_padding,
batch_max_len)
return inputs, targets

def _run_position_encoding(self, data):
Expand Down Expand Up @@ -963,7 +965,8 @@ def params_for_checkpoint(self):
def _run_embeddings(self, data, use_padding, batch_max_len):
# inputs, targets = data
inputs = self._run_input_embedding(data[0], use_padding, batch_max_len)
targets = self._run_target_embedding(data[1], use_padding, batch_max_len)
targets = self._run_target_embedding(data[1], use_padding,
batch_max_len)
inputs = torch.cat((inputs, targets), dim=-1)

return inputs
Expand Down Expand Up @@ -1003,4 +1006,4 @@ def find_transformer_class(model_type: str):
raise ValueError("Model type is not a recognized Transformer"
"({})".format(model_type))

return transformers_dict[model_type]
return transformers_dict[model_type]
2 changes: 1 addition & 1 deletion dwi_ml/testing/projects/tt_visu_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def build_argparser_transformer_visu():
help="See description above.")
g.add_argument('--color_x_y_summary', action='store_true',
help="See description above.")
gg =g.add_mutually_exclusive_group()
gg = g.add_mutually_exclusive_group()
gg.add_argument('--bertviz', action='store_true',
help="See description above.")
gg.add_argument('--bertviz_locally', action='store_true',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def load_data_run_model(parser, args, model: AbstractTransformerModel,
else:
streamline_ids = 0

streamline_ids = 3544 # ---------------------------------------------------------------- A EFFACER
# Bug in dipy. If accessing only one streamline, then sft.streamlines is
# not a list of streamlines anymore, but the streamline itself. Thus,
# len(sft) = nb_points instead of 1.
Expand Down Expand Up @@ -132,9 +131,9 @@ def reformat_attention(attention, this_seq_len, resample_attention: int):
# (nb_heads = 1 if average_heads).

# Normalizing weight. Without it, we rapidly see nothing!
# Easier to see when we normalize on the x axis.
# Easier to see when we normalize on the x-axis.

# Version 1) L_p pnormalization ==> V = V / norm(v)
# Version 1) L_p normalization ==> V = V / norm(v)
# attention[ll] = torch.nn.functional.normalize(attention[ll], dim=2)
# attention[ll] = torch.nn.functional.normalize(attention[ll], dim=3)

Expand All @@ -149,7 +148,6 @@ def reformat_attention(attention, this_seq_len, resample_attention: int):
max_ = np.max(att, axis=3)
att = att / max_[:, :, :, None]


if resample_attention < this_seq_len:
print("RESAMPLING ATTENTION TO A SEQUENCE OF LENGTH {}\n"
"(uses the max per block)"
Expand All @@ -167,9 +165,10 @@ def reformat_attention(attention, this_seq_len, resample_attention: int):
att = np.pad(att, ((0, 0), (0, 0), (0, missing), (0, missing)),
mode='edge')

# 1000: to see if bug. There should never be padding.
att = block_reduce(
att, block_size=(1, 1, nb_together, nb_together),
func=np.max, cval=1000.0) # 1000: to see if bug.
func=np.max, cval=1000.0)

else:
ind = None
Expand Down Expand Up @@ -355,4 +354,4 @@ def ttst_show_model_view(encoder_attention, tokens):

else:
print("ENCODER ATTENTION: ")
show_model_view_as_imshow(encoder_attention, tokens)
show_model_view_as_imshow(encoder_attention, tokens)
3 changes: 2 additions & 1 deletion dwi_ml/testing/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
from typing import List

from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset, MultisubjectSubset
from dwi_ml.data.dataset.multi_subject_containers import (MultiSubjectDataset,
MultisubjectSubset)


def add_args_testing_subj_hdf5(p, ask_input_group=False,
Expand Down
File renamed without changes.
9 changes: 5 additions & 4 deletions scripts_python/tests/test_all_steps_tto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

data_dir = fetch_testing_data()
tmp_dir = tempfile.TemporaryDirectory()
MAX_LEN = 400 # During tracking, if we allow 200mm * 0.5 step size = 400 points.
# During tracking, if we allow 200mm * 0.5 step size = 400 points.
MAX_LEN = 400


def test_help_option(script_runner):
Expand Down Expand Up @@ -84,9 +85,9 @@ def test_execution(script_runner, experiments_path):
'--max_batches_per_epoch_validation', '1',
'--nheads', '2', '--max_len', str(MAX_LEN),
'--input_embedding_key', 'nn_embedding',
'--input_embedded_size', '6', '--n_layers_e', '1',
'--ffnn_hidden_size', '3', '--logging', 'INFO',
'--use_gpu')
'--input_embedded_size', '6',
'--n_layers_e', '1', '--ffnn_hidden_size', '3',
'--logging', 'INFO', '--use_gpu')
assert ret.success

logging.info("************ TESTING TRACKING FROM MODEL ************")
Expand Down
14 changes: 7 additions & 7 deletions scripts_python/tests/test_all_steps_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
import tempfile

from dwi_ml.unit_tests.utils.expected_values import \
(TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_STREAMLINE_GROUPS, TEST_EXPECTED_SUBJ_NAMES)
(TEST_EXPECTED_VOLUME_GROUPS, TEST_EXPECTED_STREAMLINE_GROUPS,
TEST_EXPECTED_SUBJ_NAMES)
from dwi_ml.unit_tests.utils.data_and_models_for_tests import fetch_testing_data

data_dir = fetch_testing_data()
tmp_dir = tempfile.TemporaryDirectory()
MAX_LEN = 400 # During tracking, if we allow 200mm * 0.5 step size = 400 points.
# During tracking, if we allow 200mm * 0.5 step size = 400 points.
MAX_LEN = 400


def test_help_option(script_runner):
def test_help_option():
# All help tests already tested in test_all_steps_tto.
pass

Expand All @@ -31,8 +33,8 @@ def test_execution(script_runner, experiments_path):
input_group_name = TEST_EXPECTED_VOLUME_GROUPS[0]
streamline_group_name = TEST_EXPECTED_STREAMLINE_GROUPS[0]

# Here, testing default values only. See dwi_ml.unit_tests.test_trainer for more
# various testing.
# Here, testing default values only. See dwi_ml.unit_tests.test_trainer for
# more various testing.
# Max length in current testing dataset is 108. Setting max length to 115
# for faster testing. Also decreasing other default values.
logging.info("************ TESTING TRAINING ************")
Expand Down Expand Up @@ -103,5 +105,3 @@ def test_execution(script_runner, experiments_path):
'--subset', 'training', '--logging', 'INFO',
'--resample_plots', '15', '--rescale_0')
assert ret.success


9 changes: 5 additions & 4 deletions scripts_python/tests/test_all_steps_ttst.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

data_dir = fetch_testing_data()
tmp_dir = tempfile.TemporaryDirectory()
MAX_LEN = 400 # During tracking, if we allow 200mm * 0.5 step size = 400 points.
# During tracking, if we allow 200mm * 0.5 step size = 400 points.
MAX_LEN = 400


def test_help_option(script_runner):
def test_help_option():
# All help tests already tested in test_all_steps_tto.
pass

Expand All @@ -32,8 +33,8 @@ def test_execution(script_runner, experiments_path):
input_group_name = TEST_EXPECTED_VOLUME_GROUPS[0]
streamline_group_name = TEST_EXPECTED_STREAMLINE_GROUPS[0]

# Here, testing default values only. See dwi_ml.unit_tests.test_trainer for more
# various testing.
# Here, testing default values only. See dwi_ml.unit_tests.test_trainer for
# more various testing.
# Max length in current testing dataset is 108. Setting max length to 115
# for faster testing. Also decreasing other default values.
logging.info("************ TESTING TRAINING ************")
Expand Down

0 comments on commit 5989e17

Please sign in to comment.