-
Notifications
You must be signed in to change notification settings - Fork 11
/
validators.py
76 lines (63 loc) · 3.66 KB
/
validators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class SafetyValidator:
def __init__(self):
self.valid_categories = [
'HARM_CATEGORY_HARASSMENT',
'HARM_CATEGORY_HATE_SPEECH',
'HARM_CATEGORY_SEXUALLY_EXPLICIT',
'HARM_CATEGORY_DANGEROUS_CONTENT'
]
self.valid_thresholds = [
'BLOCK_LOW_AND_ABOVE',
'BLOCK_MEDIUM_AND_ABOVE',
'BLOCK_ONLY_HIGH',
'BLOCK_NONE'
]
def validate(self, category, threshold):
if not self.validate_categories(category):
print(f"[ ERROR ]: Invalid safety category: {category}. Valid categories are: {', '.join(self.valid_categories)}")
return False
if not self.validate_thresholds(threshold):
print(f"[ ERROR ]: Invalid safety threshold: {threshold}. Valid thresholds are: {', '.join(self.valid_thresholds)}")
return False
return True
def validate_categories(self, category):
return category in self.valid_categories
def validate_thresholds(self, threshold):
return threshold in self.valid_thresholds
class InputValidator:
def __init__(self):
self.safety_validator = SafetyValidator()
def validate_text_input(self, prompt, candidate_count, system_prompt, json, model, safety_categories, safety_thresholds):
if not prompt:
print("[ ERROR ]: Invalid input detected. Please enter a valid message.")
return False
return self.validate_common_params(candidate_count, system_prompt, json, model, safety_categories, safety_thresholds)
def validate_chat_input(self, candidate_count, system_prompt, json, model, safety_categories, safety_thresholds):
return self.validate_common_params(candidate_count, system_prompt, json, model, safety_categories, safety_thresholds)
def validate_multimodal_input(self, candidate_count, system_prompt, json, model, safety_categories, safety_thresholds):
if "1.5" or "2.0" not in model:
print(f"[ ERROR ]: Multimodal mode is only supported in Gemini 1.5 and 2.0. Current model: {model}")
return False
return self.validate_common_params(candidate_count, system_prompt, json, model, safety_categories, safety_thresholds)
def validate_common_params(self, candidate_count, system_prompt, json, model, safety_categories, safety_thresholds):
if candidate_count is not None:
if not isinstance(candidate_count, int) or candidate_count <= 0:
print("[ ERROR ]: Candidate count must be a positive integer.")
return False
if candidate_count > 1:
print("[ ERROR ]: Candidate count greater than 1 is not supported due to a bug in the API. Please use a candidate count of 1, or remove the parameter.")
return False
if system_prompt and ("1.5" or "2.0") not in model:
print(f"[ ERROR ]: System instructions are only supported in Gemini 1.5 and 2.0. Current model: {model}")
return False
if json and ("1.5" or "2.0") not in model:
print(f"[ ERROR ]: JSON output is only supported in Gemini 1.5 and 2.0. Current model: {model}")
return False
if safety_categories and safety_thresholds:
if len(safety_categories) != len(safety_thresholds):
print("[ ERROR ]: Mismatch in safety categories and thresholds. Ensure each category has a corresponding threshold.")
return False
for category, threshold in zip(safety_categories, safety_thresholds):
if not self.safety_validator.validate(category, threshold):
return False
return True