-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
73 lines (65 loc) · 2.07 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Compare Algorithms
import pandas
import os
from sklearn import model_selection
from sklearn.externals import joblib
from sklearn.tree import DecisionTreeClassifier
file_dir = os.path.dirname(__file__)
# load dataset
model_name = 'svm_model_linear'
input_filename = "test_dataset.csv"
def init():
global model_name,input_filename
print("\n\n---> Welcome to UBCpredict <---\n\n\
Please choose a model:\n\n\
1) Naive Bayes\n\
2) K-Nearest Neighbor\n\
3) Decision Trees\n\
4) Support vector machine (SVM) (linear)\n\
5) Support vector machine (SVM) (RBF)\n\n\
Option: ")
option = int(input())
if option == 1:
model_name = "naive_bayes_model"
elif option == 2:
model_name = "knn_model"
elif option == 3:
model_name = "decision_tree_model"
elif option == 4:
model_name = "svm_model_linear"
elif option == 5:
model_name = "svm_model_rbf"
print("\n\nFilename: test_dataset.csv?\n:")
filename = input()
if filename != "":
input_filename = filename
def get_predictdata():
global model_name,input_filename
names = ['AgeRecode', 'sex', 'YOD', 'martialStatus', 'grade', 'tumorSize', 'lymphNodes', 'TNinsitu', 'HT_ICDO3', 'primarySite', 'derivedAJCC',
'regionalNodePostive', 'class']
path = file_dir + "\\" + input_filename
dataframe = pandas.read_csv(path, names=names)
array = dataframe.values
X = array[:,0:12]
return X
def predict(X):
model_path = file_dir + '\\models\\' + model_name + ".joblib.pkl"
model = joblib.load(model_path)
prediction = model.predict(X)
print("\n\nPrediction result:\n\n")
i = 1
for p in prediction:
if p == 0:
print(str(i) + ") The patient will meet his creator soon!")
elif p == 1:
print(str(i) + ") The patient will survive for atleast 2.5 years!")
elif p == 2:
print(str(i) + ") The patient will survive for atleast 5 years!")
elif p == 3:
print(str(i) + ") The patient will survive for atleast 7.5 years!")
elif p == 4:
print(str(i) + ") The patient will survive for atleast 10 years!")
i = i + 1
init()
X = get_predictdata()
predict(X)