Skip to content

Commit

Permalink
Decouple pydantic models and gql types
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator committed Dec 20, 2024
1 parent a5cae16 commit 31387c6
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 74 deletions.
Empty file.
51 changes: 51 additions & 0 deletions src/phoenix/server/api/helpers/prompthub/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from enum import Enum
from typing import Any, Union

import strawberry
from pydantic import BaseModel

JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]


@strawberry.enum
class PromptMessageRole(str, Enum):
USER = "user"
SYSTEM = "system"
AI = "ai" # E.g. the assistant. Normalize to AI for consistency.


class PromptStringTemplate(BaseModel):
template: str


class TextPromptMessage(BaseModel):
role: PromptMessageRole
content: str


class JSONPromptMessage(BaseModel):
role: PromptMessageRole
content: JSONSerializable


class PromptMessagesTemplateV1(BaseModel):
_version: str = "messages-v1"
template: list[Union[TextPromptMessage, JSONPromptMessage]]


# TODO: Figure out enums, maybe just store whole tool blobs
# class PromptToolParameter(BaseModel):
# name: str
# type: str
# description: str
# required: bool
# default: str


class PromptToolDefinition(BaseModel):
definition: JSONSerializable


class PromptTools(BaseModel):
_version: str = "tools-v1"
tools: list[PromptToolDefinition]
11 changes: 4 additions & 7 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
from phoenix.server.api.types.PromptVersionTemplate import (
PromptMessageRole,
PromptMessagesTemplateV1,
PromptMessagesTemplateV1GQL,
TextPromptMessage,
)
from phoenix.server.api.types.SortDir import SortDir
Expand Down Expand Up @@ -545,25 +544,23 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
created_at=datetime.now(),
)
elif type_name == PromptVersion.__name__:
template_model = PromptMessagesTemplateV1(
template=[
template = PromptMessagesTemplateV1(
messages=[
TextPromptMessage(
role=PromptMessageRole.USER,
content="Hello what's the weather in Antarctica like?",
)
]
)

template_gql = PromptMessagesTemplateV1GQL.from_model(template_model)

if node_id == 2:
return PromptVersion(
id_attr=2,
user="alice",
description="A dummy prompt version",
template_type=PromptTemplateType.CHAT,
template_format=PromptTemplateFormat.MUSTACHE,
template=template_gql,
template=template,
invocation_parameters={"temperature": 0.5},
tools={
"_version": "tools-v1",
Expand Down Expand Up @@ -603,7 +600,7 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
description="A dummy prompt version",
template_type=PromptTemplateType.CHAT,
template_format=PromptTemplateFormat.MUSTACHE,
template=template_gql,
template=template,
invocation_parameters=None,
tools=None,
output_schema=None,
Expand Down
11 changes: 4 additions & 7 deletions src/phoenix/server/api/types/Prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from phoenix.server.api.types.PromptVersionTemplate import (
PromptMessageRole,
PromptMessagesTemplateV1,
PromptMessagesTemplateV1GQL,
TextPromptMessage,
)

Expand Down Expand Up @@ -47,25 +46,23 @@ async def prompt_versions(
before=before if isinstance(before, CursorString) else None,
)

template_model = PromptMessagesTemplateV1(
template=[
template = PromptMessagesTemplateV1(
messages=[
TextPromptMessage(
role=PromptMessageRole.USER,
content="Hello what's the weather in Antarctica like?",
)
]
)

template_gql = PromptMessagesTemplateV1GQL.from_model(template_model)

dummy_data = [
PromptVersion(
id_attr=2,
user="alice",
description="A dummy prompt version",
template_type=PromptTemplateType.CHAT,
template_format=PromptTemplateFormat.MUSTACHE,
template=template_gql,
template=template,
invocation_parameters={"temperature": 0.5},
tools={
"_version": "tools-v1",
Expand Down Expand Up @@ -103,7 +100,7 @@ async def prompt_versions(
description="A dummy prompt version",
template_type=PromptTemplateType.CHAT,
template_format=PromptTemplateFormat.MUSTACHE,
template=template_gql,
template=template,
model_name="gpt-4o",
model_provider="openai",
),
Expand Down
76 changes: 16 additions & 60 deletions src/phoenix/server/api/types/PromptVersionTemplate.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,45 @@
# Part of the Phoenix PromptHub feature set

from enum import Enum
from typing import Any, Union
from typing import Union

import strawberry
from pydantic import BaseModel
from strawberry.scalars import JSON

JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]


@strawberry.enum
class PromptMessageRole(str, Enum):
USER = "user"
SYSTEM = "system"
AI = "ai" # E.g. the assistant. Normalize to AI for consistency.


class PromptStringTemplate(BaseModel):
template: str


class TextPromptMessage(BaseModel):
role: PromptMessageRole
content: str

def to_gql(self):
return TextPromptMessageGQL.from_model(self)


class JSONPromptMessage(BaseModel):
role: PromptMessageRole
content: JSONSerializable

def to_gql(self):
return JSONPromptMessageGQL.from_model(self)


class PromptMessagesTemplateV1(BaseModel):
_version: str = "messages-v1"
template: list[Union[TextPromptMessage, JSONPromptMessage]]
from phoenix.server.api.helpers.prompthub.models import (
PromptMessageRole,
)
from phoenix.server.api.helpers.prompthub.models import (
PromptStringTemplate as PromptStringTemplateModel,
)


@strawberry.type
class TextPromptMessageGQL:
class TextPromptMessage:
role: PromptMessageRole
content: str

@classmethod
def from_model(cls, model: TextPromptMessage) -> "TextPromptMessageGQL":
return TextPromptMessageGQL(role=model.role, content=model.content)


@strawberry.type
class JSONPromptMessageGQL:
class JSONPromptMessage:
role: PromptMessageRole
content: JSON

@classmethod
def from_model(cls, model: JSONPromptMessage) -> "JSONPromptMessageGQL":
return JSONPromptMessageGQL(role=model.role, content=model.content)


@strawberry.type
class PromptMessagesTemplateV1GQL:
version: str
template: list[Union[TextPromptMessageGQL, JSONPromptMessageGQL]]

@classmethod
def from_model(cls, model: PromptMessagesTemplateV1) -> "PromptMessagesTemplateV1GQL":
return PromptMessagesTemplateV1GQL(
version=model._version,
messages=[message.to_gql() for message in model.template],
)
class PromptMessagesTemplateV1:
version: str = "messages-v1"
messages: list[Union[TextPromptMessage, JSONPromptMessage]]


@strawberry.type
class PromptStringTemplateGQL:
class PromptStringTemplate:
template: str

@classmethod
def from_model(cls, model: PromptStringTemplate) -> "PromptStringTemplateGQL":
return PromptStringTemplateGQL(template=model.template)
def from_model(cls, model: PromptStringTemplateModel) -> "PromptStringTemplate":
return PromptStringTemplate(template=model.template)


PromptTemplateVersion = strawberry.union(
"PromptTemplateVersion", (PromptStringTemplateGQL, PromptMessagesTemplateV1GQL)
"PromptTemplateVersion", (PromptStringTemplate, PromptMessagesTemplateV1)
)

0 comments on commit 31387c6

Please sign in to comment.