Skip to content

Commit

Permalink
fix thread/run thread/{thread_id} route order for craete_thread_and_r…
Browse files Browse the repository at this point in the history
…un to work
  • Loading branch information
phact committed Aug 31, 2024
1 parent 3666c71 commit 2e07330
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 39 deletions.
23 changes: 21 additions & 2 deletions client/tests/astra-assistants/test_run_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,38 @@ def run_with_assistant(assistant, client):



def create_and_run_with_assistant(assistant, client):
user_message = "What's your favorite animal."

thread = client.beta.threads.create()

client.beta.threads.messages.create(
thread_id=thread.id, role="user", content=user_message
)
run = client.beta.threads.create_and_run(
thread=thread,
assistant_id=assistant.id,
)

logger.info(run)





instructions="You're an animal expert who gives very long winded answers with flowery prose. Keep answers below 3 sentences."
def test_run_gpt_4o_mini(patched_openai_client):
gpt3_assistant = patched_openai_client.beta.assistants.create(
name="GPT3 Animal Tutor",
instructions=instructions,
model="gpt-4o_mini",
model="gpt-4o-mini",
)

assistant = patched_openai_client.beta.assistants.retrieve(gpt3_assistant.id)
logger.info(assistant)

run_with_assistant(gpt3_assistant, patched_openai_client)
create_and_run_with_assistant(gpt3_assistant, patched_openai_client)

def test_run_cohere(patched_openai_client):
cohere_assistant = patched_openai_client.beta.assistants.create(
Expand Down Expand Up @@ -91,4 +110,4 @@ def test_run_gemini(patched_openai_client):
instructions=instructions,
model="gemini/gemini-1.5-flash",
)
run_with_assistant(gemini_assistant, patched_openai_client)
run_with_assistant(gemini_assistant, patched_openai_client)
76 changes: 39 additions & 37 deletions impl/routes_v2/threads_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,45 @@ async def create_thread(
)
return astradb.upsert_table_from_base_model("threads", thread)

@router.post(
"/threads/runs",
responses={
200: {"model": RunObject, "description": "OK"},
},
tags=["Assistants"],
summary="Create a thread and run it in one request.",
response_model_by_alias=True,
response_model=None
)
async def create_thread_and_run(
create_thread_and_run_request: CreateThreadAndRunRequest = Body(None, description=""),
astradb: CassandraClient = Depends(verify_db_client),
embedding_model: str = Depends(infer_embedding_model),
embedding_api_key: str = Depends(infer_embedding_api_key),
litellm_kwargs: tuple[Dict[str, Any]] = Depends(get_litellm_kwargs),
) -> RunObject:
create_thread_request = create_thread_and_run_request.thread
if create_thread_request is None:
raise HTTPException(status_code=400, detail="thread is required.")

thread = await create_thread(create_thread_request, astradb)

create_run_request = CreateRunRequest(
assistant_id=create_thread_and_run_request.assistant_id,
model=create_thread_and_run_request.model,
instructions=create_thread_and_run_request.instructions,
tools=create_thread_and_run_request.tools,
metadata=create_thread_and_run_request.metadata
)
return await create_run(
thread_id=thread.id,
create_run_request=create_run_request,
astradb=astradb,
embedding_model=embedding_model,
embedding_api_key=embedding_api_key,
litellm_kwargs=litellm_kwargs,
)

@router.get(
"/threads/{thread_id}",
responses={
Expand Down Expand Up @@ -1823,41 +1862,4 @@ async def make_text_delta_obj_from_chunk(chunk, i, run, message_id):
return message_delta


@router.post(
"/threads/runs",
responses={
200: {"model": RunObject, "description": "OK"},
},
tags=["Assistants"],
summary="Create a thread and run it in one request.",
response_model_by_alias=True,
response_model=None
)
async def create_thread_and_run(
create_thread_and_run_request: CreateThreadAndRunRequest = Body(None, description=""),
astradb: CassandraClient = Depends(verify_db_client),
embedding_model: str = Depends(infer_embedding_model),
embedding_api_key: str = Depends(infer_embedding_api_key),
litellm_kwargs: tuple[Dict[str, Any]] = Depends(get_litellm_kwargs),
) -> RunObject:
create_thread_request = create_thread_and_run_request.thread
if create_thread_request is None:
raise HTTPException(status_code=400, detail="thread is required.")

thread = await create_thread(create_thread_request, astradb)

create_run_request = CreateRunRequest(
assistant_id=create_thread_and_run_request.assistant_id,
model=create_thread_and_run_request.model,
instructions=create_thread_and_run_request.instructions,
tools=create_thread_and_run_request.tools,
metadata=create_thread_and_run_request.metadata
)
return await create_run(
thread_id=thread.id,
create_run_request=create_run_request,
astradb=astradb,
embedding_model=embedding_model,
embedding_api_key=embedding_api_key,
litellm_kwargs=litellm_kwargs,
)

0 comments on commit 2e07330

Please sign in to comment.