diff --git a/core/cat/mad_hatter/mad_hatter.py b/core/cat/mad_hatter/mad_hatter.py index 5983983a3..5519ed25a 100644 --- a/core/cat/mad_hatter/mad_hatter.py +++ b/core/cat/mad_hatter/mad_hatter.py @@ -200,9 +200,8 @@ def toggle_plugin(self, plugin_id): # Execute hook on plugin deactivation # Deactivation hook must happen before actual deactivation, # otherwise the hook will not be available in _plugin_overrides anymore - for hook in self.plugins[plugin_id]._plugin_overrides: - if hook.name == "deactivated": - hook.function(self.plugins[plugin_id]) + if "deactivated" in self.plugins[plugin_id]._plugin_overrides: + self.plugins[plugin_id]._plugin_overrides["deactivated"].function(self.plugins[plugin_id]) # Deactivate the plugin self.plugins[plugin_id].deactivate() @@ -221,9 +220,8 @@ def toggle_plugin(self, plugin_id): # Execute hook on plugin activation # Activation hook must happen before actual activation, # otherwise the hook will still not be available in _plugin_overrides - for hook in self.plugins[plugin_id]._plugin_overrides: - if hook.name == "activated": - hook.function(self.plugins[plugin_id]) + if "activated" in self.plugins[plugin_id]._plugin_overrides: + self.plugins[plugin_id]._plugin_overrides["activated"].function(self.plugins[plugin_id]) # Add the plugin in the list of active plugins self.active_plugins.append(plugin_id) diff --git a/core/cat/mad_hatter/plugin.py b/core/cat/mad_hatter/plugin.py index b2f067334..5ee53e187 100644 --- a/core/cat/mad_hatter/plugin.py +++ b/core/cat/mad_hatter/plugin.py @@ -61,8 +61,8 @@ def __init__(self, plugin_path: str): self._forms: List[CatForm] = [] # list of plugin forms self._endpoints: List[CustomEndpoint] = [] # list of plugin endpoints - # list of @plugin decorated functions overriding default plugin behaviour - self._plugin_overrides = [] # TODO: make this a dictionary indexed by func name, for faster access + # list of @plugin decorated functions overriding default plugin behaviour + self._plugin_overrides = {} # plugin starts deactivated self._active = False @@ -101,20 +101,19 @@ def deactivate(self): self._tools = [] self._forms = [] self._deactivate_endpoints() - self._plugin_overrides = [] + self._plugin_overrides = {} self._active = False # get plugin settings JSON schema def settings_schema(self): # is "settings_schema" hook defined in the plugin? - for h in self._plugin_overrides: - if h.name == "settings_schema": - return h.function() - else: - # if the "settings_schema" is not defined but - # "settings_model" is it get the schema from the model - if h.name == "settings_model": - return h.function().model_json_schema() + if "settings_schema" in self._plugin_overrides: + return self._plugin_overrides["settings_schema"].function() + else: + # if the "settings_schema" is not defined but + # "settings_model" is it get the schema from the model + if "settings_model" in self._plugin_overrides: + return self._plugin_overrides["settings_model"].function().model_json_schema() # default schema (empty) return PluginSettingsModel.model_json_schema() @@ -122,9 +121,8 @@ def settings_schema(self): # get plugin settings Pydantic model def settings_model(self): # is "settings_model" hook defined in the plugin? - for h in self._plugin_overrides: - if h.name == "settings_model": - return h.function() + if "settings_model" in self._plugin_overrides: + return self._plugin_overrides["settings_model"].function() # default schema (empty) return PluginSettingsModel @@ -132,9 +130,8 @@ def settings_model(self): # load plugin settings def load_settings(self): # is "settings_load" hook defined in the plugin? - for h in self._plugin_overrides: - if h.name == "load_settings": - return h.function() + if "load_settings" in self._plugin_overrides: + return self._plugin_overrides["load_settings"].function() # by default, plugin settings are saved inside the plugin folder # in a JSON file called settings.json @@ -159,9 +156,8 @@ def load_settings(self): # save plugin settings def save_settings(self, settings: Dict): # is "settings_save" hook defined in the plugin? - for h in self._plugin_overrides: - if h.name == "save_settings": - return h.function(settings) + if "save_settings" in self._plugin_overrides: + return self._plugin_overrides["save_settings"].function(settings) # by default, plugin settings are saved inside the plugin folder # in a JSON file called settings.json @@ -331,9 +327,8 @@ def _load_decorated_functions(self): self._tools = list(map(self._clean_tool, tools)) self._forms = list(map(self._clean_form, forms)) self._endpoints = list(map(self._clean_endpoint, endpoints)) - self._plugin_overrides = list( - map(self._clean_plugin_override, plugin_overrides) - ) + self._plugin_overrides = {override.name: override for override in list(map(self._clean_plugin_override, plugin_overrides))} + def plugin_specific_error_message(self): name = self.manifest.get("name")