From 8a28dd61adb4bb0fe42af0b098d03b8fbb3c1174 Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Sat, 25 May 2024 17:12:58 -0400 Subject: [PATCH 01/13] update naming and start implementing new LLM API support --- TODO.md | 9 ++- .../llama_conversation/__init__.py | 14 ++-- custom_components/llama_conversation/agent.py | 80 ++++++++++++------- .../llama_conversation/config_flow.py | 38 +++++++-- .../llama_conversation/manifest.json | 2 +- docs/Setup.md | 6 +- hacs.json | 4 +- tests/llama_conversation/test_agent.py | 14 ++-- 8 files changed, 108 insertions(+), 59 deletions(-) diff --git a/TODO.md b/TODO.md index 9ce830b..1d6c9fd 100644 --- a/TODO.md +++ b/TODO.md @@ -1,6 +1,13 @@ # TODO -- [x] detection/mitigation of too many entities being exposed & blowing out the context length +- [ ] support new LLM APIs + - rewrite how services are called + - handle no API selected + - rewrite prompts + service block formats + - update dataset so new models will work with the API +- [ ] make ICL examples into conversation turns +- [ ] translate ICL examples + make better ones - [ ] areas/room support +- [x] detection/mitigation of too many entities being exposed & blowing out the context length - [ ] figure out DPO to improve response quality - [ ] train the model to respond to house events - present the model with an event + a "prompt" from the user of what you want it to do (i.e. turn on the lights when I get home = the model turns on lights when your entity presence triggers as being home) diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index d00c9b1..0d35ff1 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -1,4 +1,4 @@ -"""The Local LLaMA Conversation integration.""" +"""The Local LLM Conversation integration.""" from __future__ import annotations import logging @@ -9,8 +9,8 @@ from homeassistant.helpers import config_validation as cv from .agent import ( - LLaMAAgent, - LocalLLaMAAgent, + LocalLLMAgent, + LlamaCppAgent, GenericOpenAIAPIAgent, TextGenerationWebuiAgent, LlamaCppPythonAPIAgent, @@ -38,19 +38,19 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry): hass.data[DOMAIN][entry.entry_id] = entry # call update handler - agent: LLaMAAgent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id) + agent: LocalLLMAgent = await ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id) agent._update_options() return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Set up Local LLaMA Conversation from a config entry.""" + """Set up Local LLM Conversation from a config entry.""" def create_agent(backend_type): agent_cls = None if backend_type in [ BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING ]: - agent_cls = LocalLLaMAAgent + agent_cls = LlamaCppAgent elif backend_type == BACKEND_TYPE_GENERIC_OPENAI: agent_cls = GenericOpenAIAPIAgent elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI: @@ -78,7 +78,7 @@ def create_agent(backend_type): async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Unload Local LLaMA.""" + """Unload Local LLM.""" hass.data[DOMAIN].pop(entry.entry_id) ha_conversation.async_unset_agent(hass, entry) return True diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index 7df88e8..6d9d156 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -18,10 +18,10 @@ from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry -from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL +from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL, CONF_LLM_HASS_API from homeassistant.core import HomeAssistant, callback -from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError -from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er +from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError +from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm from homeassistant.helpers.event import async_track_state_change, async_call_later from homeassistant.util import ulid @@ -114,8 +114,8 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) -class LLaMAAgent(AbstractConversationAgent): - """Base LLaMA conversation agent.""" +class LocalLLMAgent(AbstractConversationAgent): + """Base Local LLM conversation agent.""" hass: HomeAssistant entry_id: str @@ -225,6 +225,22 @@ async def async_process( return ConversationResult( response=intent_response, conversation_id=conversation_id ) + + llm_api: llm.API | None = None + if self.entry.options.get(CONF_LLM_HASS_API): + try: + llm_api = llm.async_get_api( + self.hass, self.entry.options[CONF_LLM_HASS_API] + ) + except HomeAssistantError as err: + _LOGGER.error("Error getting LLM API: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Error preparing LLM API: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=user_input.conversation_id + ) if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id @@ -235,7 +251,7 @@ async def async_process( if len(conversation) == 0 or refresh_system_prompt: try: - message = self._generate_system_prompt(raw_prompt) + message = self._generate_system_prompt(raw_prompt, llm_api) except TemplateError as err: _LOGGER.error("Error rendering prompt: %s", err) intent_response = intent.IntentResponse(language=user_input.language) @@ -407,7 +423,7 @@ def _format_prompt( _LOGGER.debug(formatted_prompt) return formatted_prompt - def _generate_system_prompt(self, prompt_template: str) -> str: + def _generate_system_prompt(self, prompt_template: str, llm_api: llm.API) -> str: """Generate the system prompt with current entity states""" entities_to_expose, domains = self._async_get_exposed_entities() @@ -487,21 +503,14 @@ def expose_attributes(attributes): formatted_states = "\n".join(device_states) + "\n" - service_dict = self.hass.services.async_services() - all_services = [] - all_service_names = [] - for domain in domains: - # scripts show up as individual services - if domain == "script": - all_services.extend(["script.reload()", "script.turn_on()", "script.turn_off()", "script.toggle()"]) - continue - - for name, service in service_dict.get(domain, {}).items(): - args = flatten_vol_schema(service.schema) - args_to_expose = set(args).intersection(allowed_service_call_arguments) - all_services.append(f"{domain}.{name}({','.join(args_to_expose)})") - all_service_names.append(f"{domain}.{name}") - formatted_services = ", ".join(all_services) + if llm_api: + tools = [ + f"{tool.name}({flatten_vol_schema(tool.parameters)}) - {tool.description}" + for tool in llm_api.async_get_tools() + ] + formatted_services = llm_api.prompt_template + "\n" + "\n".join(tools) + else: + formatted_services = "No tools exposed." render_variables = { "devices": formatted_states, @@ -509,15 +518,16 @@ def expose_attributes(attributes): } if self.in_context_examples: - num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES)) - render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names)) + # num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES)) + # render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names)) + render_variables["response_examples"] = "" return template.Template(prompt_template, self.hass).async_render( render_variables, parse_result=False, ) -class LocalLLaMAAgent(LLaMAAgent): +class LlamaCppAgent(LocalLLMAgent): model_path: str llm: LlamaType grammar: Any @@ -612,7 +622,7 @@ def _load_grammar(self, filename: str): self.grammar = None def _update_options(self): - LLaMAAgent._update_options(self) + LocalLLMAgent._update_options(self) model_reloaded = False if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \ @@ -662,7 +672,7 @@ async def cache_current_prompt(_now): def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]: """Takes the super class function results and sorts the entities with the recently updated at the end""" - entities, domains = LLaMAAgent._async_get_exposed_entities(self) + entities, domains = LocalLLMAgent._async_get_exposed_entities(self) # ignore sorting if prompt caching is disabled if not self.entry.options.get(CONF_PROMPT_CACHING_ENABLED, DEFAULT_PROMPT_CACHING_ENABLED): @@ -739,10 +749,20 @@ def _cache_prompt(self) -> None: self.cache_refresh_after_cooldown = True return + llm_api: llm.API | None = None + if self.entry.options.get(CONF_LLM_HASS_API): + try: + llm_api = llm.async_get_api( + self.hass, self.entry.options[CONF_LLM_HASS_API] + ) + except HomeAssistantError: + _LOGGER.exception("Failed to get LLM API when caching prompt!") + return + try: raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) prompt = self._format_prompt([ - { "role": "system", "message": self._generate_system_prompt(raw_prompt)}, + { "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)}, { "role": "user", "message": "" } ], include_generation_prompt=False) @@ -839,7 +859,7 @@ def _generate(self, conversation: dict) -> str: return result -class GenericOpenAIAPIAgent(LLaMAAgent): +class GenericOpenAIAPIAgent(LocalLLMAgent): api_host: str api_key: str model_name: str @@ -1046,7 +1066,7 @@ def _completion_params(self, conversation: dict) -> (str, dict): return endpoint, request_params -class OllamaAPIAgent(LLaMAAgent): +class OllamaAPIAgent(LocalLLMAgent): api_host: str api_key: str model_name: str diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index ed7e17f..5cca95c 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -1,4 +1,4 @@ -"""Config flow for Local LLaMA Conversation integration.""" +"""Config flow for Local LLM Conversation integration.""" from __future__ import annotations import os @@ -13,18 +13,20 @@ from homeassistant import config_entries from homeassistant.core import HomeAssistant -from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, UnitOfTime +from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, CONF_LLM_HASS_API, UnitOfTime from homeassistant.data_entry_flow import ( AbortFlow, FlowHandler, FlowManager, FlowResult, ) +from homeassistant.helpers import llm from homeassistant.helpers.selector import ( NumberSelector, NumberSelectorConfig, NumberSelectorMode, TemplateSelector, + SelectOptionDict, SelectSelector, SelectSelectorConfig, SelectSelectorMode, @@ -279,7 +281,7 @@ async def async_step_finish( """ Finish configuration """ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, domain=DOMAIN): - """Handle a config flow for Local LLaMA Conversation.""" + """Handle a config flow for Local LLM Conversation.""" VERSION = 1 install_wheel_task = None @@ -584,7 +586,7 @@ async def async_step_model_parameters( persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en")) selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("", persona) - schema = vol.Schema(local_llama_config_option_schema(selected_default_options, backend_type)) + schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type)) if user_input: self.options = user_input @@ -626,7 +628,7 @@ def async_get_options_flow( class OptionsFlow(config_entries.OptionsFlow): - """Local LLaMA config flow options handler.""" + """Local LLM config flow options handler.""" def __init__(self, config_entry: config_entries.ConfigEntry) -> None: """Initialize options flow.""" @@ -656,9 +658,10 @@ async def async_step_init( description_placeholders["filename"] = filename if len(errors) == 0: - return self.async_create_entry(title="LLaMA Conversation", data=user_input) + return self.async_create_entry(title="Local LLM Conversation", data=user_input) schema = local_llama_config_option_schema( + self.hass, self.config_entry.options, self.config_entry.data[CONF_BACKEND_TYPE], ) @@ -682,12 +685,31 @@ def insert_after_key(input_dict: dict, key_name: str, other_dict: dict): return result -def local_llama_config_option_schema(options: MappingProxyType[str, Any], backend_type: str) -> dict: - """Return a schema for Local LLaMA completion options.""" +def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyType[str, Any], backend_type: str) -> dict: + """Return a schema for Local LLM completion options.""" if not options: options = DEFAULT_OPTIONS + apis: list[SelectOptionDict] = [ + SelectOptionDict( + label="No control", + value="none", + ) + ] + apis.extend( + SelectOptionDict( + label=api.name, + value=api.id, + ) + for api in llm.async_get_apis(hass) + ) + result = { + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=apis)), vol.Required( CONF_PROMPT, description={"suggested_value": options.get(CONF_PROMPT)}, diff --git a/custom_components/llama_conversation/manifest.json b/custom_components/llama_conversation/manifest.json index 9f09f63..e90a943 100644 --- a/custom_components/llama_conversation/manifest.json +++ b/custom_components/llama_conversation/manifest.json @@ -1,6 +1,6 @@ { "domain": "llama_conversation", - "name": "LLaMA Conversation", + "name": "Local LLM Conversation", "version": "0.2.17", "codeowners": ["@acon96"], "config_flow": true, diff --git a/docs/Setup.md b/docs/Setup.md index 95b449c..f888212 100644 --- a/docs/Setup.md +++ b/docs/Setup.md @@ -35,7 +35,7 @@ The following link will open your Home Assistant installation and download the i [![Open your Home Assistant instance and open a repository inside the Home Assistant Community Store.](https://my.home-assistant.io/badges/hacs_repository.svg)](https://my.home-assistant.io/redirect/hacs_repository/?category=Integration&repository=home-llm&owner=acon96) -After installation, A "LLaMA Conversation" device should show up in the `Settings > Devices and Services > [Devices]` tab now. +After installation, A "Local LLM Conversation" device should show up in the `Settings > Devices and Services > [Devices]` tab now. ## Path 1: Using the Home Model with the Llama.cpp Backend ### Overview @@ -44,7 +44,7 @@ This setup path involves downloading a fine-tuned model from HuggingFace and int ### Step 1: Wheel Installation for llama-cpp-python 1. In Home Assistant: navigate to `Settings > Devices and Services` 2. Select the `+ Add Integration` button in the bottom right corner -3. Search for, and select `LLaMA Conversation` +3. Search for, and select `Local LLM Conversation` 4. With the `Llama.cpp (HuggingFace)` backend selected, click `Submit` This should download and install `llama-cpp-python` from GitHub. If the installation fails for any reason, follow the manual installation instructions [here](./Backend%20Configuration.md#wheels). @@ -82,7 +82,7 @@ In order to access the model from another machine, we need to run the Ollama API 1. In Home Assistant: navigate to `Settings > Devices and Services` 2. Select the `+ Add Integration` button in the bottom right corner -3. Search for, and select `LLaMA Conversation` +3. Search for, and select `Local LLM Conversation` 4. Select `Ollama API` from the dropdown and click `Submit` 5. Set up the connection to the API: - **IP Address**: Fill out IP Address for the machine hosting Ollama diff --git a/hacs.json b/hacs.json index b4513aa..d0b9dbc 100644 --- a/hacs.json +++ b/hacs.json @@ -1,6 +1,6 @@ { - "name": "LLaMA Conversation", - "homeassistant": "2023.10.0", + "name": "Local LLM Conversation", + "homeassistant": "2024.5.5", "content_in_root": false, "render_readme": true } diff --git a/tests/llama_conversation/test_agent.py b/tests/llama_conversation/test_agent.py index 49df0ce..05b6722 100644 --- a/tests/llama_conversation/test_agent.py +++ b/tests/llama_conversation/test_agent.py @@ -4,7 +4,7 @@ import jinja2 from unittest.mock import patch, MagicMock, PropertyMock, AsyncMock, ANY -from custom_components.llama_conversation.agent import LocalLLaMAAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent +from custom_components.llama_conversation.agent import LlamaCppAgent, OllamaAPIAgent, TextGenerationWebuiAgent, GenericOpenAIAPIAgent from custom_components.llama_conversation.const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, @@ -140,10 +140,10 @@ async def call_now(func, *args, **kwargs): @pytest.fixture def local_llama_agent_fixture(config_entry, home_assistant_mock): - with patch.object(LocalLLaMAAgent, '_load_icl_examples') as load_icl_examples_mock, \ - patch.object(LocalLLaMAAgent, '_load_grammar') as load_grammar_mock, \ - patch.object(LocalLLaMAAgent, 'entry', new_callable=PropertyMock) as entry_mock, \ - patch.object(LocalLLaMAAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \ + with patch.object(LlamaCppAgent, '_load_icl_examples') as load_icl_examples_mock, \ + patch.object(LlamaCppAgent, '_load_grammar') as load_grammar_mock, \ + patch.object(LlamaCppAgent, 'entry', new_callable=PropertyMock) as entry_mock, \ + patch.object(LlamaCppAgent, '_async_get_exposed_entities') as get_exposed_entities_mock, \ patch('homeassistant.helpers.template.Template') as template_mock, \ patch('custom_components.llama_conversation.agent.importlib.import_module') as import_module_mock, \ patch('custom_components.llama_conversation.agent.install_llama_cpp_python') as install_llama_cpp_python_mock: @@ -174,7 +174,7 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock): "target_device": "light.kitchen_light", }).encode() - agent_obj = LocalLLaMAAgent( + agent_obj = LlamaCppAgent( home_assistant_mock, config_entry ) @@ -191,7 +191,7 @@ def local_llama_agent_fixture(config_entry, home_assistant_mock): async def test_local_llama_agent(local_llama_agent_fixture): - local_llama_agent: LocalLLaMAAgent + local_llama_agent: LlamaCppAgent all_mocks: dict[str, MagicMock] local_llama_agent, all_mocks = local_llama_agent_fixture From 367607b14f832e0d88cf993375042e53ab5122f6 Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Sat, 25 May 2024 21:24:45 -0400 Subject: [PATCH 02/13] more rewrite work for new LLM API --- custom_components/llama_conversation/agent.py | 109 +++++++----------- .../llama_conversation/config_flow.py | 16 +-- custom_components/llama_conversation/const.py | 3 - .../llama_conversation/translations/en.json | 8 +- custom_components/llama_conversation/utils.py | 7 +- requirements.txt | 9 ++ tests/llama_conversation/test_agent.py | 2 - tests/llama_conversation/test_config_flow.py | 6 +- 8 files changed, 69 insertions(+), 91 deletions(-) diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index 6d9d156..e691656 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -5,6 +5,7 @@ import threading import importlib from typing import Literal, Any, Callable +import voluptuous as vol import requests import re @@ -15,6 +16,7 @@ import time from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent +import homeassistant.components.conversation as ha_conversation from homeassistant.components.conversation.const import DOMAIN as CONVERSATION_DOMAIN from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry @@ -39,7 +41,6 @@ CONF_BACKEND_TYPE, CONF_DOWNLOADED_MODEL_FILE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, - CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_PROMPT_TEMPLATE, CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, @@ -74,7 +75,6 @@ DEFAULT_BACKEND_TYPE, DEFAULT_REQUEST_TIMEOUT, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, - DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_PROMPT_TEMPLATE, DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, @@ -209,8 +209,6 @@ async def async_process( remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION) remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS) service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX) - allowed_service_call_arguments = self.entry.options \ - .get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS) try: service_call_pattern = re.compile(service_call_regex) @@ -302,72 +300,49 @@ async def async_process( # parse response exposed_entities = list(self._async_get_exposed_entities()[0].keys()) - to_say = service_call_pattern.sub("", response).strip() + to_say = "" for block in service_call_pattern.findall(response.strip()): - services = block.split("\n") - _LOGGER.info(f"running services: {' '.join(services)}") + _LOGGER.info(f"calling tool: {block}") - for line in services: - if len(line) == 0: - break + parsed_tool_call = json.loads(block) + + to_say = to_say + parsed_tool_call.get("to_say", "") + + # try to fix certain arguments + # make sure brightness is 0-255 and not a percentage + if "brightness" in parsed_tool_call["arguments"] and 0.0 < parsed_tool_call["arguments"]["brightness"] <= 1.0: + parsed_tool_call["arguments"]["brightness"] = int(parsed_tool_call["arguments"]["brightness"] * 255) + + # convert string "tuple" to a list for RGB colors + if "rgb_color" in parsed_tool_call["arguments"] and isinstance(parsed_tool_call["arguments"]["rgb_color"], str): + parsed_tool_call["arguments"]["rgb_color"] = [ int(x) for x in parsed_tool_call["arguments"]["rgb_color"][1:-1].split(",") ] + + tool_input = llm.ToolInput( + tool_name=parsed_tool_call["tool"], + tool_args=parsed_tool_call["arguments"], + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=ha_conversation.DOMAIN, + ) + + # TODO: multi-turn with the model where it acts on the response from the tool? + try: + tool_response = await llm_api.async_call_tool( + self.hass, tool_input + ) + except (HomeAssistantError, vol.Invalid) as e: + tool_response = {"error": type(e).__name__} + if str(e): + tool_response["error_text"] = str(e) - # parse old format or JSON format - try: - json_output = json.loads(line) - service = json_output["service"] - entity = json_output["target_device"] - domain, service = tuple(service.split(".")) - if "to_say" in json_output: - to_say = to_say + json_output.pop("to_say") - - extra_arguments = { k: v for k, v in json_output.items() if k not in [ "service", "target_device" ] } - except Exception: - try: - service = line.split("(")[0] - entity = line.split("(")[1][:-1] - domain, service = tuple(service.split(".")) - extra_arguments = {} - except Exception: - to_say += f" Failed to parse call from '{line}'!" - continue - - # fix certain arguments - # make sure brightness is 0-255 and not a percentage - if "brightness" in extra_arguments and 0.0 < extra_arguments["brightness"] <= 1.0: - extra_arguments["brightness"] = int(extra_arguments["brightness"] * 255) - - # convert string "tuple" to a list for RGB colors - if "rgb_color" in extra_arguments and isinstance(extra_arguments["rgb_color"], str): - extra_arguments["rgb_color"] = [ int(x) for x in extra_arguments["rgb_color"][1:-1].split(",") ] - - # only acknowledge requests to exposed entities - if entity not in exposed_entities: - to_say += f" Can't find device '{entity}'!" - else: - # copy arguments to service call - service_data = {ATTR_ENTITY_ID: entity} - for attr in allowed_service_call_arguments: - if attr in extra_arguments.keys(): - service_data[attr] = extra_arguments[attr] - - try: - _LOGGER.debug(f"service data: {service_data}") - await self.hass.services.async_call( - domain, - service, - service_data=service_data, - blocking=True, - ) - except Exception as err: - to_say += f"\nFailed to run: {line}" - _LOGGER.exception(f"Failed to run: {line}") - - if template_desc["assistant"]["suffix"]: - to_say = to_say.replace(template_desc["assistant"]["suffix"], "") # remove the eos token if it is returned (some backends + the old model does this) + _LOGGER.debug("Tool response: %s", tool_response) # generate intent response to Home Assistant intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(to_say) + intent_response.set return ConversationResult( response=intent_response, conversation_id=conversation_id ) @@ -429,8 +404,6 @@ def _generate_system_prompt(self, prompt_template: str, llm_api: llm.API) -> str extra_attributes_to_expose = self.entry.options \ .get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE) - allowed_service_call_arguments = self.entry.options \ - .get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS) def icl_example_generator(num_examples, entity_names, service_names): entity_domains = set([x.split(".")[0] for x in entity_names]) @@ -459,8 +432,8 @@ def icl_example_generator(num_examples, entity_names, service_names): device = [ x for x in entity_names if x.split(".")[0] == chosen_service.split(".")[0] ][0] example = { "to_say": chosen_example["response"], - "service": chosen_service, - "target_device": device, + "tool": chosen_service, + "arguments": { "name": device }, } yield json.dumps(example) + "\n" @@ -505,7 +478,7 @@ def expose_attributes(attributes): if llm_api: tools = [ - f"{tool.name}({flatten_vol_schema(tool.parameters)}) - {tool.description}" + f"{tool.name}({', '.join(flatten_vol_schema(tool.parameters))}) - {tool.description}" for tool in llm_api.async_get_tools() ] formatted_services = llm_api.prompt_template + "\n" + "\n".join(tools) diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index 5cca95c..b8331b7 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -60,7 +60,6 @@ CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, - CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_TEXT_GEN_WEBUI_PRESET, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, @@ -100,7 +99,6 @@ DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, - DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_REFRESH_SYSTEM_PROMPT, DEFAULT_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_NUM_INTERACTIONS, @@ -589,10 +587,14 @@ async def async_step_model_parameters( schema = vol.Schema(local_llama_config_option_schema(self.hass, selected_default_options, backend_type)) if user_input: - self.options = user_input + if user_input[CONF_LLM_HASS_API] == "none": + user_input.pop(CONF_LLM_HASS_API) + try: # validate input schema(user_input) + + self.options = user_input return await self.async_step_finish() except Exception as ex: _LOGGER.exception("An unknown error has occurred!") @@ -657,6 +659,9 @@ async def async_step_init( errors["base"] = "missing_icl_file" description_placeholders["filename"] = filename + if user_input[CONF_LLM_HASS_API] == "none": + user_input.pop(CONF_LLM_HASS_API) + if len(errors) == 0: return self.async_create_entry(title="Local LLM Conversation", data=user_input) @@ -750,11 +755,6 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)}, default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, ): TextSelector(TextSelectorConfig(multiple=True)), - vol.Required( - CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, - description={"suggested_value": options.get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS)}, - default=DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, - ): TextSelector(TextSelectorConfig(multiple=True)), vol.Required( CONF_SERVICE_CALL_REGEX, description={"suggested_value": options.get(CONF_SERVICE_CALL_REGEX)}, diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index db3682e..cf936e5 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -59,8 +59,6 @@ DEFAULT_SSL = False CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose" DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "item", "wind_speed"] -CONF_ALLOWED_SERVICE_CALL_ARGUMENTS = "allowed_service_call_arguments" -DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration"] CONF_PROMPT_TEMPLATE = "prompt_template" PROMPT_TEMPLATE_CHATML = "chatml" PROMPT_TEMPLATE_COMMAND_R = "command-r" @@ -197,7 +195,6 @@ CONF_ENABLE_FLASH_ATTENTION: DEFAULT_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, - CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS, diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index 2119e06..21ed522 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -51,6 +51,7 @@ "model_parameters": { "data": { "max_new_tokens": "Maximum tokens to return in response", + "llm_hass_api": "Selected LLM API", "prompt": "System Prompt", "prompt_template": "Prompt Format", "temperature": "Temperature", @@ -62,7 +63,6 @@ "ollama_keep_alive": "Keep Alive/Inactivity Timeout (minutes)", "ollama_json_mode": "JSON Output Mode", "extra_attributes_to_expose": "Additional attribute to expose in the context", - "allowed_service_call_arguments": "Arguments allowed to be pass to service calls", "enable_flash_attention": "Enable Flash Attention", "gbnf_grammar": "Enable GBNF Grammar", "gbnf_grammar_file": "GBNF Grammar Filename", @@ -86,11 +86,11 @@ "n_batch_threads": "Batch Thread Count" }, "data_description": { + "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices.", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this", "remote_use_chat_endpoint": "If this is enabled, then the integration will use the chat completion HTTP endpoint instead of the text completion one.", "extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.", - "allowed_service_call_arguments": "This is the list of parameters that are allowed to be passed to Home Assistant service calls.", "gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.", "prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below." }, @@ -103,6 +103,7 @@ "step": { "init": { "data": { + "llm_hass_api": "Selected LLM API", "max_new_tokens": "Maximum tokens to return in response", "prompt": "System Prompt", "prompt_template": "Prompt Format", @@ -115,7 +116,6 @@ "ollama_keep_alive": "Keep Alive/Inactivity Timeout (minutes)", "ollama_json_mode": "JSON Output Mode", "extra_attributes_to_expose": "Additional attribute to expose in the context", - "allowed_service_call_arguments": "Arguments allowed to be pass to service calls", "enable_flash_attention": "Enable Flash Attention", "gbnf_grammar": "Enable GBNF Grammar", "gbnf_grammar_file": "GBNF Grammar Filename", @@ -139,11 +139,11 @@ "n_batch_threads": "Batch Thread Count" }, "data_description": { + "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices.", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this", "remote_use_chat_endpoint": "If this is enabled, then the integration will use the chat completion HTTP endpoint instead of the text completion one.", "extra_attributes_to_expose": "This is the list of Home Assistant 'attributes' that are exposed to the model. This limits how much information the model is able to see and answer questions on.", - "allowed_service_call_arguments": "This is the list of parameters that are allowed to be passed to Home Assistant service calls.", "gbnf_grammar": "Forces the model to output properly formatted responses. Ensure the file specified below exists in the integration directory.", "prompt_caching": "Prompt caching attempts to pre-process the prompt (house state) and cache the processing that needs to be done to understand the prompt. Enabling this will cause the model to re-process the prompt any time an entity state changes in the house, restricted by the interval below." } diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index 2fabec5..f132aae 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -37,10 +37,15 @@ def _flatten(current_schema, prefix=''): _flatten(subval, prefix) elif isinstance(current_schema.schema, dict): for key, val in current_schema.schema.items(): + if isinstance(key, vol.Any): + key = "|".join(key.validators) + if isinstance(key, vol.Optional): + key = "?" + str(key) + _flatten(val, prefix + str(key) + '/') elif isinstance(current_schema, vol.validators._WithSubValidators): for subval in current_schema.validators: - _flatten(subval, prefix) + _flatten(subval, prefix) elif callable(current_schema): flattened.append(prefix[:-1] if prefix else prefix) _flatten(schema) diff --git a/requirements.txt b/requirements.txt index 0b299a7..da3bcd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +# training + dataset requirements transformers tensorboard datasets @@ -11,9 +12,17 @@ sentencepiece deep-translator langcodes +# integration requirements +requests==2.31.0 +huggingface-hub==0.23.0 +webcolors==1.13 + +# types from Home Assistant homeassistant hassil home-assistant-intents + +# testing requirements pytest pytest-asyncio pytest-homeassistant-custom-component \ No newline at end of file diff --git a/tests/llama_conversation/test_agent.py b/tests/llama_conversation/test_agent.py index 05b6722..571706c 100644 --- a/tests/llama_conversation/test_agent.py +++ b/tests/llama_conversation/test_agent.py @@ -18,7 +18,6 @@ CONF_BACKEND_TYPE, CONF_DOWNLOADED_MODEL_FILE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, - CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_PROMPT_TEMPLATE, CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, @@ -54,7 +53,6 @@ DEFAULT_BACKEND_TYPE, DEFAULT_REQUEST_TIMEOUT, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, - DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_PROMPT_TEMPLATE, DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, diff --git a/tests/llama_conversation/test_config_flow.py b/tests/llama_conversation/test_config_flow.py index 13fb4b5..b5d8af7 100644 --- a/tests/llama_conversation/test_config_flow.py +++ b/tests/llama_conversation/test_config_flow.py @@ -24,7 +24,6 @@ CONF_BACKEND_TYPE, CONF_DOWNLOADED_MODEL_FILE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, - CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_PROMPT_TEMPLATE, CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, @@ -66,7 +65,6 @@ DEFAULT_BACKEND_TYPE, DEFAULT_REQUEST_TIMEOUT, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, - DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_PROMPT_TEMPLATE, DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, @@ -171,7 +169,6 @@ async def test_validate_config_flow_generic_openai(mock_setup_entry, hass: HomeA CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT, CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, - CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS, @@ -261,7 +258,6 @@ async def test_validate_config_flow_ollama(mock_setup_entry, hass: HomeAssistant CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT, CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, - CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_REFRESH_SYSTEM_PROMPT: DEFAULT_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION: DEFAULT_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS: DEFAULT_REMEMBER_NUM_INTERACTIONS, @@ -299,7 +295,7 @@ def test_validate_options_schema(): universal_options = [ CONF_PROMPT, CONF_PROMPT_TEMPLATE, CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, CONF_IN_CONTEXT_EXAMPLES_FILE, CONF_NUM_IN_CONTEXT_EXAMPLES, - CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, + CONF_MAX_TOKENS, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_SERVICE_CALL_REGEX, CONF_REFRESH_SYSTEM_PROMPT, CONF_REMEMBER_CONVERSATION, CONF_REMEMBER_NUM_INTERACTIONS, ] From 8546767310f902a7c1568e7935c8303ff94dd50b Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Sat, 1 Jun 2024 23:06:42 -0400 Subject: [PATCH 03/13] version with working ICL using the new APIs --- custom_components/llama_conversation/agent.py | 377 ++++++++++++------ custom_components/llama_conversation/const.py | 17 +- .../in_context_examples.csv | 65 ++- .../llama_conversation/output.gbnf | 29 -- generate.py | 4 +- 5 files changed, 298 insertions(+), 194 deletions(-) delete mode 100644 custom_components/llama_conversation/output.gbnf diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index e691656..1455bbd 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -25,7 +25,9 @@ from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm from homeassistant.helpers.event import async_track_state_change, async_call_later -from homeassistant.util import ulid +from homeassistant.util import ulid, color + +import voluptuous_serialize from .utils import closest_color, flatten_vol_schema, install_llama_cpp_python, validate_llama_cpp_python_installation from .const import ( @@ -146,7 +148,7 @@ def _load_icl_examples(self, filename: str): with open(icl_filename, encoding="utf-8-sig") as f: self.in_context_examples = list(csv.DictReader(f)) - if set(self.in_context_examples[0].keys()) != set(["service", "response" ]): + if set(self.in_context_examples[0].keys()) != set(["type", "request", "tool", "response" ]): raise Exception("ICL csv file did not have 2 columns: service & response") if len(self.in_context_examples) == 0: @@ -210,6 +212,17 @@ async def async_process( remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS) service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX) + llm_context = llm.LLMContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=ha_conversation.DOMAIN, + device_id=user_input.device_id, + ) + + _LOGGER.info(llm_context) + try: service_call_pattern = re.compile(service_call_regex) except Exception as err: @@ -224,11 +237,11 @@ async def async_process( response=intent_response, conversation_id=conversation_id ) - llm_api: llm.API | None = None + llm_api: llm.APIInstance | None = None if self.entry.options.get(CONF_LLM_HASS_API): try: - llm_api = llm.async_get_api( - self.hass, self.entry.options[CONF_LLM_HASS_API] + llm_api = await llm.async_get_api( + self.hass, self.entry.options[CONF_LLM_HASS_API], llm_context ) except HomeAssistantError as err: _LOGGER.error("Error getting LLM API: %s", err) @@ -283,7 +296,7 @@ async def async_process( intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, + intent.IntentResponseErrorCode.FAILED_TO_HANDLE, f"Sorry, there was a problem talking to the backend: {repr(err)}", ) return ConversationResult( @@ -298,15 +311,27 @@ async def async_process( self.history[conversation_id] = conversation # parse response - exposed_entities = list(self._async_get_exposed_entities()[0].keys()) - - to_say = "" + to_say = service_call_pattern.sub("", response.strip()) for block in service_call_pattern.findall(response.strip()): - _LOGGER.info(f"calling tool: {block}") - parsed_tool_call = json.loads(block) + try: + vol.Schema({ + vol.Required("name"): str, + vol.Required("arguments"): dict, + })(parsed_tool_call) + except vol.Error as ex: + _LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}") + + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.NO_INTENT_MATCH, + f"I'm sorry, I didn't produce a correctly formatted tool call! Please see the logs for more info.", + ) + return ConversationResult( + response=intent_response, conversation_id=conversation_id + ) - to_say = to_say + parsed_tool_call.get("to_say", "") + _LOGGER.info(f"calling tool: {block}") # try to fix certain arguments # make sure brightness is 0-255 and not a percentage @@ -318,31 +343,35 @@ async def async_process( parsed_tool_call["arguments"]["rgb_color"] = [ int(x) for x in parsed_tool_call["arguments"]["rgb_color"][1:-1].split(",") ] tool_input = llm.ToolInput( - tool_name=parsed_tool_call["tool"], + tool_name=parsed_tool_call["name"], tool_args=parsed_tool_call["arguments"], - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=ha_conversation.DOMAIN, ) - # TODO: multi-turn with the model where it acts on the response from the tool? try: - tool_response = await llm_api.async_call_tool( - self.hass, tool_input - ) + tool_response = await llm_api.async_call_tool(tool_input) except (HomeAssistantError, vol.Invalid) as e: tool_response = {"error": type(e).__name__} if str(e): tool_response["error_text"] = str(e) + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.NO_INTENT_MATCH, + f"There was an error calling the tool! ({tool_response})", + ) + return ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + _LOGGER.debug("Tool response: %s", tool_response) + + # TODO: optionally handle multi-turn with the model where it acts on the response from the tool + # if self.entry.options.get("", ""): + # pass # generate intent response to Home Assistant intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(to_say) - intent_response.set return ConversationResult( response=intent_response, conversation_id=conversation_id ) @@ -398,44 +427,68 @@ def _format_prompt( _LOGGER.debug(formatted_prompt) return formatted_prompt - def _generate_system_prompt(self, prompt_template: str, llm_api: llm.API) -> str: + def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance) -> str: """Generate the system prompt with current entity states""" entities_to_expose, domains = self._async_get_exposed_entities() extra_attributes_to_expose = self.entry.options \ .get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE) - def icl_example_generator(num_examples, entity_names, service_names): - entity_domains = set([x.split(".")[0] for x in entity_names]) + def icl_example_generator(num_examples, entity_names): entity_names = entity_names[:] - - # filter out examples for disabled services - selected_in_context_examples = [] - for x in self.in_context_examples: - if x["service"] in service_names and x["service"].split(".")[0] in entity_domains: - selected_in_context_examples.append(x) - - # if we filtered everything then just sample randomly - if len(selected_in_context_examples) == 0: - selected_in_context_examples = self.in_context_examples[:] - - random.shuffle(selected_in_context_examples) + in_context_examples = self.in_context_examples[:] + + random.shuffle(in_context_examples) random.shuffle(entity_names) - num_examples_to_generate = min(num_examples, len(selected_in_context_examples)) + num_examples_to_generate = min(num_examples, len(in_context_examples)) if num_examples_to_generate < num_examples: - _LOGGER.warning(f"Attempted to generate {num_examples} ICL examples for conversation, but only {len(selected_in_context_examples)} are available!") + _LOGGER.warning(f"Attempted to generate {num_examples} ICL examples for conversation, but only {len(in_context_examples)} are available!") + examples = [] for x in range(num_examples_to_generate): - chosen_example = selected_in_context_examples.pop() - chosen_service = chosen_example["service"] - device = [ x for x in entity_names if x.split(".")[0] == chosen_service.split(".")[0] ][0] - example = { - "to_say": chosen_example["response"], - "tool": chosen_service, - "arguments": { "name": device }, - } - yield json.dumps(example) + "\n" + chosen_example = in_context_examples.pop() + request = chosen_example["request"] + response = chosen_example["response"] + + random_device = [ x for x in entity_names if x.split(".")[0] == chosen_example["type"] ][0] + random_area = "bedroom" # todo, pick a random area + random_brightness = round(random.random(), 2) + random_color = random.choice(list(color.COLORS.keys())) + + tool_arguments = {} + + if "" in request: + request.replace("", random_area) + response.replace("", random_area) + tool_arguments["area"] = random_area + + if "" in request: + request.replace("", random_device) + response.replace("", random_device) + tool_arguments["name"] = random_device + + if "" in request: + request.replace("", str(random_brightness)) + response.replace("", str(random_brightness)) + tool_arguments["brightness"] = random_brightness + + if "" in request: + request.replace("", random_color) + response.replace("", random_color) + tool_arguments["color"] = random_color + + + examples.append({ + "request": request, + "response": response, + "tool": { + "name": chosen_example["tool"], + "arguments": tool_arguments + } + }) + + return examples def expose_attributes(attributes): result = attributes["state"] @@ -477,23 +530,107 @@ def expose_attributes(attributes): formatted_states = "\n".join(device_states) + "\n" if llm_api: + def format_tool(tool: llm.Tool, reduced: bool): + def serialize(value): + """This is horrible. Why is vol so hard to convert back into a readable schema?""" + + if value is cv.ensure_list: + return { "type": "list" } + + if value is color.color_name_to_rgb: + return { "type": "string" } + + if value is intent.non_empty_string: + return { "type": "string" } + + # media player registers an intent using a lambda... + # there's literally no way to detect that properly + try: + if value(100) == 1: + _LOGGER.debug("bad") + return { "type": "integer" } + except Exception: + pass + + if isinstance(value, list): + result = {} + for x in value: + result.update(serialize(x)) + return result + + return cv.custom_serializer(value) + + raw_parameters: list = voluptuous_serialize.convert( + tool.parameters, custom_serializer=serialize) + + # handle vol.Any in the key side of things + processed_parameters = [] + for param in raw_parameters: + _LOGGER.info(param["name"]) + if isinstance(param["name"], vol.Any): + for possible_name in param["name"].validators: + actual_param = param.copy() + actual_param["name"] = possible_name + actual_param["required"] = False + processed_parameters.append(actual_param) + else: + processed_parameters.append(param) + + if reduced: + return { + "name": tool.name, + "description": tool.description, + "parameters": { + "properties": { + x["name"]: x.get("type", "string") for x in processed_parameters + }, + "required": [ + x["name"] for x in processed_parameters if x.get("required") + ] + } + } + else: + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": { + x["name"]: { + "type": x.get("type", "string"), + "description": x.get("description", ""), + } for x in processed_parameters + }, + "required": [ + x["name"] for x in processed_parameters if x.get("required") + ] + } + } + } + + # def format_tool_reduced(tool: llm.Tool): + # return f"{tool.name}({', '.join(flatten_vol_schema(tool.parameters))}) - {tool.description}" + tools = [ - f"{tool.name}({', '.join(flatten_vol_schema(tool.parameters))}) - {tool.description}" - for tool in llm_api.async_get_tools() + format_tool(tool, True) + for tool in llm_api.tools ] - formatted_services = llm_api.prompt_template + "\n" + "\n".join(tools) + formatted_tools = json.dumps(tools) else: - formatted_services = "No tools exposed." + formatted_tools = "" render_variables = { "devices": formatted_states, - "services": formatted_services, + "tools": formatted_tools, } if self.in_context_examples: - # num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES)) - # render_variables["response_examples"] = "\n".join(icl_example_generator(num_examples, list(entities_to_expose.keys()), all_service_names)) - render_variables["response_examples"] = "" + num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES)) + render_variables["response_examples"] = icl_example_generator(num_examples, list(entities_to_expose.keys())) + + _LOGGER.info(render_variables) return template.Template(prompt_template, self.hass).async_render( render_variables, @@ -704,86 +841,86 @@ async def _async_cache_prompt(self, entity, old_state, new_state): refresh_end = time.time() _LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec") - def _cache_prompt(self) -> None: - # if a refresh is already scheduled then exit - if self.cache_refresh_after_cooldown: - return + # def _cache_prompt(self) -> None: + # # if a refresh is already scheduled then exit + # if self.cache_refresh_after_cooldown: + # return - # if we are inside the cooldown period, request a refresh and exit - current_time = time.time() - fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) - if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval: - self.cache_refresh_after_cooldown = True - return + # # if we are inside the cooldown period, request a refresh and exit + # current_time = time.time() + # fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) + # if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval: + # self.cache_refresh_after_cooldown = True + # return - # try to acquire the lock, if we are still running for some reason, request a refresh and exit - lock_acquired = self.model_lock.acquire(False) - if not lock_acquired: - self.cache_refresh_after_cooldown = True - return + # # try to acquire the lock, if we are still running for some reason, request a refresh and exit + # lock_acquired = self.model_lock.acquire(False) + # if not lock_acquired: + # self.cache_refresh_after_cooldown = True + # return - llm_api: llm.API | None = None - if self.entry.options.get(CONF_LLM_HASS_API): - try: - llm_api = llm.async_get_api( - self.hass, self.entry.options[CONF_LLM_HASS_API] - ) - except HomeAssistantError: - _LOGGER.exception("Failed to get LLM API when caching prompt!") - return + # llm_api: llm.APIInstance | None = None + # if self.entry.options.get(CONF_LLM_HASS_API): + # try: + # llm_api = await llm.async_get_api( + # self.hass, self.entry.options[CONF_LLM_HASS_API] + # ) + # except HomeAssistantError: + # _LOGGER.exception("Failed to get LLM API when caching prompt!") + # return - try: - raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) - prompt = self._format_prompt([ - { "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)}, - { "role": "user", "message": "" } - ], include_generation_prompt=False) + # try: + # raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) + # prompt = self._format_prompt([ + # { "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)}, + # { "role": "user", "message": "" } + # ], include_generation_prompt=False) - input_tokens = self.llm.tokenize( - prompt.encode(), add_bos=False - ) + # input_tokens = self.llm.tokenize( + # prompt.encode(), add_bos=False + # ) - temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) - top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)) - top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) - grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None + # temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) + # top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)) + # top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) + # grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None - _LOGGER.debug(f"Options: {self.entry.options}") + # _LOGGER.debug(f"Options: {self.entry.options}") - _LOGGER.debug(f"Processing {len(input_tokens)} input tokens...") + # _LOGGER.debug(f"Processing {len(input_tokens)} input tokens...") - # grab just one token. should prime the kv cache with the system prompt - next(self.llm.generate( - input_tokens, - temp=temperature, - top_k=top_k, - top_p=top_p, - grammar=grammar - )) + # # grab just one token. should prime the kv cache with the system prompt + # next(self.llm.generate( + # input_tokens, + # temp=temperature, + # top_k=top_k, + # top_p=top_p, + # grammar=grammar + # )) - self.last_cache_prime = time.time() - finally: - self.model_lock.release() + # self.last_cache_prime = time.time() + # finally: + # self.model_lock.release() - # schedule a refresh using async_call_later - # if the flag is set after the delay then we do another refresh + # # schedule a refresh using async_call_later + # # if the flag is set after the delay then we do another refresh - @callback - async def refresh_if_requested(_now): - if self.cache_refresh_after_cooldown: - self.cache_refresh_after_cooldown = False + # @callback + # async def refresh_if_requested(_now): + # if self.cache_refresh_after_cooldown: + # self.cache_refresh_after_cooldown = False - refresh_start = time.time() - _LOGGER.debug(f"refreshing cached prompt after cooldown...") - await self.hass.async_add_executor_job(self._cache_prompt) + # refresh_start = time.time() + # _LOGGER.debug(f"refreshing cached prompt after cooldown...") + # await self.hass.async_add_executor_job(self._cache_prompt) - refresh_end = time.time() - _LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec") + # refresh_end = time.time() + # _LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec") - refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) - async_call_later(self.hass, float(refresh_delay), refresh_if_requested) + # refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) + # async_call_later(self.hass, float(refresh_delay), refresh_if_requested) def _generate(self, conversation: dict) -> str: diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index cf936e5..27ad006 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -11,16 +11,21 @@ } DEFAULT_PROMPT_BASE = """ The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }} -Services: {{ services }} +Tools: {{ tools }} Devices: {{ devices }}""" ICL_EXTRAS = """ -Respond to the following user instruction by responding in the same format as the following examples: -{{ response_examples }}""" +{% for item in response_examples %} +{{ item.request }} +{{ item.response }} + {{ item.tool | to_json }} +{% endfor %}""" ICL_NO_SYSTEM_PROMPT_EXTRAS = """ -Respond to the following user instruction by responding in the same format as the following examples: -{{ response_examples }} - +{% for item in response_examples %} +{{ item.request }} +{{ item.response }} + {{ item.tool | to_json }} +{% endfor %} User instruction:""" DEFAULT_PROMPT = DEFAULT_PROMPT_BASE + ICL_EXTRAS CONF_CHAT_MODEL = "huggingface_model" diff --git a/custom_components/llama_conversation/in_context_examples.csv b/custom_components/llama_conversation/in_context_examples.csv index efc2fd7..98ab251 100644 --- a/custom_components/llama_conversation/in_context_examples.csv +++ b/custom_components/llama_conversation/in_context_examples.csv @@ -1,37 +1,28 @@ -service,response -fan.turn_on,Turning on the fan for you. -fan.turn_off,Switching off the fan as requested. -fan.toggle,I'll toggle the fan's state for you. -fan.increase_speed,Increasing the fan speed for you. -fan.decrease_speed,Reducing the fan speed as you requested. -cover.open_cover,Opening the garage door for you. -cover.close_cover,Closing the garage door as requested. -cover.stop_cover,Stopping the garage door now. -cover.toggle,Toggling the garage door state for you. -light.turn_on,Turning on the light for you. -light.turn_off,Turning off the light as requested. -light.toggle,Toggling the light for you. -lock.lock,Locking the door for you. -lock.unlock,Unlocking the door as you requested. -media_player.turn_on,Turning on the media player for you. -media_player.turn_off,Turning off the media player as requested. -media_player.toggle,Toggling the media player for you. -media_player.volume_up,Increasing the volume for you. -media_player.volume_down,Reducing the volume as you requested. -media_player.volume_mute,Muting the volume for you. -media_player.media_play_pause,Toggling play/pause on the media player. -media_player.media_play,Starting media playback. -media_player.media_pause,Pausing the media playback. -media_player.media_stop,Stopping the media playback. -media_player.media_next_track,Skipping to the next track. -media_player.media_previous_track,Going back to the previous track. -switch.turn_on,Turning on the switch for you. -switch.turn_off,Turning off the switch as requested. -switch.toggle,Toggling the switch for you. -vacuum.start,Starting the vacuum now. -vacuum.stop,Stopping the vacuum. -vacuum.pause,Pausing the vacuum for now. -vacuum.return_to_base,Sending the vacuum back to its base. -todo.add_item,the todo has been added to your todo list. -timer.start,Starting timer now. -timer.cancel,Timer has been canceled. \ No newline at end of file +type,request,tool,response +fan,Turn on the ,HassTurnOn,Turning on the fan for you. +fan,Turn on the fans in the ,HassTurnOn,Turning on the fans in +fan,Turn off the ,HassTurnOff,Switching off the fan as requested. +fan,Toggle the fan for me,HassToggle,I'll toggle the fan's state for you. +cover,Can you open the ?,HassOpenCover,Opening the garage door for you. +cover,Close the now,HassCloseCover,Closing the garage door as requested. +light,Turn the light on,HassTurnOn,Turning on the light for you. +light,Lights off in ,HassTurnOff,Turning off the light as requested. +light,Hit toggle for the lights in ,HassToggle,Toggling the lights in the for you. +light,Set the brightness for to ,HassLightSet,Setting the brightness now. +light,Make the lights in ,HassLightSet,The color should be changed now. +media_player,TODO,HassTurnOn,Turning on the media player for you. +media_player,TODO,HassTurnOff,Turning off the media player as requested. +media_player,TODO,HassToggle,Toggling the media player for you. +media_player,TODO,HassSetVolume,Muting the volume for you. +media_player,TODO,HassSetVolume,Setting the volume as requested. +media_player,TODO,HassMediaUnpause,Starting media playback. +media_player,TODO,HassMediaPause,Pausing the media playback. +media_player,TODO,HassMediaNext,Skipping to the next track. +switch,TODO,HassTurnOn,Turning on the switch for you. +switch,TODO,HassTurnOff,Turning off the switch as requested. +switch,TODO,HassToggle,Toggling the switch for you. +vacuum,TODO,HassVacuumStart,Starting the vacuum now. +vacuum,TODO,HassVacuumReturnToBase,Sending the vacuum back to its base. +todo,TODO,HassListAddItem,the todo has been added to your todo list. +timer,TODO,HassStartTimer,Starting timer now. +timer,TODO,HassCancelTimer,Timer has been canceled. \ No newline at end of file diff --git a/custom_components/llama_conversation/output.gbnf b/custom_components/llama_conversation/output.gbnf deleted file mode 100644 index 26069ba..0000000 --- a/custom_components/llama_conversation/output.gbnf +++ /dev/null @@ -1,29 +0,0 @@ -root ::= (tosay "\n")+ functioncalls? - -tosay ::= [0-9a-zA-Z #%.?!]* -functioncalls ::= - "```homeassistant\n" (object ws)* "```" - -value ::= object | array | string | number | ("true" | "false" | "null") ws -object ::= - "{" ws ( - string ":" ws value - ("," ws string ":" ws value)* - )? "}" ws - -array ::= - "[" ws ( - value - ("," ws value)* - )? "]" ws - -string ::= - "\"" ( - [^"\\] | - "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes - )* "\"" ws - -number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws - -# Optional space: by convention, applied in this grammar after literal chars when allowed -ws ::= ([ \t\n] ws)? \ No newline at end of file diff --git a/generate.py b/generate.py index d56f291..f9d170d 100644 --- a/generate.py +++ b/generate.py @@ -31,7 +31,7 @@ def generate(model, tokenizer, prompt): def format_example(example): sys_prompt = SYSTEM_PROMPT - services_block = "Services: " + ", ".join(sorted(example["available_services"])) + services_block = "Services: " + ", ".join(sorted(example["available_tools"])) states_block = "Devices:\n" + "\n".join(example["states"]) question = "Request:\n" + example["question"] response_start = "Response:\n" @@ -52,7 +52,7 @@ def main(): "fan.family_room = off", "lock.front_door = locked" ], - "available_services": ["turn_on", "turn_off", "toggle", "lock", "unlock" ], + "available_tools": ["turn_on", "turn_off", "toggle", "lock", "unlock" ], "question": request, } From 00d002d9c04437585c1a21da2024d62c7b8a0589 Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Sun, 2 Jun 2024 12:25:26 -0400 Subject: [PATCH 04/13] make tool formats work + dynamic quantization detection from HF --- .../llama_conversation/__init__.py | 4 +- custom_components/llama_conversation/agent.py | 292 +++++++++--------- .../llama_conversation/config_flow.py | 57 +++- custom_components/llama_conversation/const.py | 14 +- .../llama_conversation/translations/en.json | 11 + custom_components/llama_conversation/utils.py | 41 ++- 6 files changed, 253 insertions(+), 166 deletions(-) diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index 0d35ff1..5a8a915 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -38,8 +38,8 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry): hass.data[DOMAIN][entry.entry_id] = entry # call update handler - agent: LocalLLMAgent = await ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id) - agent._update_options() + agent: LocalLLMAgent = ha_conversation.get_agent_manager(hass).async_get_agent(entry.entry_id) + await hass.async_add_executor_job(agent._update_options) return True diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index 1455bbd..8d775f1 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -29,7 +29,8 @@ import voluptuous_serialize -from .utils import closest_color, flatten_vol_schema, install_llama_cpp_python, validate_llama_cpp_python_installation +from .utils import closest_color, flatten_vol_schema, custom_custom_serializer, install_llama_cpp_python, \ + validate_llama_cpp_python_installation from .const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, @@ -44,6 +45,7 @@ CONF_DOWNLOADED_MODEL_FILE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_PROMPT_TEMPLATE, + CONF_TOOL_FORMAT, CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, @@ -78,6 +80,7 @@ DEFAULT_REQUEST_TIMEOUT, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_PROMPT_TEMPLATE, + DEFAULT_TOOL_FORMAT, DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, @@ -103,6 +106,9 @@ TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT, DOMAIN, PROMPT_TEMPLATE_DESCRIPTIONS, + TOOL_FORMAT_FULL, + TOOL_FORMAT_REDUCED, + TOOL_FORMAT_MINIMAL, ) # make type checking work for llama-cpp-python without importing it directly at runtime @@ -310,6 +316,14 @@ async def async_process( conversation.pop(1) self.history[conversation_id] = conversation + if not llm_api: + # return the output without messing with it if there is no API exposed to the model + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_speech(response.strip()) + return ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + # parse response to_say = service_call_pattern.sub("", response.strip()) for block in service_call_pattern.findall(response.strip()): @@ -426,6 +440,124 @@ def _format_prompt( _LOGGER.debug(formatted_prompt) return formatted_prompt + + def _format_tool(self, tool: llm.Tool): + style = self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) + + if style == TOOL_FORMAT_MINIMAL: + return f"{tool.name}({', '.join(flatten_vol_schema(tool.parameters))}) - {tool.description}" + + raw_parameters: list = voluptuous_serialize.convert( + tool.parameters, custom_serializer=custom_custom_serializer) + + # handle vol.Any in the key side of things + processed_parameters = [] + for param in raw_parameters: + _LOGGER.info(param["name"]) + if isinstance(param["name"], vol.Any): + for possible_name in param["name"].validators: + actual_param = param.copy() + actual_param["name"] = possible_name + actual_param["required"] = False + processed_parameters.append(actual_param) + else: + processed_parameters.append(param) + + if style == TOOL_FORMAT_REDUCED: + return { + "name": tool.name, + "description": tool.description, + "parameters": { + "properties": { + x["name"]: x.get("type", "string") for x in processed_parameters + }, + "required": [ + x["name"] for x in processed_parameters if x.get("required") + ] + } + } + elif style == TOOL_FORMAT_FULL: + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": { + x["name"]: { + "type": x.get("type", "string"), + "description": x.get("description", ""), + } for x in processed_parameters + }, + "required": [ + x["name"] for x in processed_parameters if x.get("required") + ] + } + } + } + + raise Exception(f"Unknown tool format {style}") + + def _generate_icl_examples(self, num_examples, entity_names): + entity_names = entity_names[:] + entity_domains = set([x.split(".")[0] for x in entity_names]) + + in_context_examples = [ + x for x in self.in_context_examples + if x["type"] in entity_domains + ] + + random.shuffle(in_context_examples) + random.shuffle(entity_names) + + num_examples_to_generate = min(num_examples, len(in_context_examples)) + if num_examples_to_generate < num_examples: + _LOGGER.warning(f"Attempted to generate {num_examples} ICL examples for conversation, but only {len(in_context_examples)} are available!") + + examples = [] + for _ in range(num_examples_to_generate): + chosen_example = in_context_examples.pop() + request = chosen_example["request"] + response = chosen_example["response"] + + random_device = [ x for x in entity_names if x.split(".")[0] == chosen_example["type"] ][0] + random_area = "bedroom" # todo, pick a random area + random_brightness = round(random.random(), 2) + random_color = random.choice(list(color.COLORS.keys())) + + tool_arguments = {} + + if "" in request: + request = request.replace("", random_area) + response = response.replace("", random_area) + tool_arguments["area"] = random_area + + if "" in request: + request = request.replace("", random_device) + response = response.replace("", random_device) + tool_arguments["name"] = random_device + + if "" in request: + request = request.replace("", str(random_brightness)) + response = response.replace("", str(random_brightness)) + tool_arguments["brightness"] = random_brightness + + if "" in request: + request = request.replace("", random_color) + response = response.replace("", random_color) + tool_arguments["color"] = random_color + + examples.append({ + "request": request, + "response": response, + "tool": { + "name": chosen_example["tool"], + "arguments": tool_arguments + } + }) + + return examples def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance) -> str: """Generate the system prompt with current entity states""" @@ -433,62 +565,6 @@ def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance extra_attributes_to_expose = self.entry.options \ .get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE) - - def icl_example_generator(num_examples, entity_names): - entity_names = entity_names[:] - in_context_examples = self.in_context_examples[:] - - random.shuffle(in_context_examples) - random.shuffle(entity_names) - - num_examples_to_generate = min(num_examples, len(in_context_examples)) - if num_examples_to_generate < num_examples: - _LOGGER.warning(f"Attempted to generate {num_examples} ICL examples for conversation, but only {len(in_context_examples)} are available!") - - examples = [] - for x in range(num_examples_to_generate): - chosen_example = in_context_examples.pop() - request = chosen_example["request"] - response = chosen_example["response"] - - random_device = [ x for x in entity_names if x.split(".")[0] == chosen_example["type"] ][0] - random_area = "bedroom" # todo, pick a random area - random_brightness = round(random.random(), 2) - random_color = random.choice(list(color.COLORS.keys())) - - tool_arguments = {} - - if "" in request: - request.replace("", random_area) - response.replace("", random_area) - tool_arguments["area"] = random_area - - if "" in request: - request.replace("", random_device) - response.replace("", random_device) - tool_arguments["name"] = random_device - - if "" in request: - request.replace("", str(random_brightness)) - response.replace("", str(random_brightness)) - tool_arguments["brightness"] = random_brightness - - if "" in request: - request.replace("", random_color) - response.replace("", random_color) - tool_arguments["color"] = random_color - - - examples.append({ - "request": request, - "response": response, - "tool": { - "name": chosen_example["tool"], - "arguments": tool_arguments - } - }) - - return examples def expose_attributes(attributes): result = attributes["state"] @@ -529,108 +605,24 @@ def expose_attributes(attributes): formatted_states = "\n".join(device_states) + "\n" - if llm_api: - def format_tool(tool: llm.Tool, reduced: bool): - def serialize(value): - """This is horrible. Why is vol so hard to convert back into a readable schema?""" - - if value is cv.ensure_list: - return { "type": "list" } - - if value is color.color_name_to_rgb: - return { "type": "string" } - - if value is intent.non_empty_string: - return { "type": "string" } - - # media player registers an intent using a lambda... - # there's literally no way to detect that properly - try: - if value(100) == 1: - _LOGGER.debug("bad") - return { "type": "integer" } - except Exception: - pass - - if isinstance(value, list): - result = {} - for x in value: - result.update(serialize(x)) - return result - - return cv.custom_serializer(value) - - raw_parameters: list = voluptuous_serialize.convert( - tool.parameters, custom_serializer=serialize) - - # handle vol.Any in the key side of things - processed_parameters = [] - for param in raw_parameters: - _LOGGER.info(param["name"]) - if isinstance(param["name"], vol.Any): - for possible_name in param["name"].validators: - actual_param = param.copy() - actual_param["name"] = possible_name - actual_param["required"] = False - processed_parameters.append(actual_param) - else: - processed_parameters.append(param) - - if reduced: - return { - "name": tool.name, - "description": tool.description, - "parameters": { - "properties": { - x["name"]: x.get("type", "string") for x in processed_parameters - }, - "required": [ - x["name"] for x in processed_parameters if x.get("required") - ] - } - } - else: - return { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": { - "type": "object", - "properties": { - x["name"]: { - "type": x.get("type", "string"), - "description": x.get("description", ""), - } for x in processed_parameters - }, - "required": [ - x["name"] for x in processed_parameters if x.get("required") - ] - } - } - } - - # def format_tool_reduced(tool: llm.Tool): - # return f"{tool.name}({', '.join(flatten_vol_schema(tool.parameters))}) - {tool.description}" - - tools = [ - format_tool(tool, True) + if llm_api: + formatted_tools = json.dumps([ + self._format_tool(tool) for tool in llm_api.tools - ] - formatted_tools = json.dumps(tools) + ]) else: - formatted_tools = "" + formatted_tools = "No tools were provided. If the user requests you interact with a device, tell them you are unable to do so." render_variables = { "devices": formatted_states, "tools": formatted_tools, + "response_examples": [] } - if self.in_context_examples: + # only pass examples if there are loaded examples + an API was exposed + if self.in_context_examples and llm_api: num_examples = int(self.entry.options.get(CONF_NUM_IN_CONTEXT_EXAMPLES, DEFAULT_NUM_IN_CONTEXT_EXAMPLES)) - render_variables["response_examples"] = icl_example_generator(num_examples, list(entities_to_expose.keys())) - - _LOGGER.info(render_variables) + render_variables["response_examples"] = self._generate_icl_examples(num_examples, list(entities_to_expose.keys())) return template.Template(prompt_template, self.hass).async_render( render_variables, diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index b8331b7..c122cac 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -38,7 +38,7 @@ from homeassistant.util.package import is_installed from importlib.metadata import version -from .utils import download_model_from_hf, install_llama_cpp_python +from .utils import download_model_from_hf, install_llama_cpp_python, MissingQuantizationException from .const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, @@ -56,6 +56,7 @@ CONF_DOWNLOADED_MODEL_QUANTIZATION, CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS, CONF_PROMPT_TEMPLATE, + CONF_TOOL_FORMAT, CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, @@ -95,6 +96,7 @@ DEFAULT_BACKEND_TYPE, DEFAULT_DOWNLOADED_MODEL_QUANTIZATION, DEFAULT_PROMPT_TEMPLATE, + DEFAULT_TOOL_FORMAT, DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, @@ -123,6 +125,9 @@ BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER, BACKEND_TYPE_OLLAMA, PROMPT_TEMPLATE_DESCRIPTIONS, + TOOL_FORMAT_FULL, + TOOL_FORMAT_REDUCED, + TOOL_FORMAT_MINIMAL, TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT, @@ -172,7 +177,7 @@ def STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file=None, selected_language=Non } ) -def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_quantization=None, selected_language=None): +def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_quantization=None, selected_language=None, available_quantizations=None): return vol.Schema( { vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig( @@ -181,7 +186,7 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q multiple=False, mode=SelectSelectorMode.DROPDOWN, )), - vol.Required(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS), + vol.Required(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(available_quantizations if available_quantizations else CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS), vol.Required(CONF_SELECTED_LANGUAGE, default=selected_language if selected_language else "en"): SelectSelector(SelectSelectorConfig( options=CONF_SELECTED_LANGUAGE_OPTIONS, translation_key=CONF_SELECTED_LANGUAGE, @@ -387,13 +392,35 @@ async def async_step_local_model( raise ValueError() if self.download_error: - errors["base"] = "download_failed" - description_placeholders["exception"] = str(self.download_error) - schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA( - chat_model=self.model_config[CONF_CHAT_MODEL], - downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION], - selected_language=self.selected_language - ) + if isinstance(self.download_error, MissingQuantizationException): + available_quants = list(set(self.download_error.available_quants).intersection(set(CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS))) + + if len(available_quants) == 0: + errors["base"] = "no_supported_ggufs" + schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA( + chat_model=self.model_config[CONF_CHAT_MODEL], + downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION], + selected_language=self.selected_language + ) + else: + errors["base"] = "missing_quantization" + description_placeholders["missing"] = self.download_error.missing_quant + description_placeholders["available"] = ", ".join(self.download_error.available_quants) + + schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA( + chat_model=self.model_config[CONF_CHAT_MODEL], + downloaded_model_quantization=self.download_error.available_quants[0], + selected_language=self.selected_language, + available_quantizations=available_quants, + ) + else: + errors["base"] = "download_failed" + description_placeholders["exception"] = str(self.download_error) + schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA( + chat_model=self.model_config[CONF_CHAT_MODEL], + downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION], + selected_language=self.selected_language + ) if user_input and "result" not in user_input: self.selected_language = user_input.pop(CONF_SELECTED_LANGUAGE, self.hass.config.language) @@ -730,6 +757,16 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT multiple=False, mode=SelectSelectorMode.DROPDOWN, )), + vol.Required( + CONF_TOOL_FORMAT, + description={"suggested_value": options.get(CONF_TOOL_FORMAT)}, + default=DEFAULT_TOOL_FORMAT, + ): SelectSelector(SelectSelectorConfig( + options=[TOOL_FORMAT_FULL, TOOL_FORMAT_REDUCED, TOOL_FORMAT_MINIMAL], + translation_key=CONF_TOOL_FORMAT, + multiple=False, + mode=SelectSelectorMode.DROPDOWN, + )), vol.Required( CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)}, diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index 27ad006..9b7fcaa 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -56,7 +56,12 @@ CONF_SELECTED_LANGUAGE = "selected_language" CONF_SELECTED_LANGUAGE_OPTIONS = [ "en", "de", "fr", "es" ] CONF_DOWNLOADED_MODEL_QUANTIZATION = "downloaded_model_quantization" -CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = ["F16", "Q8_0", "Q5_K_M", "Q4_K_M", "Q3_K_M"] +CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS = [ + "Q4_0", "Q4_1", "Q5_0", "Q5_1", "IQ2_XXS", "IQ2_XS", "IQ2_S", "IQ2_M", "IQ1_S", "IQ1_M", + "Q2_K", "Q2_K_S", "IQ3_XXS", "IQ3_S", "IQ3_M", "Q3_K", "IQ3_XS", "Q3_K_S", "Q3_K_M", "Q3_K_L", + "IQ4_NL", "IQ4_XS", "Q4_K", "Q4_K_S", "Q4_K_M", "Q5_K", "Q5_K_S", "Q5_K_M", "Q6_K", "Q8_0", + "F16", "BF16", "F32" +] DEFAULT_DOWNLOADED_MODEL_QUANTIZATION = "Q4_K_M" CONF_DOWNLOADED_MODEL_FILE = "downloaded_model_file" DEFAULT_DOWNLOADED_MODEL_FILE = "" @@ -137,6 +142,11 @@ "generation_prompt": "<|start_header_id|>assistant<|end_header_id|>\n\n" } } +CONF_TOOL_FORMAT = "tool_format" +TOOL_FORMAT_FULL = "full_tool_format" +TOOL_FORMAT_REDUCED = "reduced_tool_format" +TOOL_FORMAT_MINIMAL = "min_tool_format" +DEFAULT_TOOL_FORMAT = TOOL_FORMAT_FULL CONF_ENABLE_FLASH_ATTENTION = "enable_flash_attention" DEFAULT_ENABLE_FLASH_ATTENTION = False CONF_USE_GBNF_GRAMMAR = "gbnf_grammar" @@ -163,7 +173,7 @@ CONF_PROMPT_CACHING_INTERVAL = "prompt_caching_interval" DEFAULT_PROMPT_CACHING_INTERVAL = 30 CONF_SERVICE_CALL_REGEX = "service_call_regex" -DEFAULT_SERVICE_CALL_REGEX = r"({[\S \t]*?})" +DEFAULT_SERVICE_CALL_REGEX = r" ({[\S \t]*})" FINE_TUNED_SERVICE_CALL_REGEX = r"```homeassistant\n([\S \t\n]*?)```" CONF_REMOTE_USE_CHAT_ENDPOINT = "remote_use_chat_endpoint" DEFAULT_REMOTE_USE_CHAT_ENDPOINT = False diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index 21ed522..49859da 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -2,6 +2,8 @@ "config": { "error": { "download_failed": "The download failed to complete: {exception}", + "missing_quantization": "The GGUF quantization level {missing} does not exist in the provided HuggingFace repo. The following quantization levels were found: {available}", + "no_supported_ggufs": "The provided HuggingFace repo does not contain any compatible GGUF files!", "failed_to_connect": "Failed to connect to the remote API: {exception}", "missing_model_api": "The selected model is not provided by this API.", "missing_model_file": "The provided file does not exist.", @@ -54,6 +56,7 @@ "llm_hass_api": "Selected LLM API", "prompt": "System Prompt", "prompt_template": "Prompt Format", + "tool_format": "Tool Format", "temperature": "Temperature", "top_k": "Top K", "top_p": "Top P", @@ -107,6 +110,7 @@ "max_new_tokens": "Maximum tokens to return in response", "prompt": "System Prompt", "prompt_template": "Prompt Format", + "tool_format": "Tool Format", "temperature": "Temperature", "top_k": "Top K", "top_p": "Top P", @@ -170,6 +174,13 @@ "no_prompt_template": "None" } }, + "tool_format": { + "options": { + "full_tool_format": "Full JSON Tool Format", + "reduced_tool_format": "Reduced JSON Tool Format", + "min_tool_format": "Minimal Function Style Tool Format" + } + }, "model_backend": { "options": { "llama_cpp_hf": "Llama.cpp (HuggingFace)", diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index f132aae..4c443ae 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -8,7 +8,10 @@ import webcolors from importlib.metadata import version +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import intent from homeassistant.requirements import pip_kwargs +from homeassistant.util import color from homeassistant.util.package import install_package, is_installed from .const import ( @@ -18,6 +21,11 @@ _LOGGER = logging.getLogger(__name__) +class MissingQuantizationException(Exception): + def __init__(self, missing_quant: str, available_quants: list[str]): + self.missing_quant = missing_quant + self.available_quants = available_quants + def closest_color(requested_color): min_colors = {} for key, name in webcolors.CSS3_HEX_TO_NAMES.items(): @@ -51,6 +59,35 @@ def _flatten(current_schema, prefix=''): _flatten(schema) return flattened +def custom_custom_serializer(value): + """Why is vol so hard to convert back into a readable schema?""" + + if value is cv.ensure_list: + return { "type": "list" } + + if value is color.color_name_to_rgb: + return { "type": "string" } + + if value is intent.non_empty_string: + return { "type": "string" } + + # media player registers an intent using a lambda... + # there's literally no way to detect that properly + try: + if value(100) == 1: + _LOGGER.debug("bad") + return { "type": "integer" } + except Exception: + pass + + if isinstance(value, list): + result = {} + for x in value: + result.update(custom_custom_serializer(x)) + return result + + return cv.custom_serializer(value) + def download_model_from_hf(model_name: str, quantization_type: str, storage_folder: str): try: from huggingface_hub import hf_hub_download, HfFileSystem @@ -62,8 +99,8 @@ def download_model_from_hf(model_name: str, quantization_type: str, storage_fold wanted_file = [f for f in potential_files if (f".{quantization_type.lower()}." in f or f".{quantization_type.upper()}." in f)] if len(wanted_file) != 1: - raise Exception(f"The quantization '{quantization_type}' does not exist in the HF repo for {model_name}") - + available_quants = [file.split(".")[-2].upper() for file in potential_files] + raise MissingQuantizationException(quantization_type, available_quants) try: os.makedirs(storage_folder, exist_ok=True) except Exception as ex: From dbed6de6cdbf24f5d5f19526a2961bcdc49086e7 Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Sun, 2 Jun 2024 13:02:57 -0400 Subject: [PATCH 05/13] Finish renaming stuff --- README.md | 15 +- TODO.md | 1 + custom_components/llama_conversation/const.py | 35 +-- .../in_context_examples.csv | 10 +- docs/banner.svg | 225 ------------------ docs/logo.svg | 224 ----------------- 6 files changed, 19 insertions(+), 491 deletions(-) delete mode 100644 docs/banner.svg delete mode 100644 docs/logo.svg diff --git a/README.md b/README.md index a0569cb..1c76757 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,23 @@ # Home LLM -![Banner Logo](/docs/banner.svg) - -This project provides the required "glue" components to control your Home Assistant installation with a completely local Large Language Model acting as a personal assistant. The goal is to provide a drop in solution to be used as a "conversation agent" component by Home Assistant. The 2 main pieces of this solution are Home LLM and Llama Conversation. +This project provides the required "glue" components to control your Home Assistant installation with a **completely local** Large Language Model acting as a personal assistant. The goal is to provide a drop in solution to be used as a "conversation agent" component by Home Assistant. The 2 main pieces of this solution are the Home LLM model and Local LLM Conversation integration. ## Quick Start Please see the [Setup Guide](./docs/Setup.md) for more information on installation. -## LLama Conversation Integration +## Local LLM Conversation Integration In order to integrate with Home Assistant, we provide a custom component that exposes the locally running LLM as a "conversation agent". This component can be interacted with in a few ways: - using a chat interface so you can chat with it. - integrating with Speech-to-Text and Text-to-Speech addons so you can just speak to it. -The component can either run the model directly as part of the Home Assistant software using llama-cpp-python, or you can run [Ollama](https://ollama.com/) (simple) or the [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) project (advanced) to provide access to the LLM via an API interface. +The integration can either run the model in 2 different ways: +1. Directly as part of the Home Assistant software using llama-cpp-python +2. On a separate machine using one of the following backends: + - [Ollama](https://ollama.com/) (easier) + - [LocalAI](https://localai.io/) via the Generic OpenAI backend (easier) + - [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) project (advanced) + - [llama.cpp example server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md) (advanced) ## Home LLM Model The "Home" models are a fine tuning of various Large Languages Models that are under 5B parameters. The models are able to control devices in the user's house as well as perform basic question and answering. The fine tuning dataset is a [custom synthetic dataset](./data) designed to teach the model function calling based on the device information in the context. @@ -25,6 +29,7 @@ The latest models can be found on HuggingFace:
Old Models +NOTE: These models are only compatible with version 0.2.17 and older! 3B v2 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v2-GGUF (ChatML prompt format) 3B v1 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v1-GGUF (ChatML prompt format) diff --git a/TODO.md b/TODO.md index 1d6c9fd..482c6a7 100644 --- a/TODO.md +++ b/TODO.md @@ -7,6 +7,7 @@ - [ ] make ICL examples into conversation turns - [ ] translate ICL examples + make better ones - [ ] areas/room support +- [ ] convert requests to aiohttp - [x] detection/mitigation of too many entities being exposed & blowing out the context length - [ ] figure out DPO to improve response quality - [ ] train the model to respond to house events diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index 9b7fcaa..bb40cf2 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -1,4 +1,4 @@ -"""Constants for the LLaMa Conversation integration.""" +"""Constants for the Local LLM Conversation integration.""" import types, os DOMAIN = "llama_conversation" @@ -229,7 +229,7 @@ CONF_TEXT_GEN_WEBUI_PRESET: "" } ) - +# TODO: warn the user if they picked an old, incompatible home-llm model? OPTIONS_OVERRIDES = { "home-3b-v4": { CONF_PROMPT: DEFAULT_PROMPT_BASE, @@ -238,42 +238,13 @@ CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, CONF_USE_GBNF_GRAMMAR: True, }, - "home-3b-v3": { - CONF_PROMPT: DEFAULT_PROMPT_BASE, - CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR, - CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, - CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, - CONF_USE_GBNF_GRAMMAR: True, - }, - "home-3b-v2": { - CONF_PROMPT: DEFAULT_PROMPT_BASE, - CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, - CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, - CONF_USE_GBNF_GRAMMAR: True, - }, - "home-3b-v1": { + "home-1b-v4": { CONF_PROMPT: DEFAULT_PROMPT_BASE, CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR, CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, - }, - "home-1b-v3": { - CONF_PROMPT: DEFAULT_PROMPT_BASE, - CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR2, - CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, - CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, CONF_USE_GBNF_GRAMMAR: True, }, - "home-1b-v2": { - CONF_PROMPT: DEFAULT_PROMPT_BASE, - CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, - CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, - }, - "home-1b-v1": { - CONF_PROMPT: DEFAULT_PROMPT_BASE, - CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, - CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, - }, "mistral": { CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_NO_SYSTEM_PROMPT_EXTRAS, CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_MISTRAL, diff --git a/custom_components/llama_conversation/in_context_examples.csv b/custom_components/llama_conversation/in_context_examples.csv index 98ab251..435a455 100644 --- a/custom_components/llama_conversation/in_context_examples.csv +++ b/custom_components/llama_conversation/in_context_examples.csv @@ -18,11 +18,11 @@ media_player,TODO,HassSetVolume,Setting the volume as requested. media_player,TODO,HassMediaUnpause,Starting media playback. media_player,TODO,HassMediaPause,Pausing the media playback. media_player,TODO,HassMediaNext,Skipping to the next track. -switch,TODO,HassTurnOn,Turning on the switch for you. -switch,TODO,HassTurnOff,Turning off the switch as requested. -switch,TODO,HassToggle,Toggling the switch for you. -vacuum,TODO,HassVacuumStart,Starting the vacuum now. -vacuum,TODO,HassVacuumReturnToBase,Sending the vacuum back to its base. +switch,Turn on the ,HassTurnOn,Turning on the switch for you. +switch,Turn off the switches in ,HassTurnOff,Turning off the devices as requested. +switch,Toggle the switch ,HassToggle,Toggling the switch for you. +vacuum,Start the vacuum called ,HassVacuumStart,Starting the vacuum now. +vacuum,Stop the vacuum,HassVacuumReturnToBase,Sending the vacuum back to its base. todo,TODO,HassListAddItem,the todo has been added to your todo list. timer,TODO,HassStartTimer,Starting timer now. timer,TODO,HassCancelTimer,Timer has been canceled. \ No newline at end of file diff --git a/docs/banner.svg b/docs/banner.svg deleted file mode 100644 index 9ed0cb9..0000000 --- a/docs/banner.svg +++ /dev/null @@ -1,225 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Home LLM & ​ Llama Conversation - \ No newline at end of file diff --git a/docs/logo.svg b/docs/logo.svg deleted file mode 100644 index 04875ed..0000000 --- a/docs/logo.svg +++ /dev/null @@ -1,224 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file From b50904d73b2ae2493379b02f2ffcffaed00e5dcf Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Sun, 2 Jun 2024 22:13:54 -0400 Subject: [PATCH 06/13] handle multi-turn tool models --- custom_components/llama_conversation/agent.py | 38 +++++++++++++++---- .../llama_conversation/config_flow.py | 7 ++++ custom_components/llama_conversation/const.py | 3 ++ .../llama_conversation/translations/en.json | 2 + 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index 8d775f1..8262982 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -46,6 +46,7 @@ CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_PROMPT_TEMPLATE, CONF_TOOL_FORMAT, + CONF_TOOL_MULTI_TURN_CHAT, CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, @@ -81,6 +82,7 @@ DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_PROMPT_TEMPLATE, DEFAULT_TOOL_FORMAT, + DEFAULT_TOOL_MULTI_TURN_CHAT, DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, @@ -227,7 +229,7 @@ async def async_process( device_id=user_input.device_id, ) - _LOGGER.info(llm_context) + _LOGGER.info(llm_context.context) try: service_call_pattern = re.compile(service_call_regex) @@ -316,7 +318,7 @@ async def async_process( conversation.pop(1) self.history[conversation_id] = conversation - if not llm_api: + if llm_api is None: # return the output without messing with it if there is no API exposed to the model intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_speech(response.strip()) @@ -379,9 +381,29 @@ async def async_process( _LOGGER.debug("Tool response: %s", tool_response) - # TODO: optionally handle multi-turn with the model where it acts on the response from the tool - # if self.entry.options.get("", ""): - # pass + # handle models that generate a function call and wait for the result before providing a response + if self.entry.options.get(CONF_TOOL_MULTI_TURN_CHAT, DEFAULT_TOOL_MULTI_TURN_CHAT): + conversation.append({"role": "tool", "message": json.dumps(tool_response)}) + + # generate a response based on the tool result + try: + _LOGGER.debug(conversation) + to_say = await self._async_generate(conversation) + _LOGGER.debug(to_say) + + except Exception as err: + _LOGGER.exception("There was a problem talking to the backend") + + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.FAILED_TO_HANDLE, + f"Sorry, there was a problem talking to the backend: {repr(err)}", + ) + return ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + conversation.append({"role": "assistant", "message": response}) # generate intent response to Home Assistant intent_response = intent.IntentResponse(language=user_input.language) @@ -430,7 +452,8 @@ def _format_prompt( for message in prompt: role = message["role"] message = message["message"] - role_desc = template_desc[role] + # fall back to the "user" role for unknown roles + role_desc = template_desc.get(role, template_desc["user"]) formatted_prompt = ( formatted_prompt + f"{role_desc['prefix']}{message}{role_desc['suffix']}\n" ) @@ -453,7 +476,6 @@ def _format_tool(self, tool: llm.Tool): # handle vol.Any in the key side of things processed_parameters = [] for param in raw_parameters: - _LOGGER.info(param["name"]) if isinstance(param["name"], vol.Any): for possible_name in param["name"].validators: actual_param = param.copy() @@ -828,7 +850,7 @@ async def _async_cache_prompt(self, entity, old_state, new_state): self.last_updated_entities[entity] = refresh_start _LOGGER.debug(f"refreshing cached prompt because {entity} changed...") - await self.hass.async_add_executor_job(self._cache_prompt) + # await self.hass.async_add_executor_job(self._cache_prompt) refresh_end = time.time() _LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec") diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index c122cac..fe97a4c 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -57,6 +57,7 @@ CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS, CONF_PROMPT_TEMPLATE, CONF_TOOL_FORMAT, + CONF_TOOL_MULTI_TURN_CHAT, CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, @@ -97,6 +98,7 @@ DEFAULT_DOWNLOADED_MODEL_QUANTIZATION, DEFAULT_PROMPT_TEMPLATE, DEFAULT_TOOL_FORMAT, + DEFAULT_TOOL_MULTI_TURN_CHAT, DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, @@ -767,6 +769,11 @@ def local_llama_config_option_schema(hass: HomeAssistant, options: MappingProxyT multiple=False, mode=SelectSelectorMode.DROPDOWN, )), + vol.Required( + CONF_TOOL_MULTI_TURN_CHAT, + description={"suggested_value": options.get(CONF_TOOL_MULTI_TURN_CHAT)}, + default=DEFAULT_TOOL_MULTI_TURN_CHAT, + ): BooleanSelector(BooleanSelectorConfig()), vol.Required( CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)}, diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index bb40cf2..3817e12 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -86,6 +86,7 @@ "system": { "prefix": "<|im_start|>system\n", "suffix": "<|im_end|>" }, "user": { "prefix": "<|im_start|>user\n", "suffix": "<|im_end|>" }, "assistant": { "prefix": "<|im_start|>assistant\n", "suffix": "<|im_end|>" }, + "tool": { "prefix": "<|im_start|>tool", "suffix": "<|im_end|>" }, "generation_prompt": "<|im_start|>assistant" }, PROMPT_TEMPLATE_COMMAND_R: { @@ -147,6 +148,8 @@ TOOL_FORMAT_REDUCED = "reduced_tool_format" TOOL_FORMAT_MINIMAL = "min_tool_format" DEFAULT_TOOL_FORMAT = TOOL_FORMAT_FULL +CONF_TOOL_MULTI_TURN_CHAT = "tool_multi_turn_chat" +DEFAULT_TOOL_MULTI_TURN_CHAT = False CONF_ENABLE_FLASH_ATTENTION = "enable_flash_attention" DEFAULT_ENABLE_FLASH_ATTENTION = False CONF_USE_GBNF_GRAMMAR = "gbnf_grammar" diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index 49859da..9945482 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -57,6 +57,7 @@ "prompt": "System Prompt", "prompt_template": "Prompt Format", "tool_format": "Tool Format", + "tool_multi_turn_chat": "Multi-Turn Tool Use", "temperature": "Temperature", "top_k": "Top K", "top_p": "Top P", @@ -111,6 +112,7 @@ "prompt": "System Prompt", "prompt_template": "Prompt Format", "tool_format": "Tool Format", + "tool_multi_turn_chat": "Multi-Turn Tool Use", "temperature": "Temperature", "top_k": "Top K", "top_p": "Top P", From b10ede765e87c94dedfb87b08481bffdefce226a Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Wed, 5 Jun 2024 23:37:05 -0400 Subject: [PATCH 07/13] update readme and todo --- README.md | 2 ++ TODO.md | 13 +++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 1c76757..e5cdc0d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # Home LLM This project provides the required "glue" components to control your Home Assistant installation with a **completely local** Large Language Model acting as a personal assistant. The goal is to provide a drop in solution to be used as a "conversation agent" component by Home Assistant. The 2 main pieces of this solution are the Home LLM model and Local LLM Conversation integration. +NOTE: This integration has **NOT** been updated yet to support the new LLM API changes in Home Assistant `2024.6.0`. + ## Quick Start Please see the [Setup Guide](./docs/Setup.md) for more information on installation. diff --git a/TODO.md b/TODO.md index 482c6a7..555bd9e 100644 --- a/TODO.md +++ b/TODO.md @@ -3,16 +3,14 @@ - rewrite how services are called - handle no API selected - rewrite prompts + service block formats + - implement new LLM API that has `HassCallService` so old models can still work - update dataset so new models will work with the API - [ ] make ICL examples into conversation turns - [ ] translate ICL examples + make better ones - [ ] areas/room support -- [ ] convert requests to aiohttp -- [x] detection/mitigation of too many entities being exposed & blowing out the context length -- [ ] figure out DPO to improve response quality -- [ ] train the model to respond to house events - - present the model with an event + a "prompt" from the user of what you want it to do (i.e. turn on the lights when I get home = the model turns on lights when your entity presence triggers as being home) - - basically lets you write automations in plain english +- [ ] convert requests to aiohttp +- [x] detection/mitigation of too many entities being exposed & blowing out the context length +- [ ] figure out DPO to improve response quality - [x] setup github actions to build wheels that are optimized for RPIs - [x] mixtral + prompting (no fine tuning) - add in context learning variables to sys prompt template @@ -53,3 +51,6 @@ - set up vectordb - ingest home assistant docs - "context request" from above to initiate a RAG search +- [ ] train the model to respond to house events + - present the model with an event + a "prompt" from the user of what you want it to do (i.e. turn on the lights when I get home = the model turns on lights when your entity presence triggers as being home) + - basically lets you write automations in plain english \ No newline at end of file From 36e29bedf0dae0312ab4b1dd29593297373ed6ed Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Thu, 6 Jun 2024 22:40:59 -0400 Subject: [PATCH 08/13] add an LLM API to support the existing models --- .../llama_conversation/__init__.py | 83 +++++++++++++- custom_components/llama_conversation/agent.py | 108 +++++++++++++----- custom_components/llama_conversation/const.py | 7 +- custom_components/llama_conversation/utils.py | 5 - hacs.json | 2 +- 5 files changed, 170 insertions(+), 35 deletions(-) diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index 5a8a915..c68f42d 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -5,8 +5,13 @@ import homeassistant.components.conversation as ha_conversation from homeassistant.config_entries import ConfigEntry +from homeassistant.const import ATTR_ENTITY_ID from homeassistant.core import HomeAssistant -from homeassistant.helpers import config_validation as cv +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import config_validation as cv, llm +from homeassistant.util.json import JsonObjectType + +import voluptuous as vol from .agent import ( LocalLLMAgent, @@ -27,6 +32,8 @@ BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER, BACKEND_TYPE_OLLAMA, DOMAIN, + HOME_LLM_API_ID, + SERVICE_TOOL_NAME, ) _LOGGER = logging.getLogger(__name__) @@ -46,6 +53,10 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry): async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Local LLM Conversation from a config entry.""" + # make sure the API is registered + if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]): + llm.async_register_api(hass, HomeLLMAPI(hass)) + def create_agent(backend_type): agent_cls = None @@ -102,3 +113,73 @@ async def async_migrate_entry(hass, config_entry: ConfigEntry): _LOGGER.debug("Migration to version %s successful", config_entry.version) return True + +class HassServiceTool(llm.Tool): + """Tool to get the current time.""" + + name = SERVICE_TOOL_NAME + description = "Executes a Home Assistant service" + + # Optional. A voluptuous schema of the input parameters. + parameters = vol.Schema({ + vol.Required('service'): str, + vol.Required('target_device'): str, + vol.Optional('rgb_color'): str, + vol.Optional('brightness'): float, + vol.Optional('temperature'): float, + vol.Optional('humidity'): float, + vol.Optional('fan_mode'): str, + vol.Optional('hvac_mode'): str, + vol.Optional('preset_mode'): str, + vol.Optional('duration'): str, + vol.Optional('item'): str, + }) + + optional_allowed_args = [] + + async def async_call( + self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext + ) -> JsonObjectType: + """Call the tool.""" + domain, service = tuple(tool_input.tool_args["service"].split(".")) + target_device = tool_input.tool_args["target_device"] + + service_data = {ATTR_ENTITY_ID: target_device} + for attr in self.optional_allowed_args: + if attr in tool_input.tool_args.keys(): + service_data[attr] = tool_input.tool_args[attr] + try: + await hass.services.async_call( + domain, + service, + service_data=service_data, + blocking=True, + ) + except Exception: + _LOGGER.exception("Failed to execute service for model") + return { "result": "failed" } + + return { "result": "success" } + +class HomeLLMAPI(llm.API): + """ + An API that allows calling Home Assistant services to maintain compatibility + with the older (v3 and older) Home LLM models + """ + + def __init__(self, hass: HomeAssistant) -> None: + """Init the class.""" + super().__init__( + hass=hass, + id=HOME_LLM_API_ID, + name="Home-LLM (v1-v3)", + ) + + async def async_get_api_instance(self, llm_context: llm.LLMContext) -> llm.APIInstance: + """Return the instance of the API.""" + return llm.APIInstance( + api=self, + api_prompt="Call services in Home Assistant by passing the service name and the device to control.", + llm_context=llm_context, + tools=[HassServiceTool()], + ) diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index 8262982..7329025 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -106,7 +106,10 @@ TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT, + ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS, DOMAIN, + HOME_LLM_API_ID, + SERVICE_TOOL_NAME, PROMPT_TEMPLATE_DESCRIPTIONS, TOOL_FORMAT_FULL, TOOL_FORMAT_REDUCED, @@ -329,12 +332,30 @@ async def async_process( # parse response to_say = service_call_pattern.sub("", response.strip()) for block in service_call_pattern.findall(response.strip()): - parsed_tool_call = json.loads(block) - try: - vol.Schema({ + parsed_tool_call: dict = json.loads(block) + + if llm_api.api.id == HOME_LLM_API_ID: + schema_to_validate = vol.Schema({ + vol.Required('service'): str, + vol.Required('target_device'): str, + vol.Optional('rgb_color'): str, + vol.Optional('brightness'): float, + vol.Optional('temperature'): float, + vol.Optional('humidity'): float, + vol.Optional('fan_mode'): str, + vol.Optional('hvac_mode'): str, + vol.Optional('preset_mode'): str, + vol.Optional('duration'): str, + vol.Optional('item'): str, + }) + else: + schema_to_validate = vol.Schema({ vol.Required("name"): str, vol.Required("arguments"): dict, - })(parsed_tool_call) + }) + + try: + schema_to_validate(parsed_tool_call) except vol.Error as ex: _LOGGER.info(f"LLM produced an improperly formatted response: {repr(ex)}") @@ -350,18 +371,27 @@ async def async_process( _LOGGER.info(f"calling tool: {block}") # try to fix certain arguments + args_dict = parsed_tool_call if llm_api.api.id == HOME_LLM_API_ID else parsed_tool_call["arguments"] + # make sure brightness is 0-255 and not a percentage - if "brightness" in parsed_tool_call["arguments"] and 0.0 < parsed_tool_call["arguments"]["brightness"] <= 1.0: - parsed_tool_call["arguments"]["brightness"] = int(parsed_tool_call["arguments"]["brightness"] * 255) + if "brightness" in args_dict and 0.0 < args_dict["brightness"] <= 1.0: + args_dict["brightness"] = int(args_dict["brightness"] * 255) # convert string "tuple" to a list for RGB colors - if "rgb_color" in parsed_tool_call["arguments"] and isinstance(parsed_tool_call["arguments"]["rgb_color"], str): - parsed_tool_call["arguments"]["rgb_color"] = [ int(x) for x in parsed_tool_call["arguments"]["rgb_color"][1:-1].split(",") ] + if "rgb_color" in args_dict and isinstance(args_dict["rgb_color"], str): + args_dict["rgb_color"] = [ int(x) for x in args_dict["rgb_color"][1:-1].split(",") ] - tool_input = llm.ToolInput( - tool_name=parsed_tool_call["name"], - tool_args=parsed_tool_call["arguments"], - ) + if llm_api.api.id == HOME_LLM_API_ID: + to_say = to_say + parsed_tool_call.pop("to_say", "") + tool_input = llm.ToolInput( + tool_name=SERVICE_TOOL_NAME, + tool_args=parsed_tool_call, + ) + else: + tool_input = llm.ToolInput( + tool_name=parsed_tool_call["name"], + tool_args=parsed_tool_call["arguments"], + ) try: tool_response = await llm_api.async_call_tool(tool_input) @@ -464,14 +494,17 @@ def _format_prompt( _LOGGER.debug(formatted_prompt) return formatted_prompt - def _format_tool(self, tool: llm.Tool): + def _format_tool(self, name: str, parameters: vol.Schema, description: str): style = self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) if style == TOOL_FORMAT_MINIMAL: - return f"{tool.name}({', '.join(flatten_vol_schema(tool.parameters))}) - {tool.description}" + result = f"{name}({','.join(flatten_vol_schema(parameters))})" + if description: + result = result + f" - {description}" + return result raw_parameters: list = voluptuous_serialize.convert( - tool.parameters, custom_serializer=custom_custom_serializer) + parameters, custom_serializer=custom_custom_serializer) # handle vol.Any in the key side of things processed_parameters = [] @@ -487,8 +520,8 @@ def _format_tool(self, tool: llm.Tool): if style == TOOL_FORMAT_REDUCED: return { - "name": tool.name, - "description": tool.description, + "name": name, + "description": description, "parameters": { "properties": { x["name"]: x.get("type", "string") for x in processed_parameters @@ -502,8 +535,8 @@ def _format_tool(self, tool: llm.Tool): return { "type": "function", "function": { - "name": tool.name, - "description": tool.description, + "name": name, + "description": description, "parameters": { "type": "object", "properties": { @@ -627,17 +660,40 @@ def expose_attributes(attributes): formatted_states = "\n".join(device_states) + "\n" - if llm_api: - formatted_tools = json.dumps([ - self._format_tool(tool) - for tool in llm_api.tools - ]) + if llm_api: + if llm_api.api.id == HOME_LLM_API_ID: + service_dict = self.hass.services.async_services() + all_services = [] + for domain in domains: + # scripts show up as individual services + if domain == "script": + all_services.extend(["script.reload()", "script.turn_on()", "script.turn_off()", "script.toggle()"]) + continue + + for name, service in service_dict.get(domain, {}).items(): + args = flatten_vol_schema(service.schema) + args_to_expose = set(args).intersection(ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS) + service_schema = vol.Schema({ + vol.Optional(arg): str for arg in args_to_expose + }) + + all_services.append((f"{domain}.{name}", service_schema, "")) + + tools = [ + self._format_tool(*tool) + for tool in all_services + ] + else: + tools = [ + self._format_tool(tool.name, tool.parameters, tool.description) + for tool in llm_api.tools + ] else: - formatted_tools = "No tools were provided. If the user requests you interact with a device, tell them you are unable to do so." + tools = "No tools were provided. If the user requests you interact with a device, tell them you are unable to do so." render_variables = { "devices": formatted_states, - "tools": formatted_tools, + "tools": tools, "response_examples": [] } diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index 3817e12..e0c10ce 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -2,6 +2,8 @@ import types, os DOMAIN = "llama_conversation" +HOME_LLM_API_ID = "home-llm-service-api" +SERVICE_TOOL_NAME = "HassCallService" CONF_PROMPT = "prompt" PERSONA_PROMPTS = { "en": "You are 'Al', a helpful AI Assistant that controls the devices in a house. Complete the following task as instructed with the information provided only.", @@ -11,7 +13,7 @@ } DEFAULT_PROMPT_BASE = """ The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }} -Tools: {{ tools }} +Tools: {{ tools | to_json }} Devices: {{ devices }}""" ICL_EXTRAS = """ @@ -69,6 +71,7 @@ DEFAULT_SSL = False CONF_EXTRA_ATTRIBUTES_TO_EXPOSE = "extra_attributes_to_expose" DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "media_title", "volume_level", "item", "wind_speed"] +ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS = ["rgb_color", "brightness", "temperature", "humidity", "fan_mode", "hvac_mode", "preset_mode", "item", "duration"] CONF_PROMPT_TEMPLATE = "prompt_template" PROMPT_TEMPLATE_CHATML = "chatml" PROMPT_TEMPLATE_COMMAND_R = "command-r" @@ -232,7 +235,7 @@ CONF_TEXT_GEN_WEBUI_PRESET: "" } ) -# TODO: warn the user if they picked an old, incompatible home-llm model? +# TODO: re-add old models but select the legacy API OPTIONS_OVERRIDES = { "home-3b-v4": { CONF_PROMPT: DEFAULT_PROMPT_BASE, diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index 4c443ae..a036330 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -45,11 +45,6 @@ def _flatten(current_schema, prefix=''): _flatten(subval, prefix) elif isinstance(current_schema.schema, dict): for key, val in current_schema.schema.items(): - if isinstance(key, vol.Any): - key = "|".join(key.validators) - if isinstance(key, vol.Optional): - key = "?" + str(key) - _flatten(val, prefix + str(key) + '/') elif isinstance(current_schema, vol.validators._WithSubValidators): for subval in current_schema.validators: diff --git a/hacs.json b/hacs.json index d0b9dbc..bf0d1db 100644 --- a/hacs.json +++ b/hacs.json @@ -1,6 +1,6 @@ { "name": "Local LLM Conversation", - "homeassistant": "2024.5.5", + "homeassistant": "2024.6.0", "content_in_root": false, "render_readme": true } From ab32942006cc7ef32af530f8bb2025e05d0ae3fa Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Thu, 6 Jun 2024 23:08:04 -0400 Subject: [PATCH 09/13] more fixes for llm API --- README.md | 5 ++- .../llama_conversation/__init__.py | 4 -- custom_components/llama_conversation/agent.py | 2 +- .../llama_conversation/config_flow.py | 8 ++++ custom_components/llama_conversation/const.py | 44 ++++++++++++++++--- .../llama_conversation/translations/en.json | 4 +- docs/Backend Configuration.md | 3 ++ docs/Setup.md | 8 +++- 8 files changed, 60 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index e5cdc0d..f652367 100644 --- a/README.md +++ b/README.md @@ -31,13 +31,14 @@ The latest models can be found on HuggingFace:
Old Models -NOTE: These models are only compatible with version 0.2.17 and older! 3B v2 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v2-GGUF (ChatML prompt format) -3B v1 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v1-GGUF (ChatML prompt format) 1B v2 (Based on Phi-1.5): https://huggingface.co/acon96/Home-1B-v2-GGUF (ChatML prompt format) 1B v1 (Based on Phi-1.5): https://huggingface.co/acon96/Home-1B-v1-GGUF (ChatML prompt format) +NOTE: The models below are only compatible with version 0.2.17 and older! +3B v1 (Based on Phi-2): https://huggingface.co/acon96/Home-3B-v1-GGUF (ChatML prompt format) +
The model is quantized using Llama.cpp in order to enable running the model in super low resource environments that are common with Home Assistant installations such as Raspberry Pis. diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index c68f42d..5b18b4a 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -53,10 +53,6 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry): async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up Local LLM Conversation from a config entry.""" - # make sure the API is registered - if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]): - llm.async_register_api(hass, HomeLLMAPI(hass)) - def create_agent(backend_type): agent_cls = None diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index 7329025..1b651ee 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -403,7 +403,7 @@ async def async_process( intent_response = intent.IntentResponse(language=user_input.language) intent_response.async_set_error( intent.IntentResponseErrorCode.NO_INTENT_MATCH, - f"There was an error calling the tool! ({tool_response})", + f"I'm sorry! I encountered an error calling the tool. See the logs for more info.", ) return ConversationResult( response=intent_response, conversation_id=conversation_id diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index fe97a4c..23d8331 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -134,12 +134,15 @@ TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT, DOMAIN, + HOME_LLM_API_ID, DEFAULT_OPTIONS, OPTIONS_OVERRIDES, RECOMMENDED_CHAT_MODELS, EMBEDDED_LLAMA_CPP_PYTHON_VERSION ) +from . import HomeLLMAPI + _LOGGER = logging.getLogger(__name__) def is_local_backend(backend): @@ -312,6 +315,11 @@ async def async_step_user( """Handle the initial step.""" self.model_config = {} self.options = {} + + # make sure the API is registered + if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(self.hass)]): + llm.async_register_api(self.hass, HomeLLMAPI(self.hass)) + return await self.async_step_pick_backend() async def async_step_pick_backend( diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index e0c10ce..a603bc2 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -16,6 +16,11 @@ Tools: {{ tools | to_json }} Devices: {{ devices }}""" +DEFAULT_PROMPT_BASE_LEGACY = """ +The current time and date is {{ (as_timestamp(now()) | timestamp_custom("%I:%M %p on %A %B %d, %Y", "")) }} +Services: {{ tools | join(", ") }} +Devices: +{{ devices }}""" ICL_EXTRAS = """ {% for item in response_examples %} {{ item.request }} @@ -235,21 +240,46 @@ CONF_TEXT_GEN_WEBUI_PRESET: "" } ) -# TODO: re-add old models but select the legacy API + OPTIONS_OVERRIDES = { - "home-3b-v4": { - CONF_PROMPT: DEFAULT_PROMPT_BASE, + "home-3b-v3": { + CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY, CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR, CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, - CONF_USE_GBNF_GRAMMAR: True, + CONF_TOOL_FORMAT: TOOL_FORMAT_MINIMAL, }, - "home-1b-v4": { - CONF_PROMPT: DEFAULT_PROMPT_BASE, + "home-3b-v2": { + CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY, + CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, + CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, + CONF_TOOL_FORMAT: TOOL_FORMAT_MINIMAL, + }, + "home-3b-v1": { + CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY, CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR, CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, - CONF_USE_GBNF_GRAMMAR: True, + CONF_TOOL_FORMAT: TOOL_FORMAT_MINIMAL, + }, + "home-1b-v3": { + CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY, + CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR2, + CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, + CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, + CONF_TOOL_FORMAT: TOOL_FORMAT_MINIMAL, + }, + "home-1b-v2": { + CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY, + CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, + CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, + CONF_TOOL_FORMAT: TOOL_FORMAT_MINIMAL, + }, + "home-1b-v1": { + CONF_PROMPT: DEFAULT_PROMPT_BASE_LEGACY, + CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES: False, + CONF_SERVICE_CALL_REGEX: FINE_TUNED_SERVICE_CALL_REGEX, + CONF_TOOL_FORMAT: TOOL_FORMAT_MINIMAL, }, "mistral": { CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_NO_SYSTEM_PROMPT_EXTRAS, diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index 9945482..3deae35 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -90,7 +90,7 @@ "n_batch_threads": "Batch Thread Count" }, "data_description": { - "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices.", + "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this", "remote_use_chat_endpoint": "If this is enabled, then the integration will use the chat completion HTTP endpoint instead of the text completion one.", @@ -145,7 +145,7 @@ "n_batch_threads": "Batch Thread Count" }, "data_description": { - "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices.", + "llm_hass_api": "Select 'Assist' if you want the model to be able to control devices. If you are using the Home-LLM v1, v2, or v3 model then select 'Home-LLM (v1-3)'", "prompt": "See [here](https://github.com/acon96/home-llm/blob/develop/docs/Model%20Prompting.md) for more information on model prompting.", "in_context_examples": "If you are using a model that is not specifically fine-tuned for use with this integration: enable this", "remote_use_chat_endpoint": "If this is enabled, then the integration will use the chat completion HTTP endpoint instead of the text completion one.", diff --git a/docs/Backend Configuration.md b/docs/Backend Configuration.md index c2efa2a..2f4dcc1 100644 --- a/docs/Backend Configuration.md +++ b/docs/Backend Configuration.md @@ -5,8 +5,11 @@ There are multiple backends to choose for running the model that the Home Assist # Common Options | Option Name | Description | Suggested Value | |-----------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------| +| LLM API | This is the set of tools that are provided to the LLM. Use Assist for the built-in API. If you are using Home-LLM v1, v2, or v3, then select the dedicated API | | | System Prompt | [see here](./Model%20Prompting.md) | | | Prompt Format | The format for the context of the model | | +| Tool Format | The format of the tools that are provided to the model. Full, Reduced, or Minimal | | +| Multi-Turn Tool Use | Enable this if the model you are using expects to receive the result from the tool call before responding to the user | | | Maximum tokens to return in response | Limits the number of tokens that can be produced by each model response | 512 | | Additional attribute to expose in the context | Extra attributes that will be exposed to the model via the `{{ devices }}` template variable | | | Arguments allowed to be pass to service calls | Any arguments not listed here will be filtered out of service calls. Used to restrict the model from modifying certain parts of your home. | | diff --git a/docs/Setup.md b/docs/Setup.md index f888212..8727cde 100644 --- a/docs/Setup.md +++ b/docs/Setup.md @@ -64,7 +64,9 @@ Pressing `Submit` will download the model from HuggingFace. ### Step 3: Model Configuration This step allows you to configure how the model is "prompted". See [here](./Model%20Prompting.md) for more information on how that works. -For now, defaults for the model should have been populated and you can just scroll to the bottom and click `Submit`. +For now, defaults for the model should have been populated. If you would like the model to be able to control devices then you must select the `Home-LLM (v1-v3)` API. This API is included to ensure compatability with the Home-LLM models that were trained before the introduction of the built in Home Assistant LLM API. + +Once the desired API has been selected, scroll to the bottom and click `Submit`. The model will be loaded into memory and should now be available to select as a conversation agent! @@ -95,7 +97,9 @@ In order to access the model from another machine, we need to run the Ollama API ### Step 3: Model Configuration This step allows you to configure how the model is "prompted". See [here](./Model%20Prompting.md) for more information on how that works. -For now, defaults for the model should have been populated and you can just scroll to the bottom and click `Submit`. +For now, defaults for the model should have been populated. If you would like the model to be able to control devices then you must select the `Assist` API. + +Once the desired API has been selected, scroll to the bottom and click `Submit`. > NOTE: The key settings in this case are that our prompt references the `{{ response_examples }}` variable and the `Enable in context learning (ICL) examples` option is turned on. From bee5d4e3847a7d6de253266ef34f25a75c9da75e Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Thu, 6 Jun 2024 23:14:25 -0400 Subject: [PATCH 10/13] re-enable prompt caching --- custom_components/llama_conversation/agent.py | 153 +++++++++--------- 1 file changed, 75 insertions(+), 78 deletions(-) diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index 1b651ee..f9a4f9f 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -223,17 +223,6 @@ async def async_process( remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, DEFAULT_REMEMBER_NUM_INTERACTIONS) service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX) - llm_context = llm.LLMContext( - platform=DOMAIN, - context=user_input.context, - user_prompt=user_input.text, - language=user_input.language, - assistant=ha_conversation.DOMAIN, - device_id=user_input.device_id, - ) - - _LOGGER.info(llm_context.context) - try: service_call_pattern = re.compile(service_call_regex) except Exception as err: @@ -252,7 +241,16 @@ async def async_process( if self.entry.options.get(CONF_LLM_HASS_API): try: llm_api = await llm.async_get_api( - self.hass, self.entry.options[CONF_LLM_HASS_API], llm_context + self.hass, + self.entry.options[CONF_LLM_HASS_API], + llm_context=llm.LLMContext( + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=ha_conversation.DOMAIN, + device_id=user_input.device_id, + ) ) except HomeAssistantError as err: _LOGGER.error("Error getting LLM API: %s", err) @@ -905,92 +903,91 @@ async def _async_cache_prompt(self, entity, old_state, new_state): if entity: self.last_updated_entities[entity] = refresh_start + llm_api: llm.APIInstance | None = None + if self.entry.options.get(CONF_LLM_HASS_API): + try: + llm_api = await llm.async_get_api( + self.hass, self.entry.options[CONF_LLM_HASS_API] + ) + except HomeAssistantError: + _LOGGER.exception("Failed to get LLM API when caching prompt!") + return + _LOGGER.debug(f"refreshing cached prompt because {entity} changed...") - # await self.hass.async_add_executor_job(self._cache_prompt) + await self.hass.async_add_executor_job(self._cache_prompt, llm_api) refresh_end = time.time() _LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec") - # def _cache_prompt(self) -> None: - # # if a refresh is already scheduled then exit - # if self.cache_refresh_after_cooldown: - # return - - # # if we are inside the cooldown period, request a refresh and exit - # current_time = time.time() - # fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) - # if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval: - # self.cache_refresh_after_cooldown = True - # return + def _cache_prompt(self, llm_api: llm.API) -> None: + # if a refresh is already scheduled then exit + if self.cache_refresh_after_cooldown: + return - # # try to acquire the lock, if we are still running for some reason, request a refresh and exit - # lock_acquired = self.model_lock.acquire(False) - # if not lock_acquired: - # self.cache_refresh_after_cooldown = True - # return + # if we are inside the cooldown period, request a refresh and exit + current_time = time.time() + fastest_prime_interval = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) + if self.last_cache_prime and current_time - self.last_cache_prime < fastest_prime_interval: + self.cache_refresh_after_cooldown = True + return - # llm_api: llm.APIInstance | None = None - # if self.entry.options.get(CONF_LLM_HASS_API): - # try: - # llm_api = await llm.async_get_api( - # self.hass, self.entry.options[CONF_LLM_HASS_API] - # ) - # except HomeAssistantError: - # _LOGGER.exception("Failed to get LLM API when caching prompt!") - # return - - # try: - # raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) - # prompt = self._format_prompt([ - # { "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)}, - # { "role": "user", "message": "" } - # ], include_generation_prompt=False) + # try to acquire the lock, if we are still running for some reason, request a refresh and exit + lock_acquired = self.model_lock.acquire(False) + if not lock_acquired: + self.cache_refresh_after_cooldown = True + return + try: + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) + prompt = self._format_prompt([ + { "role": "system", "message": self._generate_system_prompt(raw_prompt, llm_api)}, + { "role": "user", "message": "" } + ], include_generation_prompt=False) - # input_tokens = self.llm.tokenize( - # prompt.encode(), add_bos=False - # ) + input_tokens = self.llm.tokenize( + prompt.encode(), add_bos=False + ) - # temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) - # top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)) - # top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) - # grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None + temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) + top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K)) + top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) + grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None - # _LOGGER.debug(f"Options: {self.entry.options}") + _LOGGER.debug(f"Options: {self.entry.options}") - # _LOGGER.debug(f"Processing {len(input_tokens)} input tokens...") + _LOGGER.debug(f"Processing {len(input_tokens)} input tokens...") - # # grab just one token. should prime the kv cache with the system prompt - # next(self.llm.generate( - # input_tokens, - # temp=temperature, - # top_k=top_k, - # top_p=top_p, - # grammar=grammar - # )) + # grab just one token. should prime the kv cache with the system prompt + next(self.llm.generate( + input_tokens, + temp=temperature, + top_k=top_k, + top_p=top_p, + grammar=grammar + )) - # self.last_cache_prime = time.time() - # finally: - # self.model_lock.release() + self.last_cache_prime = time.time() + finally: + self.model_lock.release() - # # schedule a refresh using async_call_later - # # if the flag is set after the delay then we do another refresh + # schedule a refresh using async_call_later + # if the flag is set after the delay then we do another refresh - # @callback - # async def refresh_if_requested(_now): - # if self.cache_refresh_after_cooldown: - # self.cache_refresh_after_cooldown = False + @callback + async def refresh_if_requested(_now): + if self.cache_refresh_after_cooldown: + self.cache_refresh_after_cooldown = False - # refresh_start = time.time() - # _LOGGER.debug(f"refreshing cached prompt after cooldown...") - # await self.hass.async_add_executor_job(self._cache_prompt) + refresh_start = time.time() + _LOGGER.debug(f"refreshing cached prompt after cooldown...") + await self.hass.async_add_executor_job(self._cache_prompt) - # refresh_end = time.time() - # _LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec") + refresh_end = time.time() + _LOGGER.debug(f"cache refresh took {(refresh_end - refresh_start):.2f} sec") - # refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) - # async_call_later(self.hass, float(refresh_delay), refresh_if_requested) + refresh_delay = self.entry.options.get(CONF_PROMPT_CACHING_INTERVAL, DEFAULT_PROMPT_CACHING_INTERVAL) + async_call_later(self.hass, float(refresh_delay), refresh_if_requested) def _generate(self, conversation: dict) -> str: From 21640dc3215cfc734d892b32fe342c073d2a1cdd Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Thu, 6 Jun 2024 23:54:08 -0400 Subject: [PATCH 11/13] release notes, fix service call args, and other release prep/cleanup --- README.md | 1 + .../llama_conversation/__init__.py | 20 ++++++------------- .../llama_conversation/config_flow.py | 2 +- custom_components/llama_conversation/const.py | 6 +++--- .../in_context_examples.csv | 18 ++++++----------- .../llama_conversation/translations/en.json | 4 ++-- custom_components/llama_conversation/utils.py | 6 ++---- 7 files changed, 21 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index f652367..3fde4bb 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ In order to facilitate running the project entirely on the system where Home Ass ## Version History | Version | Description | |---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| v0.3 | Adds support for Home Assistant LLM APIs, improved model prompting and tool formatting options, and automatic detection of GGUF quantization levels on HuggingFace | | v0.2.17 | Disable native llama.cpp wheel optimizations, add Command R prompt format | | v0.2.16 | Fix for missing huggingface_hub package preventing startup | | v0.2.15 | Fix startup error when using llama.cpp backend and add flash attention to llama.cpp backend | diff --git a/custom_components/llama_conversation/__init__.py b/custom_components/llama_conversation/__init__.py index 5b18b4a..362700b 100644 --- a/custom_components/llama_conversation/__init__.py +++ b/custom_components/llama_conversation/__init__.py @@ -31,6 +31,7 @@ BACKEND_TYPE_GENERIC_OPENAI, BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER, BACKEND_TYPE_OLLAMA, + ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS, DOMAIN, HOME_LLM_API_ID, SERVICE_TOOL_NAME, @@ -94,17 +95,10 @@ async def async_migrate_entry(hass, config_entry: ConfigEntry): """Migrate old entry.""" _LOGGER.debug("Migrating from version %s", config_entry.version) - if config_entry.version > 1: - # This means the user has downgraded from a future version - return False - - # if config_entry.version < 2: - # # just ensure that the defaults are set - # new_options = dict(DEFAULT_OPTIONS) - # new_options.update(config_entry.options) - - # config_entry.version = 2 - # hass.config_entries.async_update_entry(config_entry, options=new_options) + # 1 -> 2: This was a breaking change so force users to re-create entries + if config_entry.version == 1: + _LOGGER.error("Cannot upgrade models that were created prior to v0.3. Please delete and re-create them.") + return False _LOGGER.debug("Migration to version %s successful", config_entry.version) @@ -131,8 +125,6 @@ class HassServiceTool(llm.Tool): vol.Optional('item'): str, }) - optional_allowed_args = [] - async def async_call( self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext ) -> JsonObjectType: @@ -141,7 +133,7 @@ async def async_call( target_device = tool_input.tool_args["target_device"] service_data = {ATTR_ENTITY_ID: target_device} - for attr in self.optional_allowed_args: + for attr in ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS: if attr in tool_input.tool_args.keys(): service_data[attr] = tool_input.tool_args[attr] try: diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index 23d8331..f07424e 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -291,7 +291,7 @@ async def async_step_finish( class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow for Local LLM Conversation.""" - VERSION = 1 + VERSION = 2 install_wheel_task = None install_wheel_error = None download_task = None diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index a603bc2..75159f5 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -173,7 +173,7 @@ CONF_TEXT_GEN_WEBUI_PRESET = "text_generation_webui_preset" CONF_OPENAI_API_KEY = "openai_api_key" CONF_TEXT_GEN_WEBUI_ADMIN_KEY = "text_generation_webui_admin_key" -CONF_REFRESH_SYSTEM_PROMPT = "refresh_prompt_per_tern" +CONF_REFRESH_SYSTEM_PROMPT = "refresh_prompt_per_turn" DEFAULT_REFRESH_SYSTEM_PROMPT = True CONF_REMEMBER_CONVERSATION = "remember_conversation" DEFAULT_REMEMBER_CONVERSATION = True @@ -315,5 +315,5 @@ } } -INTEGRATION_VERSION = "0.2.17" -EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.2.70" \ No newline at end of file +INTEGRATION_VERSION = "0.3" +EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.2.77" \ No newline at end of file diff --git a/custom_components/llama_conversation/in_context_examples.csv b/custom_components/llama_conversation/in_context_examples.csv index 435a455..18a88f5 100644 --- a/custom_components/llama_conversation/in_context_examples.csv +++ b/custom_components/llama_conversation/in_context_examples.csv @@ -10,19 +10,13 @@ light,Lights off in ,HassTurnOff,Turning off the light as requested. light,Hit toggle for the lights in ,HassToggle,Toggling the lights in the for you. light,Set the brightness for to ,HassLightSet,Setting the brightness now. light,Make the lights in ,HassLightSet,The color should be changed now. -media_player,TODO,HassTurnOn,Turning on the media player for you. -media_player,TODO,HassTurnOff,Turning off the media player as requested. -media_player,TODO,HassToggle,Toggling the media player for you. -media_player,TODO,HassSetVolume,Muting the volume for you. -media_player,TODO,HassSetVolume,Setting the volume as requested. -media_player,TODO,HassMediaUnpause,Starting media playback. -media_player,TODO,HassMediaPause,Pausing the media playback. -media_player,TODO,HassMediaNext,Skipping to the next track. +media_player,Can you turn on ,HassTurnOn,Turning on the media player for you. +media_player, should be turned off,HassTurnOff,Turning off the media player as requested. +media_player,Can you press play on ,HassMediaUnpause,Starting media playback. +media_player,Pause the ,HassMediaPause,Pausing the media playback. +media_player,Play the next thing on ,HassMediaNext,Skipping to the next track. switch,Turn on the ,HassTurnOn,Turning on the switch for you. switch,Turn off the switches in ,HassTurnOff,Turning off the devices as requested. switch,Toggle the switch ,HassToggle,Toggling the switch for you. vacuum,Start the vacuum called ,HassVacuumStart,Starting the vacuum now. -vacuum,Stop the vacuum,HassVacuumReturnToBase,Sending the vacuum back to its base. -todo,TODO,HassListAddItem,the todo has been added to your todo list. -timer,TODO,HassStartTimer,Starting timer now. -timer,TODO,HassCancelTimer,Timer has been canceled. \ No newline at end of file +vacuum,Stop the vacuum,HassVacuumReturnToBase,Sending the vacuum back to its base. \ No newline at end of file diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index 3deae35..ac77b77 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -73,7 +73,7 @@ "openai_api_key": "API Key", "text_generation_webui_admin_key": "Admin Key", "service_call_regex": "Service Call Regex", - "refresh_prompt_per_tern": "Refresh System Prompt Every Turn", + "refresh_prompt_per_turn": "Refresh System Prompt Every Turn", "remember_conversation": "Remember conversation", "remember_num_interactions": "Number of past interactions to remember", "in_context_examples": "Enable in context learning (ICL) examples", @@ -128,7 +128,7 @@ "openai_api_key": "API Key", "text_generation_webui_admin_key": "Admin Key", "service_call_regex": "Service Call Regex", - "refresh_prompt_per_tern": "Refresh System Prompt Every Turn", + "refresh_prompt_per_turn": "Refresh System Prompt Every Turn", "remember_conversation": "Remember conversation", "remember_num_interactions": "Number of past interactions to remember", "in_context_examples": "Enable in context learning (ICL) examples", diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index a036330..d7c8c05 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -55,7 +55,7 @@ def _flatten(current_schema, prefix=''): return flattened def custom_custom_serializer(value): - """Why is vol so hard to convert back into a readable schema?""" + """a vol schema is really not straightforward to convert back into a dictionary""" if value is cv.ensure_list: return { "type": "list" } @@ -67,10 +67,9 @@ def custom_custom_serializer(value): return { "type": "string" } # media player registers an intent using a lambda... - # there's literally no way to detect that properly + # there's literally no way to detect that properly. with that in mind, we have this try: if value(100) == 1: - _LOGGER.debug("bad") return { "type": "integer" } except Exception: pass @@ -105,7 +104,6 @@ def download_model_from_hf(model_name: str, quantization_type: str, storage_fold repo_id=model_name, repo_type="model", filename=wanted_file[0].removeprefix(model_name + "/"), - resume_download=True, cache_dir=storage_folder, ) From b56d54b945769a14a272cbf17f5ddd065591b230 Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Thu, 6 Jun 2024 23:55:12 -0400 Subject: [PATCH 12/13] remove note about not being updated yet --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 3fde4bb..1e55a0d 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ # Home LLM This project provides the required "glue" components to control your Home Assistant installation with a **completely local** Large Language Model acting as a personal assistant. The goal is to provide a drop in solution to be used as a "conversation agent" component by Home Assistant. The 2 main pieces of this solution are the Home LLM model and Local LLM Conversation integration. -NOTE: This integration has **NOT** been updated yet to support the new LLM API changes in Home Assistant `2024.6.0`. - ## Quick Start Please see the [Setup Guide](./docs/Setup.md) for more information on installation. From 9f08e6f8a167dfbb89ec4d25c38384e2e7b3b52f Mon Sep 17 00:00:00 2001 From: Alex O'Connell Date: Fri, 7 Jun 2024 00:05:03 -0400 Subject: [PATCH 13/13] final tweaks --- custom_components/llama_conversation/utils.py | 2 +- docs/Setup.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index d7c8c05..3ff1695 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -48,7 +48,7 @@ def _flatten(current_schema, prefix=''): _flatten(val, prefix + str(key) + '/') elif isinstance(current_schema, vol.validators._WithSubValidators): for subval in current_schema.validators: - _flatten(subval, prefix) + _flatten(subval, prefix) elif callable(current_schema): flattened.append(prefix[:-1] if prefix else prefix) _flatten(schema) diff --git a/docs/Setup.md b/docs/Setup.md index 8727cde..19d195c 100644 --- a/docs/Setup.md +++ b/docs/Setup.md @@ -52,9 +52,9 @@ This should download and install `llama-cpp-python` from GitHub. If the installa Once `llama-cpp-python` is installed, continue to the model selection. ### Step 2: Model Selection -The next step is to specify which model will be used by the integration. You may select any repository on HuggingFace that has a model in GGUF format in it. We will use `acon96/Home-3B-v3-GGUF` for this example. If you have less than 4GB of RAM then use `acon96/Home-1B-v2-GGUF`. +The next step is to specify which model will be used by the integration. You may select any repository on HuggingFace that has a model in GGUF format in it. We will use `acon96/Home-3B-v3-GGUF` for this example. If you have less than 4GB of RAM then use `acon96/Home-1B-v3-GGUF`. -**Model Name**: Use either `acon96/Home-3B-v3-GGUF` or `acon96/Home-1B-v2-GGUF` +**Model Name**: Use either `acon96/Home-3B-v3-GGUF` or `acon96/Home-1B-v3-GGUF` **Quantization Level**: The model will be downloaded in the selected quantization level from the HuggingFace repository. If unsure which level to choose, select `Q4_K_M`. Pressing `Submit` will download the model from HuggingFace.