diff --git a/src/phoenix/server/api/helpers/prompthub/models.py b/src/phoenix/server/api/helpers/prompthub/models.py index 9a4d91d5cb..728531fefa 100644 --- a/src/phoenix/server/api/helpers/prompthub/models.py +++ b/src/phoenix/server/api/helpers/prompthub/models.py @@ -14,10 +14,6 @@ class PromptMessageRole(str, Enum): AI = "ai" # E.g. the assistant. Normalize to AI for consistency. -class PromptStringTemplate(BaseModel): - template: str - - class TextPromptMessage(BaseModel): role: PromptMessageRole content: str @@ -28,11 +24,15 @@ class JSONPromptMessage(BaseModel): content: JSONSerializable -class PromptMessagesTemplateV1(BaseModel): +class PromptChatTemplateV1(BaseModel): _version: str = "messages-v1" template: list[Union[TextPromptMessage, JSONPromptMessage]] +class PromptStringTemplate(BaseModel): + template: str + + # TODO: Figure out enums, maybe just store whole tool blobs # class PromptToolParameter(BaseModel): # name: str diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index e9520e8dba..236efc5166 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -62,8 +62,8 @@ PromptVersion, ) from phoenix.server.api.types.PromptVersionTemplate import ( + PromptChatTemplateV1, PromptMessageRole, - PromptMessagesTemplateV1, TextPromptMessage, ) from phoenix.server.api.types.SortDir import SortDir @@ -544,7 +544,7 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: created_at=datetime.now(), ) elif type_name == PromptVersion.__name__: - template = PromptMessagesTemplateV1( + template = PromptChatTemplateV1( messages=[ TextPromptMessage( role=PromptMessageRole.USER, diff --git a/src/phoenix/server/api/types/Prompt.py b/src/phoenix/server/api/types/Prompt.py index 3b8290e214..2dc684c84b 100644 --- a/src/phoenix/server/api/types/Prompt.py +++ b/src/phoenix/server/api/types/Prompt.py @@ -15,8 +15,8 @@ connection_from_list, ) from phoenix.server.api.types.PromptVersionTemplate import ( + PromptChatTemplateV1, PromptMessageRole, - PromptMessagesTemplateV1, TextPromptMessage, ) @@ -46,7 +46,7 @@ async def prompt_versions( before=before if isinstance(before, CursorString) else None, ) - template = PromptMessagesTemplateV1( + template = PromptChatTemplateV1( messages=[ TextPromptMessage( role=PromptMessageRole.USER, diff --git a/src/phoenix/server/api/types/PromptVersion.py b/src/phoenix/server/api/types/PromptVersion.py index bf264b012a..db7442e84b 100644 --- a/src/phoenix/server/api/types/PromptVersion.py +++ b/src/phoenix/server/api/types/PromptVersion.py @@ -7,7 +7,7 @@ from strawberry.relay import Node, NodeID from strawberry.scalars import JSON -from phoenix.server.api.types.PromptVersionTemplate import PromptTemplateVersion +from phoenix.server.api.types.PromptVersionTemplate import PromptTemplate @strawberry.enum @@ -30,7 +30,7 @@ class PromptVersion(Node): description: str template_type: PromptTemplateType template_format: PromptTemplateFormat - template: PromptTemplateVersion + template: PromptTemplate invocation_parameters: Optional[JSON] = None tools: Optional[JSON] = None output_schema: Optional[JSON] = None diff --git a/src/phoenix/server/api/types/PromptVersionTemplate.py b/src/phoenix/server/api/types/PromptVersionTemplate.py index 0c42e8ebd1..67d24a7e1b 100644 --- a/src/phoenix/server/api/types/PromptVersionTemplate.py +++ b/src/phoenix/server/api/types/PromptVersionTemplate.py @@ -8,9 +8,6 @@ from phoenix.server.api.helpers.prompthub.models import ( PromptMessageRole, ) -from phoenix.server.api.helpers.prompthub.models import ( - PromptStringTemplate as PromptStringTemplateModel, -) @strawberry.type @@ -26,7 +23,7 @@ class JSONPromptMessage: @strawberry.type -class PromptMessagesTemplateV1: +class PromptChatTemplateV1: version: str = "messages-v1" messages: list[Union[TextPromptMessage, JSONPromptMessage]] @@ -35,11 +32,7 @@ class PromptMessagesTemplateV1: class PromptStringTemplate: template: str - @classmethod - def from_model(cls, model: PromptStringTemplateModel) -> "PromptStringTemplate": - return PromptStringTemplate(template=model.template) - -PromptTemplateVersion = strawberry.union( - "PromptTemplateVersion", (PromptStringTemplate, PromptMessagesTemplateV1) +PromptTemplate = strawberry.union( + "PromptTemplateVersion", (PromptStringTemplate, PromptChatTemplateV1) )