Skip to content

Commit

Permalink
feat: issue cheshire-cat-ai#980
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe Coco committed Dec 3, 2024
1 parent 2fdf311 commit 7d56d0f
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions core/cat/mad_hatter/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(self, plugin_path: str):
self._forms: List[CatForm] = [] # list of plugin forms
self._endpoints: List[CustomEndpoint] = [] # list of plugin endpoints

# list of @plugin decorated functions overriding default plugin behaviour
self._plugin_overrides = [] # TODO: make this a dictionary indexed by func name, for faster access
# list of @plugin decorated functions overriding default plugin behaviour
self._plugin_overrides = {}

# plugin starts deactivated
self._active = False
Expand Down Expand Up @@ -101,40 +101,37 @@ def deactivate(self):
self._tools = []
self._forms = []
self._deactivate_endpoints()
self._plugin_overrides = []
self._plugin_overrides = {}
self._active = False

# get plugin settings JSON schema
def settings_schema(self):
# is "settings_schema" hook defined in the plugin?
for h in self._plugin_overrides:
if h.name == "settings_schema":
return h.function()
else:
# if the "settings_schema" is not defined but
# "settings_model" is it get the schema from the model
if h.name == "settings_model":
return h.function().model_json_schema()
if "settings_schema" in self._plugin_overrides:
return self._plugin_overrides["settings_schema"]()
else:
# if the "settings_schema" is not defined but
# "settings_model" is it get the schema from the model
if "settings_model" in self._plugin_overrides:
return self._plugin_overrides["settings_model"]().model_json_schema()

# default schema (empty)
return PluginSettingsModel.model_json_schema()

# get plugin settings Pydantic model
def settings_model(self):
# is "settings_model" hook defined in the plugin?
for h in self._plugin_overrides:
if h.name == "settings_model":
return h.function()
if "settings_model" in self._plugin_overrides:
return self._plugin_overrides["settings_model"]()

# default schema (empty)
return PluginSettingsModel

# load plugin settings
def load_settings(self):
# is "settings_load" hook defined in the plugin?
for h in self._plugin_overrides:
if h.name == "load_settings":
return h.function()
if "load_settings" in self._plugin_overrides:
return self._plugin_overrides["load_settings"]()

# by default, plugin settings are saved inside the plugin folder
# in a JSON file called settings.json
Expand All @@ -159,9 +156,8 @@ def load_settings(self):
# save plugin settings
def save_settings(self, settings: Dict):
# is "settings_save" hook defined in the plugin?
for h in self._plugin_overrides:
if h.name == "save_settings":
return h.function(settings)
if "save_settings" in self._plugin_overrides:
return self._plugin_overrides["save_settings"](settings)

# by default, plugin settings are saved inside the plugin folder
# in a JSON file called settings.json
Expand Down Expand Up @@ -331,9 +327,8 @@ def _load_decorated_functions(self):
self._tools = list(map(self._clean_tool, tools))
self._forms = list(map(self._clean_form, forms))
self._endpoints = list(map(self._clean_endpoint, endpoints))
self._plugin_overrides = list(
map(self._clean_plugin_override, plugin_overrides)
)
self._plugin_overrides = {override.__name__: override for _, override in plugin_overrides}


def plugin_specific_error_message(self):
name = self.manifest.get("name")
Expand Down

0 comments on commit 7d56d0f

Please sign in to comment.