From aa132ace3d3b334f08f1af7c39ee970869dd5585 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Thu, 12 Dec 2024 14:58:33 -0500 Subject: [PATCH 1/4] ENH: remote sft._data usage --- scilpy/tractograms/dps_and_dpp_management.py | 64 +++++++++++++------ scilpy/viz/color.py | 11 ++-- .../scil_tractogram_assign_custom_color.py | 43 ++++++++----- .../scil_tractogram_assign_uniform_color.py | 27 ++++---- scripts/scil_tractogram_remove_invalid.py | 4 +- 5 files changed, 96 insertions(+), 53 deletions(-) diff --git a/scilpy/tractograms/dps_and_dpp_management.py b/scilpy/tractograms/dps_and_dpp_management.py index bf40e4350..3b2b119ea 100644 --- a/scilpy/tractograms/dps_and_dpp_management.py +++ b/scilpy/tractograms/dps_and_dpp_management.py @@ -1,12 +1,50 @@ # -*- coding: utf-8 -*- import numpy as np -from scilpy.viz.color import clip_and_normalize_data_for_cmap +from nibabel.streamlines import ArraySequence -def add_data_as_color_dpp(sft, cmap, data, clip_outliers=False, min_range=None, - max_range=None, min_cmap=None, max_cmap=None, - log=False, LUT=None): +def get_data_as_arraysequence(data, ref_sft): + """ Get data in the same shape as a reference StatefulTractogram's + streamlines, so it can be used to set data_per_point or + data_per_streamline. The data may represent one value per streamline or one + value per point. The function will return an ArraySequence with the same + shape as the streamlines. + + Parameters + ---------- + data: np.ndarray + The data to convert to ArraySequence. + ref_sft: StatefulTractogram + The reference StatefulTractogram containing the streamlines. + + Returns + ------- + data_as_arraysequence: ArraySequence + The data as an ArraySequence. + """ + + if data.shape[0] == len(ref_sft): + data_as_arraysequence = ArraySequence(data) + elif data.shape[0] == ref_sft._get_point_count(): + data_as_arraysequence = ArraySequence() + # This function was created to avoid messing with _data, _offsets and + # _lengths, so this feel kind of bad. However, the other way would be + # to create a new ArraySequence and iterate over the streamlines, but + # that would be way slower. + data_as_arraysequence._data = data + data_as_arraysequence._offsets = ref_sft.streamlines._offsets + data_as_arraysequence._lengths = ref_sft.streamlines._lengths + else: + raise ValueError("Data has the wrong shape. Expecting either one value" + " per streamline ({}) or one per point ({}) but got " + "{}." + .format(len(ref_sft), len(ref_sft.streamlines._data), + data.shape[0])) + return data_as_arraysequence + + +def add_data_as_color_dpp(sft, color): """ Normalizes data between 0 and 1 for an easier management with colormaps. The real lower bound and upperbound are returned. @@ -54,31 +92,21 @@ def add_data_as_color_dpp(sft, cmap, data, clip_outliers=False, min_range=None, ubound: float The upper bound of the associated colormap. """ - # If data is a list of lists, merge. - if isinstance(data[0], list) or isinstance(data[0], np.ndarray): - data = np.hstack(data) - - values, lbound, ubound = clip_and_normalize_data_for_cmap( - data, clip_outliers, min_range, max_range, - min_cmap, max_cmap, log, LUT) - # Important: values are in float after clip_and_normalize. - color = np.asarray(cmap(values)[:, 0:3]) * 255 - if len(color) == len(sft): + if color.total_nb_rows == len(sft): tmp = [np.tile([color[i][0], color[i][1], color[i][2]], (len(sft.streamlines[i]), 1)) for i in range(len(sft.streamlines))] sft.data_per_point['color'] = tmp - elif len(color) == len(sft.streamlines._data): - sft.data_per_point['color'] = sft.streamlines - sft.data_per_point['color']._data = color + elif color.total_nb_rows == sft.streamlines.total_nb_rows: + sft.data_per_point['color'] = color else: raise ValueError("Error in the code... Colors do not have the right " "shape. Expecting either one color per streamline " "({}) or one per point ({}) but got {}." .format(len(sft), len(sft.streamlines._data), len(color))) - return sft, lbound, ubound + return sft def convert_dps_to_dpp(sft, keys, overwrite=False): diff --git a/scilpy/viz/color.py b/scilpy/viz/color.py index 632bba7a8..c13efb2c3 100644 --- a/scilpy/viz/color.py +++ b/scilpy/viz/color.py @@ -102,7 +102,7 @@ def get_lookup_table(name): name_list = name.split('-') colors_list = [mcolors.to_rgba(color)[0:3] for color in name_list] cmap = mcolors.LinearSegmentedColormap.from_list('CustomCmap', - colors_list) + colors_list) return cmap return plt.colormaps.get_cmap(name) @@ -283,10 +283,10 @@ def prepare_colorbar_figure(cmap, lbound, ubound, nb_values=255, nb_ticks=10, return fig -def ambiant_occlusion(sft, colors, factor=4): +def ambient_occlusion(sft, colors, factor=4): """ Apply ambiant occlusion to a set of colors based on point density - around each points. + around each points. Parameters ---------- @@ -296,14 +296,14 @@ def ambiant_occlusion(sft, colors, factor=4): The original colors to modify. factor : float The factor of occlusion (how density will affect the saturation). - + Returns ------- np.ndarray The modified colors. """ - pts = sft.streamlines._data + pts = sft.streamlines.get_data() hsv = mcolors.rgb_to_hsv(colors) tree = KDTree(pts) @@ -324,6 +324,7 @@ def ambiant_occlusion(sft, colors, factor=4): return mcolors.hsv_to_rgb(hsv) + def generate_local_coloring(sft): """ Generate a coloring based on the local orientation of the streamlines. diff --git a/scripts/scil_tractogram_assign_custom_color.py b/scripts/scil_tractogram_assign_custom_color.py index 321770568..08c66022a 100755 --- a/scripts/scil_tractogram_assign_custom_color.py +++ b/scripts/scil_tractogram_assign_custom_color.py @@ -50,7 +50,6 @@ import logging from dipy.io.streamline import save_tractogram -from fury import colormap import nibabel as nib import numpy as np import matplotlib.pyplot as plt @@ -63,11 +62,13 @@ assert_inputs_exist, assert_outputs_exist, load_matrix_in_any_format) -from scilpy.tractograms.dps_and_dpp_management import add_data_as_color_dpp +from scilpy.tractograms.dps_and_dpp_management import ( + add_data_as_color_dpp, get_data_as_arraysequence) from scilpy.tractograms.streamline_operations import ( get_streamlines_as_linspaces, get_angles) from scilpy.viz.color import ( - get_lookup_table, prepare_colorbar_figure, ambiant_occlusion, + clip_and_normalize_data_for_cmap, + get_lookup_table, prepare_colorbar_figure, ambient_occlusion, generate_local_coloring) @@ -117,7 +118,7 @@ def _build_arg_parser(): "last points are set to 0.") g2 = p.add_argument_group(title='Coloring options') - g2.add_argument('--ambiant_occlusion', nargs='?', const=4, type=int, + g2.add_argument('--ambient_occlusion', nargs='?', const=4, type=int, help='Impact factor of the ambiant occlusion ' 'approximation. [%(default)s]') g2.add_argument('--colormap', default='jet', @@ -204,6 +205,7 @@ def main(): elif args.load_dpp or args.from_anatomy: sft.to_vox() concat_points = np.vstack(sft.streamlines).T + expected_shape = len(concat_points) sft.to_rasmm() if args.load_dpp: @@ -220,23 +222,30 @@ def main(): elif args.local_orientation: data = generate_local_coloring(sft) else: # args.local_angle: - data = get_angles(sft, add_zeros=True) + data = get_angles(sft, add_zeros=True, degrees=False) data = np.hstack(data) - # Processing + # Clip and normalize the data according to the colormap + values, lbound, ubound = clip_and_normalize_data_for_cmap( + data, args.clip_outliers, args.min_range, args.max_range, + args.min_cmap, args.max_cmap, args.log, LUT) + + # Transform the values to RGB in [0, 255] if not args.local_orientation: - sft, lbound, ubound = add_data_as_color_dpp( - sft, cmap, data, args.clip_outliers, args.min_range, args.max_range, - args.min_cmap, args.max_cmap, args.log, LUT) + color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8) else: - sft.data_per_point['color'] = sft.streamlines.copy() - data *= 255 - sft.data_per_point['color']._data = data.astype(np.uint8) - - # Saving - if args.ambiant_occlusion: - sft.data_per_point['color']._data = ambiant_occlusion( - sft, sft.data_per_point['color']._data, args.ambiant_occlusion) + color = (values * 255).astype(np.uint8) + + # Add ambient occlusion to the coloring + if args.ambient_occlusion: + color = ambient_occlusion( + sft, color, args.ambient_occlusion) + + # Set the color data in the tractogram + data = get_data_as_arraysequence(color, sft) + sft = add_data_as_color_dpp( + sft, data) + save_tractogram(sft, args.out_tractogram) if args.out_colorbar: diff --git a/scripts/scil_tractogram_assign_uniform_color.py b/scripts/scil_tractogram_assign_uniform_color.py index 6a1e15413..241558bbe 100755 --- a/scripts/scil_tractogram_assign_uniform_color.py +++ b/scripts/scil_tractogram_assign_uniform_color.py @@ -30,7 +30,9 @@ add_overwrite_arg, add_verbose_arg, add_reference_arg, assert_headers_compatible) -from scilpy.viz.color import format_hexadecimal_color_to_rgb, ambiant_occlusion +from scilpy.tractograms.dps_and_dpp_management import ( + add_data_as_color_dpp, get_data_as_arraysequence) +from scilpy.viz.color import format_hexadecimal_color_to_rgb, ambient_occlusion def _build_arg_parser(): @@ -41,8 +43,8 @@ def _build_arg_parser(): p.add_argument('in_tractograms', nargs='+', help='Input tractograms (.trk or .tck).') - p.add_argument('--ambiant_occlusion', nargs='?', const=4, type=int, - help='Impact factor of the ambiant occlusion ' + p.add_argument('--ambient_occlusion', nargs='?', const=4, type=int, + help='Impact factor of the ambient occlusion ' 'approximation.\n Use factor or 2. Decrease for ' 'lighter and increase for darker [%(default)s].') @@ -121,9 +123,7 @@ def main(): sft = load_tractogram_with_reference(parser, args, filename) - sft.data_per_point['color'] = sft.streamlines.copy() - sft.data_per_point['color']._data = np.zeros( - (len(sft.streamlines._data), 3), dtype=np.uint8) + colors = np.zeros((sft._get_point_count(), 3), dtype=np.uint8) if args.dict_colors: base, ext = os.path.splitext(filename) @@ -141,11 +141,16 @@ def main(): red, green, blue = format_hexadecimal_color_to_rgb(color) - colors = np.tile([red, green, blue], (len(sft.streamlines._data), 1)) - if args.ambiant_occlusion: - colors = ambiant_occlusion(sft, colors, - factor=args.ambiant_occlusion) - sft.data_per_point['color']._data = colors + colors = np.tile([red, green, blue], (sft._get_point_count(), 1)) + if args.ambient_occlusion: + colors = ambient_occlusion(sft, colors, + factor=args.ambient_occlusion) + + # Set the color data in the tractogram + data = get_data_as_arraysequence(colors, sft) + sft = add_data_as_color_dpp( + sft, data) + save_tractogram(sft, out_filenames[i]) diff --git a/scripts/scil_tractogram_remove_invalid.py b/scripts/scil_tractogram_remove_invalid.py index 9e5ce3b58..580775ea6 100755 --- a/scripts/scil_tractogram_remove_invalid.py +++ b/scripts/scil_tractogram_remove_invalid.py @@ -77,7 +77,7 @@ def main(): # Processing ori_len = len(sft) - ori_len_pts = len(sft.streamlines._data) + ori_len_pts = sft._get_point_count() if args.cut_invalid: sft, cutting_counter = cut_invalid_streamlines(sft, epsilon=args.threshold) @@ -92,7 +92,7 @@ def main(): sft = remove_overlapping_points_streamlines(sft, args.threshold) logging.warning("data_per_point will be discarded.") logging.warning('Removed {} overlapping points from tractogram.'.format( - ori_len_pts - len(sft.streamlines._data))) + ori_len_pts - sft._get_point_count())) logging.warning('Removed {} invalid streamlines.'.format( ori_len - len(sft))) From c8b0cf60ece2644c155c55192bb780f565d135c9 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Thu, 12 Dec 2024 15:30:11 -0500 Subject: [PATCH 2/4] ENH: more _data cleanup --- scripts/scil_bundle_diameter.py | 2 +- scripts/scil_bundle_label_map.py | 42 ++++++++++++++++++++------------ 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/scripts/scil_bundle_diameter.py b/scripts/scil_bundle_diameter.py index 65b0afdaf..8b4625913 100755 --- a/scripts/scil_bundle_diameter.py +++ b/scripts/scil_bundle_diameter.py @@ -262,7 +262,7 @@ def main(): counter = 0 labels_dict = {label: ([], []) for label in unique_labels} pts_labels = map_coordinates(data_labels, - sft.streamlines._data.T-0.5, + sft.streamlines.get_data().T-0.5, order=0) # For each label, all positions and directions are needed to get # a tube estimation per label. diff --git a/scripts/scil_bundle_label_map.py b/scripts/scil_bundle_label_map.py index 9281a5c92..134b47085 100755 --- a/scripts/scil_bundle_label_map.py +++ b/scripts/scil_bundle_label_map.py @@ -37,6 +37,8 @@ assert_inputs_exist, assert_output_dirs_exist_and_empty) from scilpy.tractanalysis.bundle_operations import uniformize_bundle_sft +from scilpy.tractograms.dps_and_dpp_management import ( + add_data_as_color_dpp, get_data_as_arraysequence) from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractanalysis.distance_to_centroid import min_dist_to_centroid from scilpy.tractograms.streamline_and_mask_operations import \ @@ -171,8 +173,8 @@ def main(): sft_centroid.streamlines = srm.transform(sft_centroid.streamlines) uniformize_bundle_sft(concat_sft, ref_bundle=sft_centroid[0]) - labels, dists = min_dist_to_centroid(concat_sft.streamlines._data, - sft_centroid.streamlines._data, + labels, dists = min_dist_to_centroid(concat_sft.streamlines.get_data(), + sft_centroid.streamlines.get_data(), args.nb_pts) labels += 1 # 0 means no labels @@ -228,7 +230,7 @@ def main(): final_labels = ArraySequence(final_label) final_dists = ArraySequence(final_dists) - kd_tree = cKDTree(final_streamlines._data) + kd_tree = cKDTree(final_streamlines.get_data()) labels_map = np.zeros(binary_bundle.shape, dtype=np.int16) distance_map = np.zeros(binary_bundle.shape, dtype=float) indices = np.array(np.nonzero(binary_bundle), dtype=int).T @@ -239,8 +241,8 @@ def main(): if not len(neighbor_ids): continue - labels_val = final_labels._data[neighbor_ids] - dists_val = final_dists._data[neighbor_ids] + labels_val = final_labels.get_data()[neighbor_ids] + dists_val = final_dists.get_data()[neighbor_ids] sum_dists_vox = np.sum(dists_val) weights_vox = np.exp(-dists_val / sum_dists_vox) @@ -274,33 +276,43 @@ def main(): if len(sft): tmp_labels = ndi.map_coordinates(labels_map, - sft.streamlines._data.T-0.5, + sft.streamlines.get_data().T-0.5, order=0) tmp_dists = ndi.map_coordinates(distance_map, - sft.streamlines._data.T-0.5, + sft.streamlines.get_data().T-0.5, order=0) tmp_corr = ndi.map_coordinates(corr_map, - sft.streamlines._data.T-0.5, + sft.streamlines.get_data().T-0.5, order=0) cmap = plt.colormaps[args.colormap] - new_sft.data_per_point['color'] = ArraySequence( - new_sft.streamlines) # Nicer visualisation for MI-Brain - new_sft.data_per_point['color']._data = cmap( - tmp_labels / np.max(tmp_labels))[:, 0:3] * 255 + colors = cmap(tmp_labels / np.max(tmp_labels))[:, 0:3] * 255 + data = get_data_as_arraysequence(colors, sft) + new_sft = add_data_as_color_dpp( + new_sft, data) + save_tractogram(new_sft, os.path.join(sub_out_dir, 'labels.trk')) if len(sft): - new_sft.data_per_point['color']._data = cmap( + # Nicer visualisation for MI-Brain + colors = cmap( tmp_dists / np.max(tmp_dists))[:, 0:3] * 255 + data = get_data_as_arraysequence(colors, sft) + new_sft = add_data_as_color_dpp( + new_sft, data) + save_tractogram(new_sft, os.path.join(sub_out_dir, 'distance.trk')) if len(sft): - new_sft.data_per_point['color']._data = cmap(tmp_corr)[ - :, 0:3] * 255 + colors = cmap( + tmp_dists / np.max(tmp_corr))[:, 0:3] * 255 + data = get_data_as_arraysequence(colors, sft) + new_sft = add_data_as_color_dpp( + new_sft, data) + save_tractogram(new_sft, os.path.join(sub_out_dir, 'correlation.trk')) From 9d14d3fda2f8bc0600b19fff18d07878da6ee0e5 Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Thu, 26 Dec 2024 11:49:50 -0500 Subject: [PATCH 3/4] ENH: better dps/dpp handling --- scilpy/tractograms/dps_and_dpp_management.py | 58 ++++++++---- .../tests/test_dps_and_dpp_management.py | 88 ++++++++++++++----- 2 files changed, 104 insertions(+), 42 deletions(-) diff --git a/scilpy/tractograms/dps_and_dpp_management.py b/scilpy/tractograms/dps_and_dpp_management.py index 21dfa0823..fda790b35 100644 --- a/scilpy/tractograms/dps_and_dpp_management.py +++ b/scilpy/tractograms/dps_and_dpp_management.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np +from collections.abc import Iterable from nibabel.streamlines import ArraySequence @@ -23,18 +24,27 @@ def get_data_as_arraysequence(data, ref_sft): data_as_arraysequence: ArraySequence The data as an ArraySequence. """ - - if data.shape[0] == len(ref_sft): + # Check if data has the right shape, either one value per streamline or one + # value per point. + if data.shape[0] == ref_sft._get_streamline_count(): + # Two consective if statements to handle both 1D and 2D arrays + # and turn them into lists of lists of lists. + # Check if the data is a vector or a scalar. + if len(data.shape) == 1: + data = data[:, None] + # ArraySequence expects a list of lists of lists, so we need to add + # an extra dimension. + if len(data.shape) == 2: + data = data[:, None, :] data_as_arraysequence = ArraySequence(data) + elif data.shape[0] == ref_sft._get_point_count(): - data_as_arraysequence = ArraySequence() - # This function was created to avoid messing with _data, _offsets and - # _lengths, so this feel kind of bad. However, the other way would be - # to create a new ArraySequence and iterate over the streamlines, but - # that would be way slower. - data_as_arraysequence._data = data - data_as_arraysequence._offsets = ref_sft.streamlines._offsets - data_as_arraysequence._lengths = ref_sft.streamlines._lengths + # Split the data into a list of arrays, one per streamline. + # np.split takes the indices at which to split the array, so use + # np.cumsum to get the indices of the end of each streamline. + data_split = np.split(data, np.cumsum(ref_sft.streamlines._lengths)[:-1]) + # Create an ArraySequence from the list of arrays. + data_as_arraysequence = ArraySequence(data_split) else: raise ValueError("Data has the wrong shape. Expecting either one value" " per streamline ({}) or one per point ({}) but got " @@ -93,19 +103,31 @@ def add_data_as_color_dpp(sft, color): The upper bound of the associated colormap. """ - if color.total_nb_rows == len(sft): - tmp = [np.tile([color[i][0], color[i][1], color[i][2]], + if len(color) == sft._get_streamline_count(): + if color.common_shape != (3,): + raise ValueError("Colors do not have the right shape. Expecting " + "RBG values, but got values of shape {}.".format( + color.common_shape)) + + tmp = [np.tile([color[i][0][0], color[i][0][1], color[i][0][2]], (len(sft.streamlines[i]), 1)) for i in range(len(sft.streamlines))] sft.data_per_point['color'] = tmp - elif color.total_nb_rows == sft.streamlines.total_nb_rows: + + elif len(color) == sft._get_point_count(): + + if color.common_shape != (3,): + raise ValueError("Colors do not have the right shape. Expecting " + "RBG values, but got values of shape {}.".format( + color.common_shape)) + sft.data_per_point['color'] = color else: - raise ValueError("Error in the code... Colors do not have the right " - "shape. Expecting either one color per streamline " - "({}) or one per point ({}) but got {}." - .format(len(sft), len(sft.streamlines._data), - len(color))) + raise ValueError("Colors do not have the right shape. Expecting either" + " one color per streamline ({}) or one per point ({})" + " but got {}.".format(sft._get_streamline_count(), + sft._get_point_count(), + color.total_nb_rows)) return sft diff --git a/scilpy/tractograms/tests/test_dps_and_dpp_management.py b/scilpy/tractograms/tests/test_dps_and_dpp_management.py index c32f2eace..4d0ccd410 100644 --- a/scilpy/tractograms/tests/test_dps_and_dpp_management.py +++ b/scilpy/tractograms/tests/test_dps_and_dpp_management.py @@ -1,11 +1,14 @@ # -*- coding: utf-8 -*- import nibabel as nib import numpy as np +import pytest + from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin from scilpy.image.volume_space_management import DataVolume from scilpy.tests.utils import nan_array_equal from scilpy.tractograms.dps_and_dpp_management import ( + get_data_as_arraysequence, add_data_as_color_dpp, convert_dps_to_dpp, project_map_to_streamlines, project_dpp_to_map, perform_operation_on_dpp, perform_operation_dpp_to_dps, perform_correlation_on_endpoints) @@ -27,45 +30,82 @@ def _get_small_sft(): return fake_sft -def test_add_data_as_color_dpp(): - lut = get_lookup_table('viridis') +def test_get_data_as_arraysequence_dpp(): + fake_sft = _get_small_sft() + + some_data = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5]) + + # Test 1: One value per point. + array_seq = get_data_as_arraysequence(some_data, fake_sft) - # Important. cmap(1) != cmap(1.0) - lowest_color = np.asarray(lut(0.0)[0:3]) * 255 - highest_color = np.asarray(lut(1.0)[0:3]) * 255 + assert fake_sft._get_point_count() == array_seq.total_nb_rows + +def test_get_data_as_arraysequence_dps(): fake_sft = _get_small_sft() + some_data = np.asarray([2, 20]) + + # Test 1: One value per point. + array_seq = get_data_as_arraysequence(some_data, fake_sft) + assert fake_sft._get_streamline_count() == array_seq.total_nb_rows + + +def test_get_data_as_arraysequence_dps_2D(): + fake_sft = _get_small_sft() + + some_data = np.asarray([[2], [20]]) + + # Test 1: One value per point. + array_seq = get_data_as_arraysequence(some_data, fake_sft) + assert fake_sft._get_streamline_count() == array_seq.total_nb_rows + + +def test_get_data_as_arraysequence_error(): + fake_sft = _get_small_sft() + + some_data = np.asarray([2, 20, 200, 0.1]) + + # Test 1: One value per point. + with pytest.raises(ValueError): + _ = get_data_as_arraysequence(some_data, fake_sft) + + +def test_add_data_as_dpp_1_per_point(): + + fake_sft = _get_small_sft() + cmap = get_lookup_table('jet') + # Not testing the clipping options. Will be tested through viz.utils tests # Test 1: One value per point. # Lowest cmap color should be first point of second streamline. - some_data = [[2, 20, 200], [0.1, 0.3, 22, 5]] - colored_sft, lbound, ubound = add_data_as_color_dpp( - fake_sft, lut, some_data) + values = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5]) + color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8) + + array_seq = get_data_as_arraysequence(color, fake_sft) + colored_sft = add_data_as_color_dpp( + fake_sft, array_seq) assert len(colored_sft.data_per_streamline.keys()) == 0 assert list(colored_sft.data_per_point.keys()) == ['color'] - assert lbound == 0.1 - assert ubound == 200 - assert np.array_equal(colored_sft.data_per_point['color'][1][0, :], - lowest_color) - assert np.array_equal(colored_sft.data_per_point['color'][0][2, :], - highest_color) + + +def test_add_data_as_dpp_1_per_streamline(): + + fake_sft = _get_small_sft() + cmap = get_lookup_table('jet') # Test 2: One value per streamline # Lowest cmap color should be every point in first streamline - some_data = np.asarray([4, 5]) - colored_sft, lbound, ubound = add_data_as_color_dpp( - fake_sft, lut, some_data) + values = np.asarray([4, 5]) + color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8) + array_seq = get_data_as_arraysequence(color, fake_sft) + + colored_sft = add_data_as_color_dpp( + fake_sft, array_seq) + assert len(colored_sft.data_per_streamline.keys()) == 0 assert list(colored_sft.data_per_point.keys()) == ['color'] - assert lbound == 4 - assert ubound == 5 - # Lowest cmap color should be first point of second streamline. - # Same value for all points. - colors_first_line = colored_sft.data_per_point['color'][0] - assert np.array_equal(colors_first_line[0, :], lowest_color) - assert np.all(colors_first_line[1:, :] == colors_first_line[0, :]) def test_convert_dps_to_dpp(): From 4bf9d512e8226ab71235499beb93255962e6eeaa Mon Sep 17 00:00:00 2001 From: AntoineTheb Date: Thu, 26 Dec 2024 17:52:13 -0500 Subject: [PATCH 4/4] FIX: pep8 --- scilpy/tractograms/dps_and_dpp_management.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scilpy/tractograms/dps_and_dpp_management.py b/scilpy/tractograms/dps_and_dpp_management.py index fda790b35..38223e995 100644 --- a/scilpy/tractograms/dps_and_dpp_management.py +++ b/scilpy/tractograms/dps_and_dpp_management.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import numpy as np -from collections.abc import Iterable from nibabel.streamlines import ArraySequence @@ -40,9 +39,10 @@ def get_data_as_arraysequence(data, ref_sft): elif data.shape[0] == ref_sft._get_point_count(): # Split the data into a list of arrays, one per streamline. - # np.split takes the indices at which to split the array, so use + # np.split takes the indices at which to split the array, so use # np.cumsum to get the indices of the end of each streamline. - data_split = np.split(data, np.cumsum(ref_sft.streamlines._lengths)[:-1]) + data_split = np.split( + data, np.cumsum(ref_sft.streamlines._lengths)[:-1]) # Create an ArraySequence from the list of arrays. data_as_arraysequence = ArraySequence(data_split) else: