-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
112 lines (106 loc) · 4.97 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from datetime import datetime
import os
import pandas as pd
def get_classifier(dataset, target_relation, num_classes, batch_size, embedding_model_path,
classifier_type='mlp',
**model_kwargs):
"""
Return a classifier that will classify the tails for the target relations
Currently only MLP classifier is implemented, but can look into others
dataset: pykeen.Dataset, knowledge graph dataset e.g fb15k-237
target_relation: str,
num_classes: int, number of classification classes
embedding_model_path: str, path to trained embedding model using pykeen library
"""
if classifier_type == 'mlp':
from classifier import TargetRelationClassifier
return TargetRelationClassifier(dataset=dataset,
embedding_model_path=embedding_model_path,
target_relation=target_relation,
num_classes=num_classes,
batch_size=batch_size,
**model_kwargs
)
elif classifier_type == 'rf':
from classifier import RFRelationClassifier
return RFRelationClassifier(dataset=dataset,
embedding_model_path=embedding_model_path,
target_relation=target_relation,
num_classes=num_classes,
batch_size=batch_size,
max_depth=3,
class_weight='balanced',
max_features='auto',
**model_kwargs
)
def suggest_relations(dataset):
"""
Suggest a list of relations to detect bias based on knowledge graph datasets.
dataset: pykeen.Dataset, knowledge graph dataset e.g fb15k-237
"""
if dataset.lower() == "fb15k237":
target_relation = '/people/person/profession'
bias_relations = ['/people/person/gender',
'/people/person/languages',
'/people/person/nationality',
'/people/person/profession',
'/people/person/places_lived./people/place_lived/location',
'/people/person/spouse_s./people/marriage/type_of_union',
'/people/person/religion'
#/people/person/place_of_birth - top have 14, 13, 9, 4, 4, 3,3..
]
elif dataset.lower() == "wikidata":
target_relation = 'P21'
bias_relations = ['P102', 'P106', 'P169']
elif dataset.lower() == "wiki5m":
target_relation = 'P106'
bias_relations = ['P27', 'P735', 'P19', 'P54', 'P69', 'P641', 'P20', 'P1344', 'P1412', 'P413']
return target_relation, bias_relations
def save_result(result, dataset, args):
"""
Save dataset summary, and output from Evaluator
result: dict, bias evaluation result
dataset: pykeen.Dataset, knowledge graph dataset e.g fb15k-237
args: arguments passed when running main program
"""
if args.embedding_path:
embedding = os.path.splitext(os.path.split(args.embedding_path)[-1])[0]
else:
embedding = args.embedding
date = datetime.now().strftime("%Y%m%d%H%M")
dir = os.path.join("./results/", args.dataset+"_"+embedding+"_"+date)
if not os.path.exists(dir):
os.makedirs(dir)
# Save Dataset Summary
with open(os.path.join(dir, args.dataset+".txt"), 'w') as f: # save dataset summary
f.writelines(dataset.summary_str())
#TODO: save embedding training configuration?
for k in result.keys():
measure_dir = os.path.join(dir, k)
os.mkdir(measure_dir)
if isinstance(result[k], pd.DataFrame):
save_path = os.path.join(measure_dir, "{}.csv".format(k))
print("Save to {}".format(save_path))
result[k].to_csv(save_path)
elif isinstance(result[k], dict):
for rel in result[k].keys():
df = pd.DataFrame(result[k][rel])
rel = rel.split('/')[-1] if args.dataset == 'fb15k237' else rel
save_path = os.path.join(measure_dir, "{}_{}.csv".format(k,rel))
print("Save to {}".format(save_path))
df.to_csv(save_path)
def remove_infreq_attributes(attr_counts, key, threshold=10,nan_val=-1):
if attr_counts[key] <= threshold:
return nan_val
return key
def requires_preds_df(bias_measures):
"""
:param bias_measures: a list of bias metrics
:return: bool, True if we need a preds dataframe and False if not
"""
require_preds_df = False
for m in bias_measures:
if m.require_preds_df:
require_preds_df = True
break
return require_preds_df