Skip to content

Commit

Permalink
New test file and fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
levje committed Oct 31, 2024
1 parent 20f2e38 commit 5e4b9ea
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
32 changes: 23 additions & 9 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def _get_one_streamline(self, idx: int):
data = self.hdf_group['data'][offset:offset + length]

return data

def _assert_dps(self, dps_dict, n_streamlines):
for key, value in dps_dict.items():
if len(value) != n_streamlines:
raise ValueError(
f"Length of data_per_streamline {key} is {len(value)} "
f"but should be {n_streamlines}.")
elif type(value) != np.ndarray:
raise ValueError(
f"Data_per_streamline {key} should be a numpy array, "
f"not a {type(value)}.")

def get_array_sequence(self, item=None):
if item is None:
Expand Down Expand Up @@ -134,21 +145,24 @@ def get_array_sequence(self, item=None):
streamlines.append(streamline, cache_build=True)

for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][idx])
# Indexing with a list (e.g. [idx]) will preserve the
# shape of the array. Crucial for concatenation below.
dps_data = hdf_dps_group[dps_key][[idx]]
data_per_streamline[dps_key].append(dps_data)
streamlines.finalize_append()

else:
raise ValueError('Item should be either a int, list, '
'np.ndarray or slice but we received {}'
.format(type(item)))

# The accumulated data_per_streamline is a list of numpy arrays.
# We need to merge them into a single numpy array so it can be
# reused in the StatefulTractogram.
for key in data_per_streamline.keys():
data_per_streamline[key] = np.concatenate(data_per_streamline[key])


# The accumulated data_per_streamline is a list of numpy arrays.
# We need to merge them into a single numpy array so it can be
# reused in the StatefulTractogram.
for key in data_per_streamline.keys():
data_per_streamline[key] = np.concatenate(data_per_streamline[key])

self._assert_dps(data_per_streamline, len(streamlines))
return streamlines, data_per_streamline

@ property
Expand Down
15 changes: 14 additions & 1 deletion dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def split_streamlines(sft: StatefulTractogram, rng: np.random.RandomState,
for i in range(len(sft.streamlines)):
old_streamline = sft.streamlines[i]
old_dpp = sft.data_per_point[i]

# Note: This getter gets lists of numpy arrays of shape
# (n_features, n_streamlines) for some reason. This is why we need to
# transpose the all_dps arrays at the end. Not sure why this is
# happening as the arrays are stored within the PerArrayDict as
# numpy arrays of shape (n_streamlines, n_features).
old_dps = sft.data_per_streamline[i]

# Cut if at least min_nb_points
Expand All @@ -106,8 +112,15 @@ def split_streamlines(sft: StatefulTractogram, rng: np.random.RandomState,
# Since _extend_dict appends many numpy arrays into a list,
# we need to merge them into a single array that can be fed
# to StatefulTractogram at data_per_streamlines.
#
# Note: at this point, all_dps is a dict of lists of numpy arrays
# of shape (n_features, n_streamlines). The StatefulTractogram
# expects a dict of numpy arrays of shape (n_streamlines, n_features).
# We need to concat along the second axis and transpose to get the
# correct shape.

for key in sft.data_per_streamline.keys():
all_dps[key] = np.concatenate(all_dps[key])
all_dps[key] = np.concatenate(all_dps[key], axis=1).transpose()

new_sft = StatefulTractogram.from_sft(all_streamlines, sft,
data_per_point=all_dpp,
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/unit_tests/utils/data_and_models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def fetch_testing_data():
# Access to the file dwi_ml.zip:
# https://drive.google.com/uc?id=1beRWAorhaINCncttgwqVAP2rNOfx842Q
name_as_dict = {
'data_for_tests_dwi_ml.zip': "59c9275d2fe83b7e2d6154877ab32b8b"}
'data_for_tests_dwi_ml.zip': "f8bd3bd88e10d939a7168468e1e99a00"}#"59c9275d2fe83b7e2d6154877ab32b8b"}
fetch_data(name_as_dict)

return testing_data_dir
Expand Down

0 comments on commit 5e4b9ea

Please sign in to comment.