Skip to content

Commit

Permalink
Merge pull request #48 from langtail/ensure-non-null-content
Browse files Browse the repository at this point in the history
ensure that the content of the AI message when finalized is not null
  • Loading branch information
vojtatranta authored Aug 9, 2024
2 parents 2bec9bd + d147cdd commit 6fcde04
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 2 deletions.
74 changes: 74 additions & 0 deletions src/react/useChatStream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,80 @@ describe("useAIStream", () => {
})
})


it("should ensure that the content isn't nullish", async () => {
function createMockReadadbleStream(dataEmitter: DataEventListener) {
return new ReadableStream({
start(controller) {
dataEmitter.addEventListener('data', (data: string) => {
controller.enqueue(data)
controller.close();
})
},
});
}

const dataEmitter = new DataEventListener()

const stream = createMockReadadbleStream(dataEmitter)

let ran = false
const createReadableStream = vi.fn(() => {
// NOTE: run this only once
if (ran) {
return Promise.reject('Error in tools!')
}

ran = true
return Promise.resolve(stream)
})

const onToolCall = () => Promise.resolve('Result in test')

const { result } = renderHook(() =>
useChatStream({
fetcher: createReadableStream,
onToolCall
}),
)

act(() => {
result.current.send('user input')
dataEmitter.dispatchEvent('data',
`{"id":"chatcmpl-9aJwNzlnvn1jG845CJe2QZH6AKcow","object":"chat.completion.chunk","created":1718443487,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_319be4768e","choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}],"usage":null}\n
{"id":"chatcmpl-9aJwNzlnvn1jG845CJe2QZH6AKcow","object":"chat.completion.chunk","created":1718443487,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_319be4768e","choices":[{"index":0,"delta":{"content": null},"logprobs":null,"finish_reason":null}],"usage":null}\n
{"id":"chatcmpl-9aJwNzlnvn1jG845CJe2QZH6AKcow","object":"chat.completion.chunk","created":1718443487,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_319be4768e","choices":[{"index":0,"delta":{"tool_calls":[{"index":0, "id":"call_tNW2f79DhRvuuwrslSYt3yVT", "type": "function", "function":{"name":"get_weather", "arguments":"{\\"location\\":\\"Prague, Czech Republic\\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null}\n
{"id":"chatcmpl-9aJwNzlnvn1jG845CJe2QZH6AKcow","object":"chat.completion.chunk","created":1718443487,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_319be4768e","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null}\n
{"id":"chatcmpl-9aJwNzlnvn1jG845CJe2QZH6AKcow","object":"chat.completion.chunk","created":1718443487,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_319be4768e","choices":[],"usage":{"prompt_tokens":284,"completion_tokens":34,"total_tokens":318}}\n\n`
)
})

await vi.waitFor(() => {
expect(result.current.messages).toEqual([
{ role: 'user', content: 'user input' },
{
"content": "",
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "{\"location\":\"Prague, Czech Republic\"}",
"name": "get_weather",
},
"id": "call_tNW2f79DhRvuuwrslSYt3yVT",
"type": "function",
},
]
},
{
"content": "Result in test",
"role": "tool",
"tool_call_id": "call_tNW2f79DhRvuuwrslSYt3yVT",
}
])
})
})

it("should pass tool call result to the messages", async () => {
function createMockReadadbleStream(dataEmitter: DataEventListener) {
return new ReadableStream({
Expand Down
14 changes: 12 additions & 2 deletions src/react/useChatStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
ChatCompletionMessageToolCall,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionMessage,
} from "openai/resources"
import { chatStreamToRunner, type ChatCompletionStream } from "../stream"
import { useRef, useState } from "react"
Expand Down Expand Up @@ -118,6 +119,14 @@ export function combineAIMessageChunkWithCompleteMessages(
})
}

function normalizeMessage(message: ChatCompletionMessage) {
return {
...message,
// NOTE: ensure that message isn't null or undefined
content: message.content ?? "",
}
}

function parameterToMessage(
parameter: ChatMessage | ChatMessage[] | string,
): ChatMessage[] {
Expand Down Expand Up @@ -183,6 +192,7 @@ export function useChatStream<
const generatingRef = useRef<boolean>(false)
const endedRef = useRef<boolean>(false)
const errorRef = useRef<Error | null>(null)
const messageMode = options.messageMode ?? 'append'

function setIsLoadingState(generating: boolean) {
generatingRef.current = generating
Expand Down Expand Up @@ -226,7 +236,7 @@ export function useChatStream<
const abortController = new AbortController()
abortControllerRef.current = abortController

switch (options.messageMode ?? 'append') {
switch (messageMode) {
case 'replace':
setMessagesState(parameterToMessage(parameter))
case 'append':
Expand Down Expand Up @@ -255,7 +265,7 @@ export function useChatStream<
!("id" in currentMessage) ||
currentMessage.id !== finalMessage.id,
)
.concat(finalMessage.choices.flatMap((choice) => choice.message))
.concat(finalMessage.choices.flatMap((choice) => normalizeMessage(choice.message)))

const userChatMessages = mapAIMessagesToChatCompletions(
messagesRef.current,
Expand Down

0 comments on commit 6fcde04

Please sign in to comment.