Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: remove sft._data usage part 1 - tractogram coloring scripts + more #1105

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions scilpy/tractograms/dps_and_dpp_management.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,50 @@
# -*- coding: utf-8 -*-
import numpy as np

from scilpy.viz.color import clip_and_normalize_data_for_cmap
from nibabel.streamlines import ArraySequence


def add_data_as_color_dpp(sft, cmap, data, clip_outliers=False, min_range=None,
max_range=None, min_cmap=None, max_cmap=None,
log=False, LUT=None):
def get_data_as_arraysequence(data, ref_sft):
""" Get data in the same shape as a reference StatefulTractogram's
streamlines, so it can be used to set data_per_point or
data_per_streamline. The data may represent one value per streamline or one
value per point. The function will return an ArraySequence with the same
shape as the streamlines.

Parameters
----------
data: np.ndarray
The data to convert to ArraySequence.
ref_sft: StatefulTractogram
The reference StatefulTractogram containing the streamlines.

Returns
-------
data_as_arraysequence: ArraySequence
The data as an ArraySequence.
"""

if data.shape[0] == len(ref_sft):
data_as_arraysequence = ArraySequence(data)
elif data.shape[0] == ref_sft._get_point_count():
data_as_arraysequence = ArraySequence()
# This function was created to avoid messing with _data, _offsets and
# _lengths, so this feel kind of bad. However, the other way would be
# to create a new ArraySequence and iterate over the streamlines, but
# that would be way slower.
data_as_arraysequence._data = data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a way to hstack the data, or vstack, and then convert to array sequence?

In my PR #890 , I had done this. Is it really that much slower?

image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I often do is that I create an array sequence with my streamlines like this:
dpp = sft.streamlines.copy()
dpp._data = my_array.copy()

That way everything is initialized with the right length, memory safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EmmaRenauld I feel like iterating over all streamlines would be problematic for large (>1M) tractograms.
@frheault the problem with your approach is the dpp is still "prealocated" using the streamlines, which is weird. I think this is a deeper problem stemming from the way ArraySequences work but I'd rather not have to preallocate data per points as the streamlines' points. Especially if more complicated processing is needed and some DPP may be left to the streamline's value by accident.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EmmaRenauld @frheault should be a bit cleaner now.

data_as_arraysequence._offsets = ref_sft.streamlines._offsets
data_as_arraysequence._lengths = ref_sft.streamlines._lengths
else:
raise ValueError("Data has the wrong shape. Expecting either one value"
" per streamline ({}) or one per point ({}) but got "
"{}."
.format(len(ref_sft), len(ref_sft.streamlines._data),
data.shape[0]))
return data_as_arraysequence


def add_data_as_color_dpp(sft, color):
"""
Normalizes data between 0 and 1 for an easier management with colormaps.
The real lower bound and upperbound are returned.
Expand Down Expand Up @@ -54,31 +92,21 @@ def add_data_as_color_dpp(sft, cmap, data, clip_outliers=False, min_range=None,
ubound: float
The upper bound of the associated colormap.
"""
# If data is a list of lists, merge.
if isinstance(data[0], list) or isinstance(data[0], np.ndarray):
data = np.hstack(data)

values, lbound, ubound = clip_and_normalize_data_for_cmap(
data, clip_outliers, min_range, max_range,
min_cmap, max_cmap, log, LUT)

# Important: values are in float after clip_and_normalize.
color = np.asarray(cmap(values)[:, 0:3]) * 255
if len(color) == len(sft):
if color.total_nb_rows == len(sft):
tmp = [np.tile([color[i][0], color[i][1], color[i][2]],
(len(sft.streamlines[i]), 1))
for i in range(len(sft.streamlines))]
sft.data_per_point['color'] = tmp
elif len(color) == len(sft.streamlines._data):
sft.data_per_point['color'] = sft.streamlines
sft.data_per_point['color']._data = color
elif color.total_nb_rows == sft.streamlines.total_nb_rows:
sft.data_per_point['color'] = color
else:
raise ValueError("Error in the code... Colors do not have the right "
"shape. Expecting either one color per streamline "
"({}) or one per point ({}) but got {}."
.format(len(sft), len(sft.streamlines._data),
len(color)))
return sft, lbound, ubound
return sft


