Skip to content

Commit

Permalink
Add rescale_z, rescale_non_lin
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Feb 6, 2024
1 parent 22e7910 commit 2ae0ac3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 27 deletions.
36 changes: 21 additions & 15 deletions dwi_ml/testing/projects/tt_visu_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,15 @@ def build_argparser_transformer_visu():
gg = g.add_mutually_exclusive_group()
gg.add_argument('--rescale_0_1', action='store_true',
help="If set, rescale to max 1 per row. X = X/max(row)")
gg.add_argument('--rescale_dev', action='store_true',
help="If set, transform each value to X = (X - mu) / mu, "
"where mu=mean(row).\nThis acts as a modified "
"z-score / std, but \n"
"- When all values are equal, z-score is not "
"defined. This value is."
"- Compensates for the increasing length of the "
"rows.")
gg.add_argument('--rescale_z', action='store_true',
help="If set, transform each value to X = (X - mu) / std, "
"where \nmu and std are computed per row.")
gg.add_argument('--rescale_non_lin', action='store_true',
help="If set, transform each value so that values below "
"the equal \nattention are transformed to [0-0.5], "
"and values above to [0.5, 1].\n"
"(Ex: At point #3, 0.33 --> 0.5. At point #40, "
"0.025 --> 0.5.")

g = p.add_argument_group("Options defining how to deal with the heads")
gg = g.add_mutually_exclusive_group()
Expand Down Expand Up @@ -315,14 +316,16 @@ def tt_visualize_weights_main(args, parser):

average_heads = args.group_heads or args.group_all
average_layers = args.group_all
if args.group_with_max and not (args.rescale_0_1 or args.rescale_dev):
if args.group_with_max and not \
(args.rescale_0_1 or args.rescale_z or args.rescale_non_lin):
parser.error("--group_with_max is expected to be used together with "
"a rescaling option.")
visu_encoder_decoder(
weights, sft, model.direction_getter.add_eos, average_heads,
average_layers, args.group_with_max, args.resample_attention,
args.rescale_0_1, args.rescale_dev, save_colored_sft, run_bertviz,
show_as_matrices, prefix_total, has_decoder=has_decoder)
args.rescale_0_1, args.rescale_z, args.rescale_non_lin,
save_colored_sft, run_bertviz, show_as_matrices, prefix_total,
has_decoder=has_decoder)

