-
Notifications
You must be signed in to change notification settings - Fork 19
/
utils.py
47 lines (39 loc) · 1.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import math
import numpy as np
def l2_dist(q: np.ndarray, x: np.ndarray):
assert len(q.shape) == 2
assert len(x.shape) == 2
assert q.shape[1] == q.shape[1]
x = x.T
sqr_q = np.sum(q ** 2, axis=1, keepdims=True)
sqr_x = np.sum(x ** 2, axis=0, keepdims=True)
l2 = sqr_q + sqr_x - 2 * q @ x
l2[ np.nonzero(l2 < 0) ] = 0.0
return np.sqrt(l2)
def arg_sort(q, x):
dists = l2_dist(q, x)
return np.argsort(dists)
def intersect(gs, ids):
return np.mean([len(np.intersect1d(g, list(id))) for g, id in zip(gs, ids)])
def intersect_sizes(gs, ids):
return np.array([len(np.intersect1d(g, list(id))) for g, id in zip(gs, ids)])
def test_recall(X, Q, G):
ks = [1, 5, 10, 20, 50, 100, 1000]
Ts = [2 ** i for i in range(2 + int(math.log2(len(X))))]
sort_idx = arg_sort(Q, X)
print("# Probed \t Items \t", end="")
for top_k in ks:
print("top-%d\t" % (top_k), end="")
print()
for t in Ts:
ids = sort_idx[:, :t]
items = np.mean([len(id) for id in ids])
print("%6d \t %6d \t" % (t, items), end="")
tps = [intersect_sizes(G[:, :top_k], ids) / float(top_k) for top_k in ks]
rcs = [np.mean(t) for t in tps]
vrs = [np.std(t) for t in tps]
for rc in rcs:
print("%.4f \t" % rc, end="")
# for vr in vrs:
# print("%.4f \t" % vr, end="")
print()