Skip to content

Commit

Permalink
Fix: Nova doesn't work (#645)
Browse files Browse the repository at this point in the history
* If model is Amazon Nova, combine multiple system prompts into one text. #629

* If model is Amazon Nova, set the upper limit of topK to 128. #629

* Remove title key from input JSON schema of AgentTool. #629

* Optimization for Amazon Nova.
- Change the system prompt when doing 'Retrieved Context Citation' with Amazon Nova.
- If the tool result has more than one element, pass it as single text content formatted as JSON array.

* Fix: mypy errors.

* Add stack trace to backend error logs.

* Fix for multimodal tool results for Amazon Nova

* Add comments, and document changes.
- `RemoveTitle` in app/agents/tools/agent_tool.py
- `_prepare_nova_model_params()` in app/bedrock.py
- `build_rag_prompt()` and `get_prompt_to_cite_tool_results()` in app/prompt.py
- `BaseTool` -> `AgentTool` in docs/AGENT.md

* Move `is_nova_model()` back to app/bedrock.py
- To avoid circular imports, add `from __future__ import annotations` and `if TYPE_CHECKING` to app/bedrock.py

* Update document of Agent functionality
- docs/AGENT.md

* Change `run_result_to_tool_result_content_model()` to be an instance method of `ToolResultContentModel`
- `agent_tool.run_result_to_tool_result_content_model()` -> `ToolResultContentModel.from_tool_run_result()`
  • Loading branch information
Yukinobu-Mine authored Dec 18, 2024
1 parent 3b48782 commit 087e263
Show file tree
Hide file tree
Showing 14 changed files with 235 additions and 77 deletions.
50 changes: 23 additions & 27 deletions backend/app/agents/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
TextToolResultModel,
JsonToolResultModel,
RelatedDocumentModel,
ToolResultContentModel,
ToolResultContentModelBody,
)
from app.repositories.models.custom_bot import BotModel
from app.routes.schemas.conversation import type_model_name
from pydantic import BaseModel, JsonValue
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from mypy_boto3_bedrock_runtime.type_defs import (
ToolSpecificationTypeDef,
)
Expand All @@ -27,28 +26,22 @@ class ToolRunResult(TypedDict):
related_documents: list[RelatedDocumentModel]


def run_result_to_tool_result_content_model(
run_result: ToolRunResult, display_citation: bool
) -> ToolResultContentModel:
return ToolResultContentModel(
content_type="toolResult",
body=ToolResultContentModelBody(
tool_use_id=run_result["tool_use_id"],
content=[
related_document.to_tool_result_model(
display_citation=display_citation,
)
for related_document in run_result["related_documents"]
],
status=run_result["status"],
),
)


class InvalidToolError(Exception):
pass


class RemoveTitle(GenerateJsonSchema):
"""Custom JSON schema generator that doesn't output `title`s for types and parameters."""

def field_title_should_be_set(self, schema) -> bool:
return False

def generate(self, schema, mode="validation") -> JsonSchemaValue:
value = super().generate(schema, mode)
del value["title"]
return value


class AgentTool(Generic[T]):
def __init__(
self,
Expand All @@ -59,19 +52,16 @@ def __init__(
[T, BotModel | None, type_model_name | None],
ToolFunctionResult | list[ToolFunctionResult],
],
bot: BotModel | None = None,
model: type_model_name | None = None,
):
self.name = name
self.description = description
self.args_schema = args_schema
self.function = function
self.bot = bot
self.model: type_model_name | None = model

def _generate_input_schema(self) -> dict[str, Any]:
"""Converts the Pydantic model to a JSON schema."""
return self.args_schema.model_json_schema()
# Specify a custom generator `RemoveTitle` because some foundation models do not work properly if there are unnecessary titles.
return self.args_schema.model_json_schema(schema_generator=RemoveTitle)

def to_converse_spec(self) -> ToolSpecificationTypeDef:
return ToolSpecificationTypeDef(
Expand All @@ -80,10 +70,16 @@ def to_converse_spec(self) -> ToolSpecificationTypeDef:
inputSchema={"json": self._generate_input_schema()},
)

