-
Notifications
You must be signed in to change notification settings - Fork 327
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Decouple pydantic models and gql types
- Loading branch information
1 parent
a5cae16
commit 31387c6
Showing
5 changed files
with
75 additions
and
74 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from enum import Enum | ||
from typing import Any, Union | ||
|
||
import strawberry | ||
from pydantic import BaseModel | ||
|
||
JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]] | ||
|
||
|
||
@strawberry.enum | ||
class PromptMessageRole(str, Enum): | ||
USER = "user" | ||
SYSTEM = "system" | ||
AI = "ai" # E.g. the assistant. Normalize to AI for consistency. | ||
|
||
|
||
class PromptStringTemplate(BaseModel): | ||
template: str | ||
|
||
|
||
class TextPromptMessage(BaseModel): | ||
role: PromptMessageRole | ||
content: str | ||
|
||
|
||
class JSONPromptMessage(BaseModel): | ||
role: PromptMessageRole | ||
content: JSONSerializable | ||
|
||
|
||
class PromptMessagesTemplateV1(BaseModel): | ||
_version: str = "messages-v1" | ||
template: list[Union[TextPromptMessage, JSONPromptMessage]] | ||
|
||
|
||
# TODO: Figure out enums, maybe just store whole tool blobs | ||
# class PromptToolParameter(BaseModel): | ||
# name: str | ||
# type: str | ||
# description: str | ||
# required: bool | ||
# default: str | ||
|
||
|
||
class PromptToolDefinition(BaseModel): | ||
definition: JSONSerializable | ||
|
||
|
||
class PromptTools(BaseModel): | ||
_version: str = "tools-v1" | ||
tools: list[PromptToolDefinition] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,45 @@ | ||
# Part of the Phoenix PromptHub feature set | ||
|
||
from enum import Enum | ||
from typing import Any, Union | ||
from typing import Union | ||
|
||
import strawberry | ||
from pydantic import BaseModel | ||
from strawberry.scalars import JSON | ||
|
||
JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]] | ||
|
||
|
||
@strawberry.enum | ||
class PromptMessageRole(str, Enum): | ||
USER = "user" | ||
SYSTEM = "system" | ||
AI = "ai" # E.g. the assistant. Normalize to AI for consistency. | ||
|
||
|
||
class PromptStringTemplate(BaseModel): | ||
template: str | ||
|
||
|
||
class TextPromptMessage(BaseModel): | ||
role: PromptMessageRole | ||
content: str | ||
|
||
def to_gql(self): | ||
return TextPromptMessageGQL.from_model(self) | ||
|
||
|
||
class JSONPromptMessage(BaseModel): | ||
role: PromptMessageRole | ||
content: JSONSerializable | ||
|
||
def to_gql(self): | ||
return JSONPromptMessageGQL.from_model(self) | ||
|
||
|
||
class PromptMessagesTemplateV1(BaseModel): | ||
_version: str = "messages-v1" | ||
template: list[Union[TextPromptMessage, JSONPromptMessage]] | ||
from phoenix.server.api.helpers.prompthub.models import ( | ||
PromptMessageRole, | ||
) | ||
from phoenix.server.api.helpers.prompthub.models import ( | ||
PromptStringTemplate as PromptStringTemplateModel, | ||
) | ||
|
||
|
||
@strawberry.type | ||
class TextPromptMessageGQL: | ||
class TextPromptMessage: | ||
role: PromptMessageRole | ||
content: str | ||
|
||
@classmethod | ||
def from_model(cls, model: TextPromptMessage) -> "TextPromptMessageGQL": | ||
return TextPromptMessageGQL(role=model.role, content=model.content) | ||
|
||
|
||
@strawberry.type | ||
class JSONPromptMessageGQL: | ||
class JSONPromptMessage: | ||
role: PromptMessageRole | ||
content: JSON | ||
|
||
@classmethod | ||
def from_model(cls, model: JSONPromptMessage) -> "JSONPromptMessageGQL": | ||
return JSONPromptMessageGQL(role=model.role, content=model.content) | ||
|
||
|
||
@strawberry.type | ||
class PromptMessagesTemplateV1GQL: | ||
version: str | ||
template: list[Union[TextPromptMessageGQL, JSONPromptMessageGQL]] | ||
|
||
@classmethod | ||
def from_model(cls, model: PromptMessagesTemplateV1) -> "PromptMessagesTemplateV1GQL": | ||
return PromptMessagesTemplateV1GQL( | ||
version=model._version, | ||
messages=[message.to_gql() for message in model.template], | ||
) | ||
class PromptMessagesTemplateV1: | ||
version: str = "messages-v1" | ||
messages: list[Union[TextPromptMessage, JSONPromptMessage]] | ||
|
||
|
||
@strawberry.type | ||
class PromptStringTemplateGQL: | ||
class PromptStringTemplate: | ||
template: str | ||
|
||
@classmethod | ||
def from_model(cls, model: PromptStringTemplate) -> "PromptStringTemplateGQL": | ||
return PromptStringTemplateGQL(template=model.template) | ||
def from_model(cls, model: PromptStringTemplateModel) -> "PromptStringTemplate": | ||
return PromptStringTemplate(template=model.template) | ||
|
||
|
||
PromptTemplateVersion = strawberry.union( | ||
"PromptTemplateVersion", (PromptStringTemplateGQL, PromptMessagesTemplateV1GQL) | ||
"PromptTemplateVersion", (PromptStringTemplate, PromptMessagesTemplateV1) | ||
) |