diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index f09e3186..a2764820 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import logging from typing import List import numpy as np @@ -320,13 +321,28 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, compressed streamlines.' Else, uses simple computation from endpoints. Faster. Also, works with incomplete parcellation. + + Returns + ------- + matrix: np.ndarray + With use_scilpy: shape (nb_labels + 1, nb_labels + 1) + (last label is "Not Found") + Else, shape (nb_labels, nb_labels) + labels: List + The list of labels """ - real_labels = np.unique(data_labels)[1:] + real_labels = list(np.sort(np.unique(data_labels))) nb_labels = len(real_labels) - matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int) + logging.debug("Computing connectivity matrix for {} labels." + .format(nb_labels)) - start_blocs = [] - end_blocs = [] + if use_scilpy: + matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int) + else: + matrix = np.zeros((nb_labels, nb_labels), dtype=int) + + start_labels = [] + end_labels = [] if use_scilpy: indices, points_to_idx = uncompress(streamlines, return_mapping=True) @@ -334,29 +350,33 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, for strl_vox_indices in indices: segments_info = segmenting_func(strl_vox_indices, data_labels) if len(segments_info) > 0: - start = segments_info[0]['start_label'] - end = segments_info[0]['end_label'] - start_blocs.append(start) - end_blocs.append(end) + start = real_labels.index(segments_info[0]['start_label']) + end = real_labels.index(segments_info[0]['end_label']) + else: + start = nb_labels + end = nb_labels - matrix[start, end] += 1 - if start != end: - matrix[end, start] += 1 + start_labels.append(start) + end_labels.append(end) + + matrix[start, end] += 1 + if start != end: + matrix[end, start] += 1 + + real_labels = real_labels + [np.NaN] - else: - # Putting it in 0,0, we will remember that this means 'other' - matrix[0, 0] += 1 - start_blocs.append(0) - end_blocs.append(0) else: for s in streamlines: # Vox space, corner origin # = we can get the nearest neighbor easily. # Coord 0 = voxel 0. Coord 0.9 = voxel 0. Coord 1 = voxel 1. - start = data_labels[tuple(np.floor(s[0, :]).astype(int))] - end = data_labels[tuple(np.floor(s[-1, :]).astype(int))] - start_blocs.append(start) - end_blocs.append(end) + start = real_labels.index( + data_labels[tuple(np.floor(s[0, :]).astype(int))]) + end = real_labels.index( + data_labels[tuple(np.floor(s[-1, :]).astype(int))]) + + start_labels.append(start) + end_labels.append(end) matrix[start, end] += 1 if start != end: matrix[end, start] += 1 @@ -367,7 +387,7 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, if binary: matrix = matrix.astype(bool) - return matrix, start_blocs, end_blocs + return matrix, real_labels, start_labels, end_labels def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs, diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py index f662f684..fb956c49 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py @@ -1,5 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- + +""" +Computes the connectivity matrix. +Labels associated with each line / row will be printed. +""" + import argparse import logging import os.path @@ -65,10 +71,16 @@ def main(): args = p.parse_args() if args.verbose: + # Currenlty, with debug, matplotlib prints a lot of stuff. Why?? logging.getLogger().setLevel(logging.INFO) tmp, ext = os.path.splitext(args.out_file) + + if ext != '.npy': + p.error("--out_file should have a .npy extension.") + out_fig = tmp + '.png' + out_ordered_labels = tmp + '_labels.txt' assert_inputs_exist(p, [args.in_labels, args.streamlines]) assert_outputs_exist(p, args, [args.out_file, out_fig], [args.save_biggest, args.save_smallest]) @@ -80,26 +92,36 @@ def main(): p.error("Streamlines not compatible with chosen volume.") else: args.reference = args.in_labels + + logging.info("Loading tractogram.") in_sft = load_tractogram_with_reference(p, args, args.streamlines) in_img = nib.load(args.in_labels) data_labels = get_data_as_labels(in_img) in_sft.to_vox() in_sft.to_corner() - matrix, start_blocs, end_blocs = compute_triu_connectivity_from_labels( - in_sft.streamlines, data_labels, - use_scilpy=args.use_longest_segment) + matrix, ordered_labels, start_blocs, end_blocs = \ + compute_triu_connectivity_from_labels( + in_sft.streamlines, data_labels, + use_scilpy=args.use_longest_segment) if args.hide_background is not None: - matrix[args.hide_background, :] = 0 - matrix[:, args.hide_background] = 0 + idx = ordered_labels.idx(args.hide_background) + matrix[idx, :] = 0 + matrix[:, idx] = 0 + ordered_labels[idx] = ("Hidden background ({})" + .format(args.hide_background)) + + logging.info("Labels are, in order: {}".format(ordered_labels)) # Options to try to investigate the connectivity matrix: # masking point (0,0) = streamline ending in wm. if args.save_biggest is not None: i, j = np.unravel_index(np.argmax(matrix, axis=None), matrix.shape) print("Saving biggest bundle: {} streamlines. From label {} to label " - "{}".format(matrix[i, j], i, j)) + "{} (line {}, column {} in the matrix)" + .format(matrix[i, j], ordered_labels[i], ordered_labels[j], + i, j)) biggest = find_streamlines_with_chosen_connectivity( in_sft.streamlines, i, j, start_blocs, end_blocs) sft = in_sft.from_sft(biggest, in_sft) @@ -109,15 +131,22 @@ def main(): tmp_matrix = np.ma.masked_equal(matrix, 0) i, j = np.unravel_index(tmp_matrix.argmin(axis=None), matrix.shape) print("Saving smallest bundle: {} streamlines. From label {} to label " - "{}".format(matrix[i, j], i, j)) - biggest = find_streamlines_with_chosen_connectivity( + "{} (line {}, column {} in the matrix)" + .format(matrix[i, j], ordered_labels[i], ordered_labels[j], + i, j)) + smallest = find_streamlines_with_chosen_connectivity( in_sft.streamlines, i, j, start_blocs, end_blocs) - sft = in_sft.from_sft(biggest, in_sft) + sft = in_sft.from_sft(smallest, in_sft) save_tractogram(sft, args.save_smallest) + ordered_labels = str(ordered_labels) + with open(out_ordered_labels, "w") as text_file: + text_file.write(ordered_labels) + if args.show_now: plt.imshow(matrix) plt.colorbar() + plt.title("Raw streamline count") plt.figure() plt.imshow(matrix > 0)