-
Notifications
You must be signed in to change notification settings - Fork 0
/
decoder.py
127 lines (114 loc) · 5.37 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import random
from torch.nn import functional as F
import torch
from torch.nn.parameter import Parameter
import math
import os
path_dir = os.getcwd()
class ConvTransR(torch.nn.Module):
def __init__(self, num_relations, embedding_dim, input_dropout=0, hidden_dropout=0, feature_map_dropout=0, channels=50, kernel_size=3, use_bias=True):
super(ConvTransR, self).__init__()
self.inp_drop = torch.nn.Dropout(input_dropout)
self.hidden_drop = torch.nn.Dropout(hidden_dropout)
self.feature_map_drop = torch.nn.Dropout(feature_map_dropout)
self.loss = torch.nn.BCELoss()
self.conv1 = torch.nn.Conv1d(2, channels, kernel_size, stride=1,
padding=int(math.floor(kernel_size / 2))) # kernel size is odd, then padding = math.floor(kernel_size/2)
self.bn0 = torch.nn.BatchNorm1d(2)
self.bn1 = torch.nn.BatchNorm1d(channels)
self.bn2 = torch.nn.BatchNorm1d(embedding_dim)
self.register_parameter('b', Parameter(torch.zeros(num_relations*2)))
self.fc = torch.nn.Linear(embedding_dim * channels, embedding_dim)
self.bn3 = torch.nn.BatchNorm1d(embedding_dim)
# self.bn4 = torch.nn.BatchNorm1d(Config.embedding_dim)
self.bn_init = torch.nn.BatchNorm1d(embedding_dim)
def forward(self, embedding, emb_rel, triplets, nodes_id=None, mode="train", negative_rate=0):
e1_embedded_all = torch.tanh(embedding)
batch_size = len(triplets)
# if mode=="train":
e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)
e2_embedded = e1_embedded_all[triplets[:, 2]].unsqueeze(1)
# else:
# e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)
# e2_embedded = e1_embedded_all[triplets[:, 2]].unsqueeze(1)
stacked_inputs = torch.cat([e1_embedded, e2_embedded], 1)
stacked_inputs = self.bn0(stacked_inputs)
x = self.inp_drop(stacked_inputs)
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.feature_map_drop(x)
x = x.view(batch_size, -1)
x = self.fc(x)
x = self.hidden_drop(x)
x = self.bn2(x)
x = F.relu(x)
x = torch.mm(x, emb_rel.transpose(1, 0))
return x
class ConvTransE(torch.nn.Module):
def __init__(self, num_entities, embedding_dim, input_dropout=0, hidden_dropout=0, feature_map_dropout=0, channels=50, kernel_size=3, use_bias=True):
super(ConvTransE, self).__init__()
# 初始化relation embeddings
# self.emb_rel = torch.nn.Embedding(num_relations, embedding_dim, padding_idx=0)
self.inp_drop = torch.nn.Dropout(input_dropout)
self.hidden_drop = torch.nn.Dropout(hidden_dropout)
self.feature_map_drop = torch.nn.Dropout(feature_map_dropout)
self.loss = torch.nn.BCELoss()
self.conv1 = torch.nn.Conv1d(2, channels, kernel_size, stride=1,
padding=int(math.floor(kernel_size / 2))) # kernel size is odd, then padding = math.floor(kernel_size/2)
self.bn0 = torch.nn.BatchNorm1d(2)
self.bn1 = torch.nn.BatchNorm1d(channels)
self.bn2 = torch.nn.BatchNorm1d(embedding_dim)
self.register_parameter('b', Parameter(torch.zeros(num_entities)))
self.fc = torch.nn.Linear(embedding_dim * channels, embedding_dim)
self.bn3 = torch.nn.BatchNorm1d(embedding_dim)
# self.bn4 = torch.nn.BatchNorm1d(Config.embedding_dim)
self.bn_init = torch.nn.BatchNorm1d(embedding_dim)
def forward(self, embedding, emb_rel, triplets, nodes_id=None, mode="train", negative_rate=0, partial_embeding=None):
e1_embedded_all = torch.tanh(embedding)
batch_size = len(triplets)
e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)
rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)
stacked_inputs = torch.cat([e1_embedded, rel_embedded], 1)
stacked_inputs = self.bn0(stacked_inputs)
x = self.inp_drop(stacked_inputs)
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.feature_map_drop(x)
x = x.view(batch_size, -1)
x = self.fc(x)
x = self.hidden_drop(x)
if batch_size > 1:
x = self.bn2(x)
x = F.relu(x)
if partial_embeding is None:
x = torch.mm(x, e1_embedded_all.transpose(1, 0))
else:
x = torch.mm(x, partial_embeding.transpose(1, 0))
return x
def forward_slow(self, embedding, emb_rel, triplets):
e1_embedded_all = torch.tanh(embedding)
# e1_embedded_all = embedding
batch_size = len(triplets)
e1_embedded = e1_embedded_all[triplets[:, 0]].unsqueeze(1)
# translate to sub space
# e1_embedded = torch.matmul(e1_embedded, sub_trans)
rel_embedded = emb_rel[triplets[:, 1]].unsqueeze(1)
stacked_inputs = torch.cat([e1_embedded, rel_embedded], 1)
stacked_inputs = self.bn0(stacked_inputs)
x = self.inp_drop(stacked_inputs)
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.feature_map_drop(x)
x = x.view(batch_size, -1)
x = self.fc(x)
x = self.hidden_drop(x)
if batch_size > 1:
x = self.bn2(x)
x = F.relu(x)
e2_embedded = e1_embedded_all[triplets[:, 2]]
score = torch.sum(torch.mul(x, e2_embedded), dim=1)
pred = score
return pred