def convert_dps_to_dpp(sft, keys, overwrite=False):
Expand Down
11 changes: 6 additions & 5 deletions scilpy/viz/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_lookup_table(name):
name_list = name.split('-')
colors_list = [mcolors.to_rgba(color)[0:3] for color in name_list]
cmap = mcolors.LinearSegmentedColormap.from_list('CustomCmap',
colors_list)
colors_list)
return cmap

return plt.colormaps.get_cmap(name)
Expand Down Expand Up @@ -283,10 +283,10 @@ def prepare_colorbar_figure(cmap, lbound, ubound, nb_values=255, nb_ticks=10,
return fig


def ambiant_occlusion(sft, colors, factor=4):
def ambient_occlusion(sft, colors, factor=4):
"""
Apply ambiant occlusion to a set of colors based on point density
around each points.
around each points.

Parameters
----------
Expand All @@ -296,14 +296,14 @@ def ambiant_occlusion(sft, colors, factor=4):
The original colors to modify.
factor : float
The factor of occlusion (how density will affect the saturation).

Returns
-------
np.ndarray
The modified colors.
"""

pts = sft.streamlines._data
pts = sft.streamlines.get_data()
hsv = mcolors.rgb_to_hsv(colors)

tree = KDTree(pts)
Expand All @@ -324,6 +324,7 @@ def ambiant_occlusion(sft, colors, factor=4):

return mcolors.hsv_to_rgb(hsv)


def generate_local_coloring(sft):
"""
Generate a coloring based on the local orientation of the streamlines.
Expand Down
2 changes: 1 addition & 1 deletion scripts/scil_bundle_diameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def main():
counter = 0
labels_dict = {label: ([], []) for label in unique_labels}
pts_labels = map_coordinates(data_labels,
sft.streamlines._data.T-0.5,
sft.streamlines.get_data().T-0.5,
order=0)
# For each label, all positions and directions are needed to get
# a tube estimation per label.
Expand Down
42 changes: 27 additions & 15 deletions scripts/scil_bundle_label_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
assert_inputs_exist,
assert_output_dirs_exist_and_empty)
from scilpy.tractanalysis.bundle_operations import uniformize_bundle_sft
from scilpy.tractograms.dps_and_dpp_management import (
add_data_as_color_dpp, get_data_as_arraysequence)
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map
from scilpy.tractanalysis.distance_to_centroid import min_dist_to_centroid
from scilpy.tractograms.streamline_and_mask_operations import \
Expand Down Expand Up @@ -171,8 +173,8 @@ def main():
sft_centroid.streamlines = srm.transform(sft_centroid.streamlines)

uniformize_bundle_sft(concat_sft, ref_bundle=sft_centroid[0])
labels, dists = min_dist_to_centroid(concat_sft.streamlines._data,
sft_centroid.streamlines._data,
labels, dists = min_dist_to_centroid(concat_sft.streamlines.get_data(),
sft_centroid.streamlines.get_data(),
args.nb_pts)
labels += 1 # 0 means no labels

Expand Down Expand Up @@ -228,7 +230,7 @@ def main():
final_labels = ArraySequence(final_label)
final_dists = ArraySequence(final_dists)

kd_tree = cKDTree(final_streamlines._data)
kd_tree = cKDTree(final_streamlines.get_data())
labels_map = np.zeros(binary_bundle.shape, dtype=np.int16)
distance_map = np.zeros(binary_bundle.shape, dtype=float)
indices = np.array(np.nonzero(binary_bundle), dtype=int).T
Expand All @@ -239,8 +241,8 @@ def main():
if not len(neighbor_ids):
continue

labels_val = final_labels._data[neighbor_ids]
dists_val = final_dists._data[neighbor_ids]
labels_val = final_labels.get_data()[neighbor_ids]
dists_val = final_dists.get_data()[neighbor_ids]
sum_dists_vox = np.sum(dists_val)
weights_vox = np.exp(-dists_val / sum_dists_vox)

Expand Down Expand Up @@ -274,33 +276,43 @@ def main():

if len(sft):
tmp_labels = ndi.map_coordinates(labels_map,
sft.streamlines._data.T-0.5,
sft.streamlines.get_data().T-0.5,
order=0)
tmp_dists = ndi.map_coordinates(distance_map,
sft.streamlines._data.T-0.5,
sft.streamlines.get_data().T-0.5,
order=0)
tmp_corr = ndi.map_coordinates(corr_map,
sft.streamlines._data.T-0.5,
sft.streamlines.get_data().T-0.5,
order=0)
cmap = plt.colormaps[args.colormap]
new_sft.data_per_point['color'] = ArraySequence(
new_sft.streamlines)

# Nicer visualisation for MI-Brain
new_sft.data_per_point['color']._data = cmap(
tmp_labels / np.max(tmp_labels))[:, 0:3] * 255
colors = cmap(tmp_labels / np.max(tmp_labels))[:, 0:3] * 255
data = get_data_as_arraysequence(colors, sft)
new_sft = add_data_as_color_dpp(
new_sft, data)

save_tractogram(new_sft,
os.path.join(sub_out_dir, 'labels.trk'))

if len(sft):
new_sft.data_per_point['color']._data = cmap(
# Nicer visualisation for MI-Brain
colors = cmap(
tmp_dists / np.max(tmp_dists))[:, 0:3] * 255
data = get_data_as_arraysequence(colors, sft)
new_sft = add_data_as_color_dpp(
new_sft, data)

save_tractogram(new_sft,
os.path.join(sub_out_dir, 'distance.trk'))

if len(sft):
new_sft.data_per_point['color']._data = cmap(tmp_corr)[
:, 0:3] * 255
colors = cmap(
tmp_dists / np.max(tmp_corr))[:, 0:3] * 255
data = get_data_as_arraysequence(colors, sft)
new_sft = add_data_as_color_dpp(
new_sft, data)

save_tractogram(new_sft,
os.path.join(sub_out_dir, 'correlation.trk'))

Expand Down
43 changes: 26 additions & 17 deletions scripts/scil_tractogram_assign_custom_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import logging

from dipy.io.streamline import save_tractogram
from fury import colormap
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
Expand All @@ -63,11 +62,13 @@
assert_inputs_exist,
assert_outputs_exist,
load_matrix_in_any_format)
from scilpy.tractograms.dps_and_dpp_management import add_data_as_color_dpp
from scilpy.tractograms.dps_and_dpp_management import (
add_data_as_color_dpp, get_data_as_arraysequence)
from scilpy.tractograms.streamline_operations import (
get_streamlines_as_linspaces, get_angles)
from scilpy.viz.color import (
get_lookup_table, prepare_colorbar_figure, ambiant_occlusion,
clip_and_normalize_data_for_cmap,
get_lookup_table, prepare_colorbar_figure, ambient_occlusion,
generate_local_coloring)


Expand Down Expand Up @@ -117,7 +118,7 @@ def _build_arg_parser():
"last points are set to 0.")

g2 = p.add_argument_group(title='Coloring options')
g2.add_argument('--ambiant_occlusion', nargs='?', const=4, type=int,
g2.add_argument('--ambient_occlusion', nargs='?', const=4, type=int,
help='Impact factor of the ambiant occlusion '
'approximation. [%(default)s]')
g2.add_argument('--colormap', default='jet',
Expand Down Expand Up @@ -204,6 +205,7 @@ def main():
elif args.load_dpp or args.from_anatomy:
sft.to_vox()
concat_points = np.vstack(sft.streamlines).T

expected_shape = len(concat_points)
sft.to_rasmm()
if args.load_dpp:
Expand All @@ -220,23 +222,30 @@ def main():
elif args.local_orientation:
data = generate_local_coloring(sft)
else: # args.local_angle:
data = get_angles(sft, add_zeros=True)
data = get_angles(sft, add_zeros=True, degrees=False)
data = np.hstack(data)

# Processing
# Clip and normalize the data according to the colormap
values, lbound, ubound = clip_and_normalize_data_for_cmap(
data, args.clip_outliers, args.min_range, args.max_range,
args.min_cmap, args.max_cmap, args.log, LUT)

# Transform the values to RGB in [0, 255]
if not args.local_orientation:
sft, lbound, ubound = add_data_as_color_dpp(
sft, cmap, data, args.clip_outliers, args.min_range, args.max_range,
args.min_cmap, args.max_cmap, args.log, LUT)
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)
else:
sft.data_per_point['color'] = sft.streamlines.copy()
data *= 255
sft.data_per_point['color']._data = data.astype(np.uint8)

# Saving
if args.ambiant_occlusion:
sft.data_per_point['color']._data = ambiant_occlusion(
sft, sft.data_per_point['color']._data, args.ambiant_occlusion)
color = (values * 255).astype(np.uint8)

# Add ambient occlusion to the coloring
if args.ambient_occlusion:
color = ambient_occlusion(
sft, color, args.ambient_occlusion)

# Set the color data in the tractogram
data = get_data_as_arraysequence(color, sft)
sft = add_data_as_color_dpp(
sft, data)

save_tractogram(sft, args.out_tractogram)

if args.out_colorbar:
Expand Down
27 changes: 16 additions & 11 deletions scripts/scil_tractogram_assign_uniform_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
add_overwrite_arg,
add_verbose_arg,
add_reference_arg, assert_headers_compatible)
from scilpy.viz.color import format_hexadecimal_color_to_rgb, ambiant_occlusion
from scilpy.tractograms.dps_and_dpp_management import (
add_data_as_color_dpp, get_data_as_arraysequence)
from scilpy.viz.color import format_hexadecimal_color_to_rgb, ambient_occlusion


def _build_arg_parser():
Expand All @@ -41,8 +43,8 @@ def _build_arg_parser():
p.add_argument('in_tractograms', nargs='+',
help='Input tractograms (.trk or .tck).')

p.add_argument('--ambiant_occlusion', nargs='?', const=4, type=int,
help='Impact factor of the ambiant occlusion '
p.add_argument('--ambient_occlusion', nargs='?', const=4, type=int,
help='Impact factor of the ambient occlusion '
'approximation.\n Use factor or 2. Decrease for '
'lighter and increase for darker [%(default)s].')

Expand Down Expand Up @@ -121,9 +123,7 @@ def main():

sft = load_tractogram_with_reference(parser, args, filename)

sft.data_per_point['color'] = sft.streamlines.copy()
sft.data_per_point['color']._data = np.zeros(
(len(sft.streamlines._data), 3), dtype=np.uint8)
colors = np.zeros((sft._get_point_count(), 3), dtype=np.uint8)

if args.dict_colors:
base, ext = os.path.splitext(filename)
Expand All @@ -141,11 +141,16 @@ def main():

red, green, blue = format_hexadecimal_color_to_rgb(color)

colors = np.tile([red, green, blue], (len(sft.streamlines._data), 1))
if args.ambiant_occlusion:
colors = ambiant_occlusion(sft, colors,
factor=args.ambiant_occlusion)
sft.data_per_point['color']._data = colors
colors = np.tile([red, green, blue], (sft._get_point_count(), 1))
if args.ambient_occlusion:
colors = ambient_occlusion(sft, colors,
factor=args.ambient_occlusion)

# Set the color data in the tractogram
data = get_data_as_arraysequence(colors, sft)
sft = add_data_as_color_dpp(
sft, data)

save_tractogram(sft, out_filenames[i])


Expand Down
4 changes: 2 additions & 2 deletions scripts/scil_tractogram_remove_invalid.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main():

# Processing
ori_len = len(sft)
ori_len_pts = len(sft.streamlines._data)
ori_len_pts = sft._get_point_count()
if args.cut_invalid:
sft, cutting_counter = cut_invalid_streamlines(sft,
epsilon=args.threshold)
Expand All @@ -92,7 +92,7 @@ def main():
sft = remove_overlapping_points_streamlines(sft, args.threshold)
logging.warning("data_per_point will be discarded.")
logging.warning('Removed {} overlapping points from tractogram.'.format(
ori_len_pts - len(sft.streamlines._data)))
ori_len_pts - sft._get_point_count()))

logging.warning('Removed {} invalid streamlines.'.format(
ori_len - len(sft)))
Expand Down
Loading