if args.show_now:
plt.show()
Expand All @@ -331,7 +334,8 @@ def tt_visualize_weights_main(args, parser):
def visu_encoder_decoder(
weights: Tuple, sft: StatefulTractogram, has_eos: bool,
average_heads: bool, average_layers: bool, group_with_max: bool,
resample_nb: int, rescale_0_1: bool, rescale_dev: bool,
resample_nb: int, rescale_0_1: bool, rescale_z: bool,
rescale_non_lin: bool,
save_colored_sft: bool, run_bertviz: bool, show_as_matrices: bool,
prefix_name: str, has_decoder: bool = True):
"""
Expand Down Expand Up @@ -359,8 +363,10 @@ def visu_encoder_decoder(
Number of values to resample matrices
rescale_0_1: bool,
If true, rescale each line of the matrix using X / max(X).
rescale_dev: bool
If true, rescale each line of the matrix using (X-mu)/mu.
rescale_z: bool
If true, rescale each line of the matrix using (X-mu)/std
rescale_non_lin: bool
If true, rescale each line of the matrix to [0 - 0.5] and [0.5 - 1]
save_colored_sft: bool
run_bertviz: bool
For now, on one streamline.
Expand All @@ -382,7 +388,7 @@ def visu_encoder_decoder(
for i in range(len(weights)):
weights[i] = reshape_unpad_rescale_attention(
weights[i], average_heads, average_layers, group_with_max,
lengths, rescale_0_1, rescale_dev)
lengths, rescale_0_1, rescale_z, rescale_non_lin)

if has_decoder:
attention_names = ('encoder',)
Expand Down
65 changes: 53 additions & 12 deletions dwi_ml/testing/projects/tt_visu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def reshape_unpad_rescale_attention(
attention_per_layer, average_heads: bool, average_layers,
group_with_max, lengths, rescale_0_1, rescale_dev):
group_with_max, lengths, rescale_0_1, rescale_z, rescale_non_lin):
"""
Also sends to CPU.
Expand All @@ -23,7 +23,12 @@ def reshape_unpad_rescale_attention(
group_with_max: bool
lengths: List[int]
Unpadded lengths of the streamlines.
rescale_0_1: bool
rescale_0_1: bool,
If true, rescale each line of the matrix using X / max(X).
rescale_z: bool
If true, rescale each line of the matrix using (X-mu)/mu.
rescale_non_lin: bool
If true, rescale each line of the matrix to [0 - 0.5] and [0.5 - 1]
Returns
-------
Expand All @@ -34,23 +39,20 @@ def reshape_unpad_rescale_attention(
Where nheads=1 if average_heads.
"""
if rescale_0_1:
logging.info(
"We will normalize the attention: per row, to the range [0, 1]: \n"
" The attention when deciding the next direction at point N \n"
" is distributed in the N first points of the streamline such\n"
" that the point with most attention has value 1. "
"(att = att/max)")
print("Rescaling between 0-1: X = X/ max(row)")
elif rescale_z:
print("Rescaling using X = (X-mu) / mu.")

# 1. To numpy. Possibly average heads.
nb_layers = len(attention_per_layer)
for ll in range(nb_layers):
for layer in range(nb_layers):
# To numpy arrays
attention_per_layer[ll] = attention_per_layer[ll].cpu().numpy()
attention_per_layer[layer] = attention_per_layer[layer].cpu().numpy()

# Averaging heads (but keeping 4D).
if average_heads and not group_with_max:
attention_per_layer[ll] = np.mean(attention_per_layer[ll],
axis=1, keepdims=True)
attention_per_layer[layer] = np.mean(attention_per_layer[layer],
axis=1, keepdims=True)

# Possibly average layers (but keeping as list)
if average_layers and not group_with_max:
Expand All @@ -73,7 +75,46 @@ def reshape_unpad_rescale_attention(
# Normalizing each row so that max value = 1.
# Axis=2: Horizontally for each matrix.
if rescale_0_1:
# Rescale [0, max] --> [0, 1]
line_att = line_att / np.max(line_att, axis=2, keepdims=True)
elif rescale_z:
# Mu is not np.mean here. Ignoring future values.
# Expected values for mu = [1, 1/2, 1/3, etc.]
mask = np.ones((line_att.shape[1], line_att.shape[2]))
mask = np.tril(mask)[None, :, :]

nb = np.arange(line_att.shape[1]) + 1

mu = np.sum(line_att, axis=2) / nb[None, :]
mu = mu[:, :, None]

std = (line_att - mu)**2
std = np.sqrt(np.sum(std * mask, axis=2) / nb[None, :])
std = std[:, :, None]

line_att = (line_att - mu) / np.maximum(std, 1e-6)

# Back to triu
line_att = line_att * mask
elif rescale_non_lin:
nb = np.arange(line_att.shape[1]) + 1
mean = 1/nb[None, :, None]

where_below = line_att <= mean
where_above = ~where_below

# Rescale [0, mean] --> [0, 0.5]
# x = ( (x-0) / (mean - 0) )*0.5
tmp1 = 0.5 * (line_att / mean)

# Rescale [mean, 1] --> [0.5, 1]
# x = 0.5 + ( (x - mean) / (1-mean) ) * 0.5
# But (1 - mean) creates an error for first point. It does not
# belong to where_above so we don't care about this value.
mean[:, 0, :] = 10
tmp2 = 0.5 + 0.5 * ((line_att - mean) / (1.0 - mean))

line_att = tmp1 * where_below + tmp2 * where_above

if average_heads and group_with_max:
line_att = np.max(line_att, axis=0, keepdims=True)
Expand Down

0 comments on commit 2ae0ac3

Please sign in to comment.