-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_few_shot.py
executable file
·98 lines (83 loc) · 3.49 KB
/
prepare_few_shot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
class FewShotSetCreator:
def __init__(
self,
dataset_name="wikidata5m_v3_semi_inductive",
split="valid",
use_inverse=False,
context_selection="most_common",
):
self.dataset_name = dataset_name
self.split = split
self.use_inverse = use_inverse
self.context_selection=context_selection
if self.use_inverse:
train_triples = self._load_train_triples()
self.num_relations = len(np.unique(train_triples[:, 1]))
self.triple_pool = self._load_triple_pool()
def _load_train_triples(self):
triples = pd.read_csv(
os.path.join("data", self.dataset_name, "train.del"),
delimiter="\t",
header=None
).to_numpy()
return triples
def _load_triple_pool(self):
triple_pool = pd.read_csv(
os.path.join("data", self.dataset_name, f"{self.split}_pool.del"),
delimiter="\t",
header=None
).to_numpy()
if self.use_inverse:
triple_inverse_pool = np.copy(triple_pool)
triple_inverse_pool[:, 3] += self.num_relations
triple_inverse_pool[:, 2] = np.copy(triple_pool[:, 4])
triple_inverse_pool[:, 4] = np.copy(triple_pool[:, 2])
relevant_triple_pool = triple_pool[triple_pool[:, 1] == 0]
relevant_inverse_triple_pool = triple_inverse_pool[triple_inverse_pool[:, 1] == 2]
relevant_inverse_triple_pool[:, 1] = 0
triple_pool = np.concatenate((relevant_triple_pool, relevant_inverse_triple_pool))
return triple_pool
def create_few_shot_dataset(self, num_shots):
print(f"create few shot set for {self.split} with {num_shots} shots")
# convert to torch as split behaves differently in torch compared to numpy
triples_per_entity = [t.numpy() for t in torch.from_numpy(self.triple_pool).split(11)]
eval_list = list()
for tpe in tqdm(triples_per_entity):
for i in range(len(tpe)):
mask = np.ones((len(tpe,)), dtype=np.bool)
mask[i] = False
triple = tpe[i]
context = tpe[mask]
if self.context_selection == "most_common":
context = context[:num_shots]
elif self.context_selection == "least_common":
context = context[-num_shots:]
elif self.context_selection == "random":
# numpy shuffles in place
np.random.shuffle(context)
context = context[:num_shots]
eval_dict = {
"unseen_entity": triple[0].item(),
"unseen_slot": triple[1].item(),
"triple": triple[2:],
"context": context
}
eval_list.append(eval_dict)
return eval_list
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", "-d", type=str, default="wikidata5m_v3_semi_inductive")
parser.add_argument("--split", "-s", type=str, default="valid")
parser.add_argument("--num_shots", "-k", type=int, default=10)
args = parser.parse_args()
few_shot_set_creator = FewShotSetCreator(
dataset_name=args.dataset,
split=args.split,
)
few_shot_set_creator.create_few_shot_dataset(args.num_shots)