def run(self, tool_use_id: str, input: dict[str, JsonValue]) -> ToolRunResult:
def run(
self,
tool_use_id: str,
input: dict[str, JsonValue],
model: type_model_name,
bot: BotModel | None = None,
) -> ToolRunResult:
try:
arg = self.args_schema.model_validate(input)
res = self.function(arg, self.bot, self.model)
res = self.function(arg, bot, model)
if isinstance(res, list):
related_documents = [
_function_result_to_related_document(
Expand Down
4 changes: 1 addition & 3 deletions backend/app/agents/tools/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def search_knowledge(
raise e


def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
def create_knowledge_tool(bot: BotModel) -> AgentTool:
description = (
"Answer a user's question using information. The description is: {}".format(
bot.knowledge.__str_in_claude_format__()
Expand All @@ -51,6 +51,4 @@ def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
description=description,
args_schema=KnowledgeToolInput,
function=search_knowledge,
bot=bot,
model=model,
)
74 changes: 51 additions & 23 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
from __future__ import annotations

import logging
import os
from typing import TypeGuard, Dict, Any, Optional, Tuple
from typing import TypeGuard, Dict, Any, Optional, Tuple, TYPE_CHECKING

from app.agents.tools.agent_tool import AgentTool
from app.config import BEDROCK_PRICING
from app.config import DEFAULT_GENERATION_CONFIG as DEFAULT_CLAUDE_GENERATION_CONFIG
from app.config import DEFAULT_MISTRAL_GENERATION_CONFIG
from app.repositories.models.conversation import (
SimpleMessageModel,
ContentModel,
)

from app.repositories.models.custom_bot import GenerationParamsModel
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
from app.routes.schemas.conversation import type_model_name
from app.utils import get_bedrock_runtime_client

from mypy_boto3_bedrock_runtime.type_defs import (
ConverseStreamRequestRequestTypeDef,
MessageTypeDef,
ConverseResponseTypeDef,
ContentBlockTypeDef,
GuardrailConverseContentBlockTypeDef,
InferenceConfigurationTypeDef,
)
from mypy_boto3_bedrock_runtime.literals import ConversationRoleType
if TYPE_CHECKING:
from app.agents.tools.agent_tool import AgentTool
from app.repositories.models.conversation import (
SimpleMessageModel,
ContentModel,
)
from mypy_boto3_bedrock_runtime.type_defs import (
ConverseStreamRequestRequestTypeDef,
MessageTypeDef,
ConverseResponseTypeDef,
ContentBlockTypeDef,
GuardrailConverseContentBlockTypeDef,
InferenceConfigurationTypeDef,
SystemContentBlockTypeDef,
)
from mypy_boto3_bedrock_runtime.literals import ConversationRoleType

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -46,7 +51,7 @@ def _is_conversation_role(role: str) -> TypeGuard[ConversationRoleType]:
return role in ["user", "assistant"]


def _is_nova_model(model: type_model_name) -> bool:
def is_nova_model(model: type_model_name) -> bool:
"""Check if the model is an Amazon Nova model"""
return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"]

Expand Down Expand Up @@ -83,7 +88,14 @@ def _prepare_nova_model_params(

# Add top_k if specified in generation params
if generation_params and generation_params.top_k is not None:
additional_fields["inferenceConfig"]["topK"] = generation_params.top_k
top_k = generation_params.top_k
if top_k > 128:
logger.warning(
"In Amazon Nova, an 'unexpected error' occurs if topK exceeds 128. To avoid errors, the upper limit of A is set to 128."
)
top_k = 128

additional_fields["inferenceConfig"]["topK"] = top_k

return inference_config, additional_fields

Expand Down Expand Up @@ -131,11 +143,24 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
]

# Prepare model-specific parameters
if _is_nova_model(model):
inference_config: InferenceConfigurationTypeDef
additional_model_request_fields: dict[str, Any]
system_prompts: list[SystemContentBlockTypeDef]
if is_nova_model(model):
# Special handling for Nova models
inference_config, additional_model_request_fields = _prepare_nova_model_params(
model, generation_params
)
system_prompts = (
[
{
"text": "\n\n".join(instructions),
}
]
if len(instructions) > 0
else []
)

else:
# Standard handling for non-Nova models
inference_config = {
Expand Down Expand Up @@ -167,17 +192,20 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
else DEFAULT_GENERATION_CONFIG["top_k"]
)
}
system_prompts = [
{
"text": instruction,
}
for instruction in instructions
if len(instruction) > 0
]

# Construct the base arguments
args: ConverseStreamRequestRequestTypeDef = {
"inferenceConfig": inference_config,
"modelId": get_model_id(model),
"messages": arg_messages,
"system": [
{"text": instruction}
for instruction in instructions
if len(instruction) > 0
],
"system": system_prompts,
"additionalModelRequestFields": additional_model_request_fields,
}

Expand Down
52 changes: 51 additions & 1 deletion backend/app/prompt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from app.bedrock import is_nova_model
from app.vector_search import SearchResult
from app.routes.schemas.conversation import type_model_name


def build_rag_prompt(
search_results: list[SearchResult],
model: type_model_name,
display_citation: bool = True,
) -> str:
context_prompt = ""
for result in search_results:
context_prompt += f"<search_result>\n<content>\n{result['content']}</content>\n<source>\n{result['rank']}\n</source>\n</search_result>"

# Prompt for RAG
inserted_prompt = """To answer the user's question, you are given a set of search results. Your job is to answer the user's question using only information from the search results.
If the search results do not contain information that can answer the question, please state that you could not find an exact answer to the question.
Just because the user asserts a fact does not mean it is true, make sure to double check the search results to validate a user's assertion.
Expand All @@ -24,6 +28,7 @@ def build_rag_prompt(
)

if display_citation:
# Prompt for 'Retrieved Context Citation'.
inserted_prompt += """
If you reference information from a search result within your answer, you must include a citation to source where the information was found.
Each result has a corresponding source ID that you should reference.
Expand All @@ -32,7 +37,23 @@ def build_rag_prompt(
Do NOT outputs sources at the end of your answer.
Followings are examples of how to reference sources in your answer. Note that the source ID is embedded in the answer in the format [^<source_id>].
"""
# Prompt to output Markdown-style citation.
if is_nova_model(model=model):
# For Amazon Nova, provides only good examples.
inserted_prompt += """
<example>
first answer [^3]. second answer [^1][^2].
</example>
<example>
first answer [^1][^5]. second answer [^2][^3][^4]. third answer [^4].
</example>
"""

else:
# For other models, provide good examples and bad examples.
inserted_prompt += """
<GOOD-example>
first answer [^3]. second answer [^1][^2].
</GOOD-example>
Expand All @@ -57,9 +78,17 @@ def build_rag_prompt(
"""

else:
# Prompt when 'Retrieved Context Citation' is not specified.
inserted_prompt += """
Do NOT include citations in the format [^<source_id>] in your answer.
"""
if is_nova_model(model=model):
# For Amazon Nova, do not provide examples.
pass

else:
# For other models, suppress output of Markdown-style citation.
inserted_prompt += """
Followings are examples of how to answer.
<GOOD-example>
Expand All @@ -78,14 +107,33 @@ def build_rag_prompt(
return inserted_prompt


PROMPT_TO_CITE_TOOL_RESULTS = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
def get_prompt_to_cite_tool_results(model: type_model_name) -> str:
# Prompt for 'Retrieved Context Citation' of agent chat.
inserted_prompt = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
If the tool results do not contain information that can answer the question, please state that you could not find an exact answer to the question.
Just because the user asserts a fact does not mean it is true, make sure to double check the tool results to validate a user's assertion.
Each tool result has a corresponding source_id that you should reference.
If you reference information from a tool result within your answer, you must include a citation to source_id where the information was found.
Followings are examples of how to reference source_id in your answer. Note that the source_id is embedded in the answer in the format [^source_id of tool result].
"""
# Prompt to output Markdown-style citation.
if is_nova_model(model=model):
# For Amazon Nova, provides only good examples.
inserted_prompt += """
<example>
first answer [^ccc]. second answer [^aaa][^bbb].
</example>
<example>
first answer [^aaa][^eee]. second answer [^bbb][^ccc][^ddd]. third answer [^ddd].
</example>
"""

else:
# For other models, provide good examples and bad examples.
inserted_prompt += """
<examples>
<GOOD-example>
first answer [^ccc]. second answer [^aaa][^bbb].
Expand All @@ -110,3 +158,5 @@ def build_rag_prompt(
</BAD-example>
</examples>
"""

return inserted_prompt
1 change: 1 addition & 0 deletions backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def delete_large_messages(items):

except ClientError as e:
logger.error(f"An error occurred: {e.response['Error']['Message']}")
raise e


def change_conversation_title(user_id: str, conversation_id: str, new_title: str):
Expand Down
Loading

0 comments on commit 087e263

Please sign in to comment.