-
Notifications
You must be signed in to change notification settings - Fork 85
/
evaluation.py
62 lines (56 loc) · 1.93 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import math
import numpy as np
from sklearn.metrics import (
auc,
precision_recall_curve,
roc_auc_score,
f1_score,
confusion_matrix,
matthews_corrcoef,
mean_squared_error,
mean_absolute_error,
)
def calc_classification_metrics(pred_scores, pred_labels, labels):
if len(np.unique(labels)) == 2: # binary classification
roc_auc_pred_score = roc_auc_score(labels, pred_scores)
precisions, recalls, thresholds = precision_recall_curve(labels, pred_scores)
fscore = (2 * precisions * recalls) / (precisions + recalls)
fscore[np.isnan(fscore)] = 0
ix = np.argmax(fscore)
threshold = thresholds[ix].item()
pr_auc = auc(recalls, precisions)
tn, fp, fn, tp = confusion_matrix(labels, pred_labels, labels=[0, 1]).ravel()
result = {
"roc_auc": roc_auc_pred_score,
"threshold": threshold,
"pr_auc": pr_auc,
"recall": recalls[ix].item(),
"precision": precisions[ix].item(),
"f1": fscore[ix].item(),
"tn": tn.item(),
"fp": fp.item(),
"fn": fn.item(),
"tp": tp.item(),
}
else:
acc = (pred_labels == labels).mean()
f1_micro = f1_score(y_true=labels, y_pred=pred_labels, average="micro")
f1_macro = f1_score(y_true=labels, y_pred=pred_labels, average="macro")
f1_weighted = f1_score(y_true=labels, y_pred=pred_labels, average="weighted")
result = {
"acc": acc,
"f1_micro": f1_micro,
"f1_macro": f1_macro,
"f1_weighted": f1_weighted,
"mcc": matthews_corrcoef(labels, pred_labels),
}
return result
def calc_regression_metrics(preds, labels):
mse = mean_squared_error(labels, preds)
rmse = math.sqrt(mse)
mae = mean_absolute_error(labels, preds)
return {
"mse": mse,
"rmse": rmse,
"mae": mae,
}