forked from uma-pi1/kge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
projection_embedder.py
65 lines (55 loc) · 2.31 KB
/
projection_embedder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch.nn
import torch.nn.functional
from kge.model import KgeEmbedder
class ProjectionEmbedder(KgeEmbedder):
"""Adds a linear projection layer to a base embedder."""
def __init__(
self, config, dataset, configuration_key, vocab_size, init_for_load_only=False
):
super().__init__(
config, dataset, configuration_key, init_for_load_only=init_for_load_only
)
# initialize base_embedder
if self.configuration_key + ".base_embedder.type" not in config.options:
config.set(
self.configuration_key + ".base_embedder.type",
self.get_option("base_embedder.type"),
)
self.base_embedder = KgeEmbedder.create(
config, dataset, self.configuration_key + ".base_embedder", vocab_size
)
# initialize projection
if self.dim < 0:
self.dim = self.base_embedder.dim
self.dropout = self.get_option("dropout")
self.regularize = self.check_option("regularize", ["", "lp"])
self.projection = torch.nn.Linear(self.base_embedder.dim, self.dim, bias=False)
if not init_for_load_only:
self.initialize(self.projection.weight.data)
def _embed(self, embeddings):
embeddings = self.projection(embeddings)
if self.dropout > 0:
embeddings = torch.nn.functional.dropout(
embeddings, p=self.dropout, training=self.training
)
return embeddings
def embed(self, indexes):
return self._embed(self.base_embedder.embed(indexes))
def embed_all(self):
return self._embed(self.base_embedder.embed_all())
def penalty(self, **kwargs):
# TODO factor out to a utility method
if self.regularize == "" or self.get_option("regularize_weight") == 0.0:
result = []
elif self.regularize == "lp":
p = self.get_option("regularize_args.p")
result = [
(
f"{self.configuration_key}.L{p}_penalty",
self.get_option("regularize_weight")
* self.projection.weight.norm(p=p).sum(),
)
]
else:
raise ValueError("unknown penalty")
return super().penalty(**kwargs) + result + self.base_embedder.penalty(**kwargs)