-
Notifications
You must be signed in to change notification settings - Fork 0
/
slf_attn.py
108 lines (80 loc) · 3.68 KB
/
slf_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from torch import nn
import torch.nn.functional as f
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nn_Softargmax = nn.Softmax # fix wrong name
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, p, d_input=None):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
if d_input is None:
d_xq = d_xk = d_xv = d_model
else:
d_xq, d_xk, d_xv = d_input
# Make sure that the embedding dimension of model is a multiple of number of heads
assert d_model % self.num_heads == 0
self.d_k = d_model // self.num_heads
# These are still of dimension d_model. They will be split into number of heads
self.W_q = nn.Linear(d_xq, d_model, bias=False)
self.W_k = nn.Linear(d_xk, d_model, bias=False)
self.W_v = nn.Linear(d_xv, d_model, bias=False)
# Outputs of all sub-layers need to be of dimension d_model
self.W_h = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V):
batch_size = Q.size(0)
k_length = K.size(-2)
# Scaling by d_k so that the soft(arg)max doesnt saturate
Q = Q / np.sqrt(self.d_k) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(Q, K.transpose(2,3)) # (bs, n_heads, q_length, k_length)
A = nn_Softargmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
# Get the weighted average of the values
H = torch.matmul(A, V) # (bs, n_heads, q_length, dim_per_head)
return H, A
def split_heads(self, x, batch_size):
"""
Split the last dimension into (heads X depth)
Return after transpose to put in shape (batch_size X num_heads X seq_length X d_k)
"""
return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
def group_heads(self, x, batch_size):
"""
Combine the heads again to get (batch_size X seq_length X (num_heads times d_k))
"""
return x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
def forward(self, X_q, X_k, X_v):
batch_size, seq_length, dim = X_q.size()
# After transforming, split into num_heads
Q = self.split_heads(self.W_q(X_q), batch_size) # (bs, n_heads, q_length, dim_per_head)
K = self.split_heads(self.W_k(X_k), batch_size) # (bs, n_heads, k_length, dim_per_head)
V = self.split_heads(self.W_v(X_v), batch_size) # (bs, n_heads, v_length, dim_per_head)
# Calculate the attention weights for each of the heads
H_cat, A = self.scaled_dot_product_attention(Q, K, V)
# Put all the heads back together by concat
H_cat = self.group_heads(H_cat, batch_size) # (bs, q_length, dim)
# Final linear layer
H = self.W_h(H_cat) # (bs, q_length, dim)
return H, A
temp_mha = MultiHeadAttention(d_model=512, num_heads=8, p=0)
def print_out(Q, K, V):
temp_out, temp_attn = temp_mha.scaled_dot_product_attention(Q, K, V)
print(temp_attn.shape)
print('Attention weights are:', temp_attn.squeeze())
print('Output is:', temp_out.squeeze())
test_K = torch.tensor(
[[10, 0, 0],
[ 0,10, 0],
[ 0, 0,10],
[ 0, 0,10]]
).float()[None,None]
test_V = torch.tensor(
[[ 1,0,0],
[ 10,0,0],
[ 100,5,0],
[1000,6,0]]
).float()[None,None]
test_Q = torch.tensor(
[[0, 0, 10], [0, 10, 0], [10, 10, 0]]
).float()[None,None]
print_out(test_Q, test_K, test_V)