-
Notifications
You must be signed in to change notification settings - Fork 3
/
config_handler.py
123 lines (93 loc) · 3.48 KB
/
config_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import Dict, Any, Optional, Union
import os
import json
from notify_handler import notify_others_about_change_thread
from fastapi import responses, status
__config = None
BASE_CONFIG_PATH = "base_config.json"
CURRENT_CONFIG_PATH = "/config/current_config.json"
SERVICES_TO_NOTIFY = {
"TOKENIZER": "http://refinery-tokenizer:80",
}
def get_config_value(
key: str, subkey: Optional[str] = None
) -> Union[str, Dict[str, str]]:
if key not in __config:
raise ValueError(f"Key {key} coudn't be found in config")
value = __config[key]
if not subkey:
return value
if isinstance(value, dict) and subkey in value:
return value[subkey]
else:
raise ValueError(f"Subkey {subkey} coudn't be found in config[{key}]")
def __read_and_change_base_config():
print("reading base config file", flush=True)
global __config
f = open(BASE_CONFIG_PATH)
__config = json.load(f)
__config["s3_region"] = os.getenv("S3_REGION", "eu-west-1")
__save_current_config()
def change_config(changes: Dict[str, Any]) -> bool:
global __config
something_changed = False
for key in changes:
if key == "KERN_S3_ENDPOINT":
continue
if key in __config:
if isinstance(changes[key], dict):
for subkey in changes[key]:
if subkey in __config[key]:
__config[key][subkey] = changes[key][subkey]
something_changed = True
else:
__config[key] = changes[key]
something_changed = True
if something_changed:
__save_current_config()
else:
print("nothing was changed with input", changes, flush=True)
return something_changed
def __save_current_config() -> None:
print("saving config file", flush=True)
with open(CURRENT_CONFIG_PATH, "w") as f:
json.dump(__config, f, indent=4)
def init_config() -> None:
if not os.path.exists(CURRENT_CONFIG_PATH):
__read_and_change_base_config()
else:
__load_and_remove_outdated_config_keys()
# this one is to be set on every start to ensure its up to date
print("setting s3 endpoint", flush=True)
__config["KERN_S3_ENDPOINT"] = os.getenv("KERN_S3_ENDPOINT")
def __load_and_remove_outdated_config_keys():
if not os.path.exists(CURRENT_CONFIG_PATH):
return
global __config
with open(CURRENT_CONFIG_PATH) as f:
__config = json.load(f)
with open(BASE_CONFIG_PATH) as f:
base_config = json.load(f)
to_remove = [key for key in __config if key not in base_config]
if len(to_remove) > 0:
print("removing outdated config keys", to_remove, flush=True)
for key in to_remove:
del __config[key]
__save_current_config()
def get_config() -> Dict[str, Any]:
global __config
return __config
def change_json(config_data) -> responses.PlainTextResponse:
try:
has_changed = change_config(config_data)
if has_changed:
notify_others_about_change_thread(SERVICES_TO_NOTIFY)
return responses.PlainTextResponse(
f"Did update: {has_changed}", status_code=status.HTTP_200_OK
)
except Exception as e:
return responses.PlainTextResponse(
f"Error: {str(e)}", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
def full_config_json() -> responses.JSONResponse:
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=get_config())