diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index b2f067334..81797dbe0 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -61,8 +61,8 @@ def __init__(self, plugin_path: str): self._forms: List[CatForm] = [] # list of plugin forms self._endpoints: List[CustomEndpoint] = [] # list of plugin endpoints - # list of @plugin decorated functions overriding default plugin behaviour - self._plugin_overrides = [] # TODO: make this a dictionary indexed by func name, for faster access + # list of @plugin decorated functions overriding default plugin behaviour + self._plugin_overrides = {} # plugin starts deactivated self._active = False @@ -101,20 +101,19 @@ def deactivate(self): self._tools = [] self._forms = [] self._deactivate_endpoints() - self._plugin_overrides = [] + self._plugin_overrides = {} self._active = False # get plugin settings JSON schema def settings_schema(self): # is "settings_schema" hook defined in the plugin? - for h in self._plugin_overrides: - if h.name == "settings_schema": - return h.function() - else: - # if the "settings_schema" is not defined but - # "settings_model" is it get the schema from the model - if h.name == "settings_model": - return h.function().model_json_schema() + if "settings_schema" in self._plugin_overrides: + return self._plugin_overrides["settings_schema"]() + else: + # if the "settings_schema" is not defined but + # "settings_model" is it get the schema from the model + if "settings_model" in self._plugin_overrides: + return self._plugin_overrides["settings_model"]().model_json_schema() # default schema (empty) return PluginSettingsModel.model_json_schema() @@ -122,9 +121,8 @@ def settings_schema(self): # get plugin settings Pydantic model def settings_model(self): # is "settings_model" hook defined in the plugin? - for h in self._plugin_overrides: - if h.name == "settings_model": - return h.function() + if "settings_model" in self._plugin_overrides: + return self._plugin_overrides["settings_model"]() # default schema (empty) return PluginSettingsModel @@ -132,9 +130,8 @@ def settings_model(self): # load plugin settings def load_settings(self): # is "settings_load" hook defined in the plugin? - for h in self._plugin_overrides: - if h.name == "load_settings": - return h.function() + if "load_settings" in self._plugin_overrides: + return self._plugin_overrides["load_settings"]() # by default, plugin settings are saved inside the plugin folder # in a JSON file called settings.json @@ -159,9 +156,8 @@ def load_settings(self): # save plugin settings def save_settings(self, settings: Dict): # is "settings_save" hook defined in the plugin? - for h in self._plugin_overrides: - if h.name == "save_settings": - return h.function(settings) + if "save_settings" in self._plugin_overrides: + return self._plugin_overrides["save_settings"](settings) # by default, plugin settings are saved inside the plugin folder # in a JSON file called settings.json @@ -331,9 +327,8 @@ def _load_decorated_functions(self): self._tools = list(map(self._clean_tool, tools)) self._forms = list(map(self._clean_form, forms)) self._endpoints = list(map(self._clean_endpoint, endpoints)) - self._plugin_overrides = list( - map(self._clean_plugin_override, plugin_overrides) - ) + self._plugin_overrides = {override.__name__: override for _, override in plugin_overrides} + def plugin_specific_error_message(self): name = self.manifest.get("name")