-
Notifications
You must be signed in to change notification settings - Fork 1
/
decision.py
78 lines (61 loc) · 2.76 KB
/
decision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
'''
Plotting clinical benefit curves
'''
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def calculate_net_benefit_model(thresh_group, y_pred_score, y_label):
"""
Function: Calculate the benefit from the model
Returns: net_benefit_model, ndarray, model benefits obtained at different thresholds
Parameters: thresh_group, different thresholds for comparing with y_pred_score to get predicted labels
Parameters: y_pred_score, predicted probability of positive class or indicator value, etc.
Parameters: y_label, true label
"""
net_benefit_model = np.array([])
for thresh in thresh_group:
y_pred_label = y_pred_score > thresh
tn, fp, fn, tp = confusion_matrix(y_label, y_pred_label).ravel()
n = len(y_label)
net_benefit = (tp / n) - (fp / n) * (thresh / (1 - thresh))
net_benefit_model = np.append(net_benefit_model, net_benefit)
return net_benefit_model
def calculate_net_benefit_all(thresh_group, y_label):
"""
Function: Calculate the return of all treats
Returns: net_benefit_all, ndarray, all benefits obtained at different thresholds
Parameters: thresh_group, different thresholds
Parameters: y_label, true label
"""
net_benefit_all = np.array([])
tn, fp, fn, tp = confusion_matrix(y_label, y_label).ravel()
total = tp + tn
for thresh in thresh_group:
net_benefit = (tp / total) - (tn / total) * (thresh / (1 - thresh))
net_benefit_all = np.append(net_benefit_all, net_benefit)
return net_benefit_all
def plot_DCA(ax, thresh_group, net_benefit_model, net_benefit_all):
#Plot
ax.plot(thresh_group, net_benefit_model, color = 'crimson', label = 'Model')
ax.plot(thresh_group, net_benefit_all, color = 'black',label = 'Treat all')
ax.plot((0, 1), (0, 0), color = 'black', linestyle = ':', label = 'Treat none')
#Fill,Show the parts of the model that are better than treat all and treat none
y2 = np.maximum(net_benefit_all, 0)
y1 = np.maximum(net_benefit_model, y2)
ax.fill_between(thresh_group, y1, y2, color = 'crimson', alpha = 0.2)
#Figure Configuration,
ax.set_xlim(0,1)
ax.set_ylim(net_benefit_model.min() - 0.02, net_benefit_model.max() + 0.15)#adjustify the y axis limitation
ax.set_xlabel(
xlabel = 'Threshold Probability',
fontdict= {'family': 'Times New Roman', 'fontsize': 15}
)
ax.set_ylabel(
ylabel = 'Net Benefit',
fontdict= {'family': 'Times New Roman', 'fontsize': 15}
)
# ax.grid('major')
ax.spines['right'].set_color((0.8, 0.8, 0.8))
ax.spines['top'].set_color((0.8, 0.8, 0.8))
ax.legend()
return ax