-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add message history truncation
- Loading branch information
1 parent
9e12f5e
commit 820a8cf
Showing
5 changed files
with
268 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
184 changes: 184 additions & 0 deletions
184
control-plane/src/modules/workflows/agent/overflow.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import { AgentError } from '../../../utilities/errors'; | ||
import { WorkflowAgentStateMessage } from './state'; | ||
import { handleContextWindowOverflow } from './overflow'; | ||
import { estimateTokenCount } from './utils'; | ||
|
||
jest.mock('./utils', () => ({ | ||
estimateTokenCount: jest.fn(), | ||
})); | ||
|
||
describe('handleContextWindowOverflow', () => { | ||
beforeEach(() => { | ||
jest.clearAllMocks(); | ||
}); | ||
|
||
it('should throw if system prompt exceeds threshold', async () => { | ||
const systemPrompt = 'System prompt'; | ||
const messages: WorkflowAgentStateMessage[] = []; | ||
const modelContextWindow = 1000; | ||
|
||
(estimateTokenCount as jest.Mock) | ||
.mockResolvedValueOnce(701); // system prompt (0.7 * 1000) | ||
|
||
await expect( | ||
handleContextWindowOverflow({ | ||
systemPrompt, | ||
messages, | ||
modelContextWindow, | ||
}) | ||
).rejects.toThrow(new AgentError('System prompt can not exceed 700 tokens')); | ||
}); | ||
|
||
it('should not modify messages if total tokens are under threshold', async () => { | ||
const systemPrompt = 'System prompt'; | ||
const messages: WorkflowAgentStateMessage[] = [ | ||
{ type: 'human', data: { message: 'Hello' } } as any, | ||
{ type: 'agent', data: { message: 'Hi' } } as any, | ||
]; | ||
const modelContextWindow = 1000; | ||
|
||
(estimateTokenCount as jest.Mock) | ||
.mockResolvedValueOnce(100) // system prompt | ||
.mockResolvedValueOnce(200); // messages | ||
|
||
const result = await handleContextWindowOverflow({ | ||
systemPrompt, | ||
messages, | ||
modelContextWindow, | ||
}); | ||
|
||
expect(result).toEqual(messages); | ||
expect(messages).toHaveLength(2); | ||
}); | ||
|
||
it('should handle empty messages array', async () => { | ||
const systemPrompt = 'System prompt'; | ||
const messages: WorkflowAgentStateMessage[] = []; | ||
const modelContextWindow = 1000; | ||
|
||
(estimateTokenCount as jest.Mock) | ||
.mockResolvedValueOnce(200) // system prompt | ||
.mockResolvedValueOnce(0); // empty messages | ||
|
||
const result = await handleContextWindowOverflow({ | ||
systemPrompt, | ||
messages, | ||
modelContextWindow, | ||
}); | ||
|
||
expect(result).toEqual(messages); | ||
expect(messages).toHaveLength(0); | ||
}); | ||
|
||
describe('truncate strategy', () => { | ||
it('should remove messages until total tokens are under threshold', async () => { | ||
const systemPrompt = 'System prompt'; | ||
const messages: WorkflowAgentStateMessage[] = Array(5).fill({ | ||
type: 'human', | ||
data: { message: 'Message' }, | ||
}); | ||
const modelContextWindow = 600; | ||
|
||
(estimateTokenCount as jest.Mock) | ||
.mockResolvedValueOnce(200) // system prompt | ||
.mockResolvedValueOnce(900) // initial messages | ||
.mockResolvedValueOnce(700) // after first removal | ||
.mockResolvedValueOnce(500) // after second removal | ||
.mockResolvedValueOnce(300); // after third removal | ||
|
||
const result = await handleContextWindowOverflow({ | ||
systemPrompt, | ||
messages, | ||
modelContextWindow, | ||
}); | ||
|
||
expect(result).toHaveLength(2); | ||
}); | ||
|
||
it('should throw if a single message exceeds the context window', async () => { | ||
const systemPrompt = 'System prompt'; | ||
const messages: WorkflowAgentStateMessage[] = [ | ||
{ type: 'human', data: { message: 'Message' } } as any, | ||
]; | ||
const modelContextWindow = 400; | ||
|
||
(estimateTokenCount as jest.Mock) | ||
.mockResolvedValueOnce(200) // system prompt | ||
.mockResolvedValueOnce(400); // message | ||
|
||
await expect( | ||
handleContextWindowOverflow({ | ||
systemPrompt, | ||
messages, | ||
modelContextWindow, | ||
}) | ||
).rejects.toThrow(AgentError); | ||
}); | ||
|
||
|
||
it('should remove tool invocation result when removing agent message', async () => { | ||
const systemPrompt = 'System prompt'; | ||
const messages: WorkflowAgentStateMessage[] = [ | ||
{ id: "123", type: 'agent', data: { message: 'Hi', invocations: [ | ||
{ | ||
id: "toolCallId1", | ||
}, | ||
{ | ||
id: "toolCallId2", | ||
}, | ||
{ | ||
id: "toolCallId3", | ||
}, | ||
]}} as any, | ||
{ id: "456", type: 'invocation-result', data: { id: "toolCallId1" } } as any, | ||
{ id: "456", type: 'invocation-result', data: { id: "toolCallId2" } } as any, | ||
{ id: "456", type: 'invocation-result', data: { id: "toolCallId3" } } as any, | ||
{ id: "789", type: 'human', data: { message: 'Hello' }} as any, | ||
]; | ||
const modelContextWindow = 1100; | ||
|
||
(estimateTokenCount as jest.Mock) | ||
.mockResolvedValueOnce(200) // system prompt | ||
.mockResolvedValueOnce(1000) // initial messages | ||
.mockResolvedValueOnce(800) // after first removal | ||
|
||
const result = await handleContextWindowOverflow({ | ||
systemPrompt, | ||
messages, | ||
modelContextWindow, | ||
}); | ||
|
||
expect(result).toHaveLength(1); | ||
expect(result[0].type).toBe('human'); | ||
}); | ||
|
||
it('should remove agent message when removing tool invocation result', async () => { | ||
const systemPrompt = 'System prompt'; | ||
const messages: WorkflowAgentStateMessage[] = [ | ||
{ id: "456", type: 'invocation-result', data: { id: "toolCallId1" } } as any, | ||
{ id: "123", type: 'agent', data: { message: 'Hi', invocations: [ | ||
{ | ||
id: "toolCallId1", | ||
}, | ||
]}} as any, | ||
{ id: "789", type: 'human', data: { message: 'Hello' }} as any, | ||
]; | ||
const modelContextWindow = 1100; | ||
|
||
(estimateTokenCount as jest.Mock) | ||
.mockResolvedValueOnce(200) // system prompt | ||
.mockResolvedValueOnce(1000) // initial messages | ||
.mockResolvedValueOnce(800) // after first removal | ||
.mockResolvedValueOnce(600) // after second removal | ||
|
||
const result = await handleContextWindowOverflow({ | ||
systemPrompt, | ||
messages, | ||
modelContextWindow, | ||
}); | ||
|
||
expect(result).toHaveLength(1); | ||
expect(result[0].type).toBe('human'); | ||
}) | ||
}) | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import { AgentError } from "../../../utilities/errors"; | ||
import { logger } from "../../observability/logger"; | ||
import { WorkflowAgentStateMessage } from "./state"; | ||
import { estimateTokenCount } from "./utils"; | ||
|
||
const TOTAL_CONTEXT_THRESHOLD = 0.95; | ||
const SYSTEM_PROMPT_THRESHOLD = 0.7; | ||
|
||
export const handleContextWindowOverflow = async ({ | ||
systemPrompt, | ||
messages, | ||
modelContextWindow, | ||
render = JSON.stringify | ||
}: { | ||
systemPrompt: string | ||
messages: WorkflowAgentStateMessage[] | ||
modelContextWindow: number | ||
render? (message: WorkflowAgentStateMessage): unknown | ||
//strategy?: "truncate" | ||
}) => { | ||
const systemPromptTokenCount = await estimateTokenCount(systemPrompt); | ||
|
||
if (systemPromptTokenCount > modelContextWindow * SYSTEM_PROMPT_THRESHOLD) { | ||
throw new AgentError(`System prompt can not exceed ${modelContextWindow * SYSTEM_PROMPT_THRESHOLD} tokens`); | ||
} | ||
|
||
let messagesTokenCount = await estimateTokenCount(messages.map(render).join("\n")); | ||
if (messagesTokenCount + systemPromptTokenCount < (modelContextWindow * TOTAL_CONTEXT_THRESHOLD)) { | ||
return messages; | ||
} | ||
|
||
logger.info("Chat history exceeds context window, early messages will be dropped", { | ||
systemPromptTokenCount, | ||
messagesTokenCount, | ||
}) | ||
|
||
do { | ||
if (messages.length === 1) { | ||
throw new AgentError("Single chat message exceeds context window"); | ||
} | ||
|
||
const removed = messages.shift(); | ||
|
||
// This _techincally_ shouldn't happen as we remove from the earliest message forward. | ||
// If the removed message is an invocation result, we need to remove the agent message also | ||
if (removed?.type === 'invocation-result') { | ||
messages = messages.filter((message) => message.type !== 'agent' || !message.data.invocations?.find((tool) => tool.id === removed.data.id)); | ||
} | ||
|
||
// If the removed message is an agent message, we need to remove the invocation result message also | ||
if (removed?.type === 'agent') { | ||
removed.data.invocations?.forEach((tool) => { | ||
messages = messages.filter((message) => message.type !== 'invocation-result' || message.data.id !== tool.id); | ||
}) | ||
} | ||
|
||
logger.info("Dropping early message"); | ||
|
||
messagesTokenCount = await estimateTokenCount(messages.map(render).join("\n")); | ||
|
||
} while (messagesTokenCount + systemPromptTokenCount > modelContextWindow * TOTAL_CONTEXT_THRESHOLD); | ||
|
||
return messages; | ||
}; |