Skip to content

Commit

Permalink
Rework model names
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator committed Dec 20, 2024
1 parent 1154306 commit 728a9fa
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 21 deletions.
10 changes: 5 additions & 5 deletions src/phoenix/server/api/helpers/prompthub/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class PromptMessageRole(str, Enum):
AI = "ai" # E.g. the assistant. Normalize to AI for consistency.


class PromptStringTemplate(BaseModel):
template: str


class TextPromptMessage(BaseModel):
role: PromptMessageRole
content: str
Expand All @@ -28,11 +24,15 @@ class JSONPromptMessage(BaseModel):
content: JSONSerializable


class PromptMessagesTemplateV1(BaseModel):
class PromptChatTemplateV1(BaseModel):
_version: str = "messages-v1"
template: list[Union[TextPromptMessage, JSONPromptMessage]]


class PromptStringTemplate(BaseModel):
template: str


# TODO: Figure out enums, maybe just store whole tool blobs
# class PromptToolParameter(BaseModel):
# name: str
Expand Down
4 changes: 2 additions & 2 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
PromptVersion,
)
from phoenix.server.api.types.PromptVersionTemplate import (
PromptChatTemplateV1,
PromptMessageRole,
PromptMessagesTemplateV1,
TextPromptMessage,
)
from phoenix.server.api.types.SortDir import SortDir
Expand Down Expand Up @@ -544,7 +544,7 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
created_at=datetime.now(),
)
elif type_name == PromptVersion.__name__:
template = PromptMessagesTemplateV1(
template = PromptChatTemplateV1(
messages=[
TextPromptMessage(
role=PromptMessageRole.USER,
Expand Down
4 changes: 2 additions & 2 deletions src/phoenix/server/api/types/Prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
connection_from_list,
)
from phoenix.server.api.types.PromptVersionTemplate import (
PromptChatTemplateV1,
PromptMessageRole,
PromptMessagesTemplateV1,
TextPromptMessage,
)

Expand Down Expand Up @@ -46,7 +46,7 @@ async def prompt_versions(
before=before if isinstance(before, CursorString) else None,
)

template = PromptMessagesTemplateV1(
template = PromptChatTemplateV1(
messages=[
TextPromptMessage(
role=PromptMessageRole.USER,
Expand Down
4 changes: 2 additions & 2 deletions src/phoenix/server/api/types/PromptVersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from strawberry.relay import Node, NodeID
from strawberry.scalars import JSON

from phoenix.server.api.types.PromptVersionTemplate import PromptTemplateVersion
from phoenix.server.api.types.PromptVersionTemplate import PromptTemplate


@strawberry.enum
Expand All @@ -30,7 +30,7 @@ class PromptVersion(Node):
description: str
template_type: PromptTemplateType
template_format: PromptTemplateFormat
template: PromptTemplateVersion
template: PromptTemplate
invocation_parameters: Optional[JSON] = None
tools: Optional[JSON] = None
output_schema: Optional[JSON] = None
Expand Down
13 changes: 3 additions & 10 deletions src/phoenix/server/api/types/PromptVersionTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from phoenix.server.api.helpers.prompthub.models import (
PromptMessageRole,
)
from phoenix.server.api.helpers.prompthub.models import (
PromptStringTemplate as PromptStringTemplateModel,
)


@strawberry.type
Expand All @@ -26,7 +23,7 @@ class JSONPromptMessage:


@strawberry.type
class PromptMessagesTemplateV1:
class PromptChatTemplateV1:
version: str = "messages-v1"
messages: list[Union[TextPromptMessage, JSONPromptMessage]]

Expand All @@ -35,11 +32,7 @@ class PromptMessagesTemplateV1:
class PromptStringTemplate:
template: str

@classmethod
def from_model(cls, model: PromptStringTemplateModel) -> "PromptStringTemplate":
return PromptStringTemplate(template=model.template)


PromptTemplateVersion = strawberry.union(
"PromptTemplateVersion", (PromptStringTemplate, PromptMessagesTemplateV1)
PromptTemplate = strawberry.union(
"PromptTemplateVersion", (PromptStringTemplate, PromptChatTemplateV1)
)

0 comments on commit 728a9fa

Please sign in to comment.