forked from LeonDLotter/MAsync
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils_image.py
444 lines (347 loc) · 15.6 KB
/
utils_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jun 19 11:48:09 2021
@author: Leon D. Lotter
"""
from os import getcwd
from os.path import splitext, join
import numpy as np
import pandas as pd
import nibabel as nb
from nilearn.datasets import load_mni152_template
from nilearn.input_data import NiftiLabelsMasker, NiftiMasker
from nilearn.image import load_img, math_img, new_img_like, get_data, resample_to_img, index_img
from nilearn.regions import connected_regions
from nilearn.reporting import get_clusters_table
from scipy.stats import zscore
from scipy.stats import spearmanr, pearsonr, rankdata
from scipy.ndimage import label
import matplotlib.pylab as plt
from seaborn import regplot
import logging
lgr = logging.getLogger(__name__)
lgr.setLevel(logging.INFO)
def index_clusters_bin(img):
"""
Extracts clusters from a binarized volume and assigns indexes in the
order of cluster size to all voxels within each cluster
(1=largest to n=smallest)
Input: img=binarized volume, path or volume
Output: volume with indexes from 1 to n(clusters), 4D volume,
pd.df with sizes
"""
# load volume, binarize and remove nan (just in case)
img = math_img('np.nan_to_num(img) > 0', img=img)
# get clusters and store in separate volumes within 4D nifti
img4D, _ = connected_regions(img, min_region_size=1,
extract_type='connected_components')
# get 4D data
dat4D = get_data(img4D)
# get region sizes
sizes = list()
for i in range(dat4D.shape[3]): # iterate over fourth dimension of data
dat3D = dat4D[:,:,:,i] # get 3D array
sizes.append(len(dat3D[dat3D == 1])) # get and store size
# sort sizes, dataframe index denotes position in 4D volume
sizes = pd.DataFrame(sizes, columns=['size'])
sizes.sort_values('size', ascending=False, inplace=True)
# create new 3D array
dat = np.zeros(dat4D.shape[:3])
dat_4d = np.zeros(dat4D.shape)
for idx_new, idx_old in enumerate(sizes.index, start=1):
dat = dat + dat4D[:,:,:,idx_old] * idx_new
dat_4d[:,:,:,idx_new-1] = dat4D[:,:,:,idx_old]
# write into new volumes and return
img_idx = new_img_like(img, dat, copy_header=False)
img_4d = new_img_like(img4D, dat_4d, copy_header=False)
return(img_idx, img_4d, sizes)
#=============================================================================
def partialcorr3(x, y, z, method="pearson"):
"""Computes partial correlation between {x} and {y} controlled for {z}
Args:
x (array-like): input vector 1
y (array-like): input vector 2
z (array-like): input vector to be controlled for
corrtype (str, optional): "spearman" or "pearson. Defaults to "pearson".
Returns:
R (float): (ranked) partial correlation coefficient between x and y
"""
C = np.column_stack((x, y, z))
if method=="spearman":
C = rankdata(C, axis=0)
corr = np.corrcoef(C, rowvar=False) # Pearson product-moment correlation coefficients.
corr_inv = np.linalg.inv(corr) # the (multiplicative) inverse of a matrix.
P = -corr_inv[0,1] / (np.sqrt(corr_inv[0,0] * corr_inv[1,1]))
return P
def correlate_volumes_via_atlas(x_img, y_img, atlas, adjust_img, method='spearman',
colors=None, labels=None, pr=True, pl=True):
"""
Correlates two volumes using roi wise averaged values defined by a given
parcellation. If {adjust_img} is given and {method}=='partial...', will
calculate partial correlations.
Can print correlation coefficients and a scatter plot.
Input:
img1/2 = volumes to correlate
atlas = parcellation
adjust_img = data to be adjusted for when calculating partial correlations
method = 'spearman', 'pearson', 'partialspearman', 'partialpearson'
labels = labels to plot over each marker
colors = colors of each point, must be 1D array with color for each
marker
pr, pl = print correlation coefficient, plot scatter plot
Output: correlation coefficient, results pd.dataframe
"""
# resample volumes to atlas space
i1 = resample_to_img(x_img, atlas)
i2 = resample_to_img(y_img, atlas)
# get roi-wise data
masker = NiftiLabelsMasker(atlas)
i1_dat = masker.fit_transform(i1)[0]
i2_dat = masker.fit_transform(i2)[0]
if adjust_img:
i3 = resample_to_img(adjust_img, atlas)
i3_dat = masker.fit_transform(i3)[0]
# correlate
if method == 'spearman':
r, p = spearmanr(i1_dat, i2_dat, axis=1)
if pr is True: print(f'Spearman`s r = {round(r,2)}')
elif method == 'pearson':
r, p = pearsonr(i1_dat, i2_dat)
if pr is True: print(f'Pearson`s r = {round(r,2)}')
elif method == 'partialspearman':
r = partialcorr3(i1_dat, i2_dat, i3_dat, corrtype='spearman')
if pr is True: print(f'Partial Spearman`s r = {round(r,2)}')
elif method == 'partialpearson':
r = partialcorr3(i1_dat, i2_dat, i3_dat, corrtype='pearson')
if pr is True: print(f'Partial Pearson`s r = {round(r,2)}')
else:
lgr.error(f'Method "{method}" not defined!')
# plot
if pl is True:
if colors is None:
regplot(x=i1_dat, y=i2_dat)
if colors is not None:
regplot(x=i1_dat, y=i2_dat, color='black',
scatter_kws={'facecolors':colors})
if labels is not None:
for i, l in enumerate(labels):
plt.text(i1_dat[i], i2_dat[i], l)
if not adjust_img:
return(r, pd.DataFrame({'idx': list(range(1,len(i1_dat)+1)),
'dat1': i1_dat,
'dat2': i2_dat,
'label': labels}))
else:
return(r, pd.DataFrame({'idx': list(range(1,len(i1_dat)+1)),
'dat1': i1_dat,
'dat2': i2_dat,
'dat3': i3_dat,
'label': labels}))
#=============================================================================
def get_cluster_stats(img_thresh):
"""
Extracts connected components (clusters), cluster sizes, and cluster masses
from a thresholded image.
Input:
thresholded image as string, nifti, or numpy array
Output:
img_labels: array with labelled clusters (1 to n)
clust_labels: array with cluster labels
clust_sizes: array with cluster sizes
clust_masses: array with cluster masses (sum of in-cluster voxel values)
"""
# get data
if isinstance(img_thresh, (nb.nifti1.Nifti1Image, str)):
img_thresh_dat = load_img(img_thresh).get_fdata()
elif isinstance(img_thresh, np.ndarray):
img_thresh_dat = img_thresh
# label clusters
img_labels, n_labels = label(img_thresh_dat)
# get cluster sizes
clust_labels, clust_sizes = np.unique(img_labels, return_counts=True)
clust_sizes = clust_sizes[clust_labels!=0]
clust_labels = clust_labels[clust_labels!=0]
# get cluster masses
clust_masses = np.zeros(clust_sizes.shape)
for i, c in enumerate(clust_labels):
clust_masses[i] = np.sum(img_thresh_dat[img_labels==c])
return img_labels, clust_labels, clust_sizes, clust_masses
#=============================================================================
def get_size_of_rois(img):
"""
Extracts a list of roi sizes from a volume or numpy array with indexed
clusters
Input: cluster volume or numpy array
Output: pandas dataframe with sizes
"""
if type(img) is nb.nifti1.Nifti1Image:
dat = get_data(img) # get data matrix
elif type(img) is np.ndarray:
dat = img # input is data matrix
idx = np.unique(dat) # get roi indices
idx = idx[idx != 0] # drop zero from list
# iterate over roi list and extract roi sizes
sizes_list = list()
for i in idx:
sizes_list.append(len(dat[dat == i]))
# create df and return
sizes = pd.DataFrame({'idx':idx, 'size':sizes_list})
return(sizes)
#=============================================================================
def drop_atlasregions_by_size(img, threshold):
"""
Drops regions from a brain parcellation based on their voxel-size.
Input: img=parcellation volume, threshold=size threshold
Output: volume without sub-threshold sized regions
"""
# get sizes of rois in volume and rois to keep after thresholding
sizes = get_size_of_rois(img)
sizes_nogo = sizes[sizes['size'] <= threshold]
sizes_go = sizes[sizes['size'] > threshold]
# get data and set all voxels in regions below threshold to zero
dat = get_data(img)
for i in sizes_nogo['idx']:
dat[dat == i] = 0
# write into new volume and return
img_drop = new_img_like(img, dat, copy_header=False)
lgr.info(f'Kept {len(sizes_go)} regions with sizes > {threshold} voxels.')
return(img_drop, sizes_go)
#=============================================================================
def z_norm_vol(in_file, out_path=None, mask=None):
"""
Z-normalizes 3D volume to mean and sd of all non-zero voxels (default)
or all voxels included in 'mask', if given.
Input:
in_file = input volume, nifti object or file path
out_path = full path to store file, defaults to current path or path
of input volume (if input volume given as file path)
mask = binarized mask, if None, determined from background intensity of
in_file
Output:
z normalized volume
"""
img = math_img('np.nan_to_num(img)', img=in_file) # remove nan, just in case
masker = NiftiMasker(mask_img=mask, standardize=False, standardize_confounds=False)
img_data = masker.fit_transform(img)
img_data_z = zscore(img_data, nan_policy='omit')
img_z = masker.inverse_transform(img_data_z)
if out_path is None and type(in_file) == str:
# save as /(in_file)_z.nii.gz
path,_ = splitext(in_file)
img_z.to_filename(path+'_z.nii.gz')
elif out_path is None and type(in_file) != str:
# save as cwd/z.nii.gz
img_z.to_filename(join(getcwd(), 'z.nii.gz'))
else:
# save as out_path
img_z.to_filename(out_path)
return(img_z)
#=============================================================================
def get_cluster_peaks(vol_file, z=True, tab_save=False, peak_dist=10,
vthr=2.5, cthr=None, cthr_prc=0.1):
"""
Extracts clusters and peak coordinates based on given voxel and cluster
thresholds. Cluster threshold either determined directly (e.g., p < 0.05)
or via percent of non-zero-voxels (e.g., 0.1 %).
Input:
vol_file = input volume, nifti object or file path
z = if True, input volume is normalized to mean and sd of all
non-zero voxels
tab_save = if True, stores table with results at path of input volume,
only possible if input volume is given as file path
peak_dist = minimum distance between reported subpeaks in mm
vthr, cthr = voxel- and cluster-level thresholds
if cthr = None, cthr is set to cthr_prc % of non-zero voxels
Output: results table and cluster-level threshold
"""
vol = math_img('np.nan_to_num(img)', img=vol_file) # remove nan, just in case
if z == True: # z-normalize volume and get in-brain mask
vol, inbrain = z_norm_vol(vol_file)
elif z == False: # get original volume and in-brain mask
inbrain = math_img('img != 0', img=vol)
if cthr is None:
n_vox = np.sum(inbrain.get_fdata()) # get number of inbrain voxels
cthr = round(n_vox * cthr_prc / 100) # cluster threshold based on %
# get cluster table
tab = get_clusters_table(stat_img=vol, stat_threshold=vthr,
cluster_threshold=cthr, min_distance=peak_dist)
if tab_save == True:
path,_ = splitext(vol_file) # get path of volume
if z == True:
tab.to_csv(path+'_z_clusters.csv') # save table with suffix
elif z == False:
tab.to_csv(path+'_clusters.csv') # save table with suffix
else:
tab.to_csv(tab_save) # save table as tab_save
print(f'Found {len(tab)} peaks, thresholding: voxel-value > {vthr}; '
f'cluster-size > {cthr}')
print(tab)
return(tab, cthr)
#=============================================================================
def parcel_data_to_volume(data, atlas, save_path=None, rank=False):
"""
Writes parcellated data into nifti volume according to shape and labels
of 'atlas'. If rank==True, parcel-values will be ranks of values in data.
"""
# atlas data
a = load_img(atlas) # load atlas
a_dat = a.get_fdata() # get atlas data
idc = np.unique(a_dat) # get atlas indices
idc = idc[idc!=0] # remove zero from indices
# check input data
if len(idc) != len(data):
lgr.error('Input array not the same length as atlas has indices!')
# new volume
dat = np.zeros(np.shape(a_dat)) # make zero 3d array
# replace data with ranks of data
if rank:
data_ranks = [sorted(data).index(x) + 1 for x in data]
data = data_ranks
# create volume
for i, idx in enumerate(idc):
dat[a_dat==idx] = data[i] # write data in array
vol = new_img_like(a, dat)
# save & return
if save_path is not None:
vol.to_filename(save_path)
return(vol, save_path)
#=============================================================================
def combine_atlases(atlases, renumber=True, target_space=None):
"""
Combines Parcellations, either by adding them up or renumbering parcel
indices from 1 to n(rois) starting with the first input atlas. If parcels
from different atlases overlap, the nth atlas "overwrites" the n-1th.
Input:
atlases = list of atlases
renumber = if True, renumber parcel indices, if False, add atlases up
target_space = space to resample atlases to, default to mni152_2mm
Output: new atlas volume
"""
# fetch mni152_2mm as default space
if target_space is None:
target_space = load_mni152_template()
dat = np.zeros(target_space.shape) # empty matrix
n_rois = 0 # starting point for number of rois
# loop over atlases
for atlas in atlases:
# load image data and first frame, if 4D volume
a = load_img(atlas)
if len(a.shape) == 4:
a = index_img(a, 0)
# resample to first volume and get data
a = resample_to_img(a, target_space, interpolation='nearest')
a_dat = get_data(a)
# get parcel indices
idc = np.unique(a_dat) # get indices
idc = idc[idc != 0] # remove zero
# loop over indices
for i_idx, idx in enumerate(idc, start=n_rois+1):
if renumber == True: # renumber atlas from n_rois+1 to n_rois+n(idc)
dat[a_dat == idx] = i_idx # new index
if renumber == False: # use orginal parcel indices
dat[a_dat == idx] = idx # original index
n_rois = n_rois + len(idc) # save current largest index
# write new volume
combined = new_img_like(target_space, dat.round(0))
return(combined)