From 9d5b0a30e6a529376cdd5a18c2285914490b794f Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Fri, 11 Oct 2024 15:10:03 -0400 Subject: [PATCH] feat: Add model listing (#4948) * Add query for model providers * gql build * Hardcode all models for now --- app/schema.graphql | 6 +++ src/phoenix/server/api/queries.py | 44 +++++++++++++++++++ src/phoenix/server/api/types/ModelProvider.py | 9 ++++ 3 files changed, 59 insertions(+) create mode 100644 src/phoenix/server/api/types/ModelProvider.py diff --git a/app/schema.graphql b/app/schema.graphql index dc51800104e..f3433d7fde0 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -926,6 +926,11 @@ type Model { ): PerformanceTimeSeries! } +type ModelProvider { + name: String! + modelNames: [String!]! +} + type Mutation { createSystemApiKey(input: CreateApiKeyInput!): CreateSystemApiKeyMutationPayload! createUserApiKey(input: CreateUserApiKeyInput!): CreateUserApiKeyMutationPayload! @@ -1118,6 +1123,7 @@ type PromptResponse { } type Query { + modelProviders(vendors: [String!]!): [ModelProvider!]! users(first: Int = 50, last: Int, after: String, before: String): UserConnection! userRoles: [UserRole!]! userApiKeys: [UserApiKey!]! diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index 0c52856db80..6ca3e5c8481 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -60,6 +60,7 @@ from phoenix.server.api.types.Functionality import Functionality from phoenix.server.api.types.InferencesRole import AncillaryInferencesRole, InferencesRole from phoenix.server.api.types.Model import Model +from phoenix.server.api.types.ModelProvider import ModelProvider from phoenix.server.api.types.node import from_global_id, from_global_id_with_expected_type from phoenix.server.api.types.pagination import ( ConnectionArgs, @@ -78,6 +79,49 @@ @strawberry.type class Query: + @strawberry.field + async def model_providers( + self, vendors: List[str], info: Info[Context, None] + ) -> List[ModelProvider]: + all_vendors = { + "OpenAI": ModelProvider( # https://platform.openai.com/docs/models + name="OpenAI", # currently only models using the chat completions API + model_names=[ + "o1-preview", + "o1-preview-2024-09-12", + "o1-mini", + "o1-mini-2024-09-12", + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-0613", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-instruct", + ], + ), + "Anthropic": ModelProvider( # https://docs.anthropic.com/en/docs/about-claude/models#model-comparison + name="Anthropic", # currently only models using the messages API + model_names=[ + "claude-3-5-sonnet-20240620", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + ], + ), + } + return [all_vendors[vendor] for vendor in vendors] + @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def users( self, diff --git a/src/phoenix/server/api/types/ModelProvider.py b/src/phoenix/server/api/types/ModelProvider.py new file mode 100644 index 00000000000..680bc3aaaa9 --- /dev/null +++ b/src/phoenix/server/api/types/ModelProvider.py @@ -0,0 +1,9 @@ +from typing import List + +import strawberry + + +@strawberry.type +class ModelProvider: + name: str + model_names: List[str]