Skip to content

Commit

Permalink
feat: Add message history truncation
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjcsmith committed Dec 23, 2024
1 parent 9e12f5e commit 820a8cf
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 3 deletions.
3 changes: 3 additions & 0 deletions control-plane/src/modules/models/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Anthropic from "@anthropic-ai/sdk";
import { ToolUseBlock } from "@anthropic-ai/sdk/resources";
import {
ChatIdentifiers,
CONTEXT_WINDOW,
EmbeddingIdentifiers,
getEmbeddingRouting,
getRouting,
Expand Down Expand Up @@ -47,6 +48,7 @@ export type Model = {
options: T,
) => Promise<StructuredCallOutput>;
identifier: ChatIdentifiers | EmbeddingIdentifiers;
contextWindow?: number;
embedQuery: (input: string) => Promise<number[]>;
};

Expand All @@ -72,6 +74,7 @@ export const buildModel = ({

return {
identifier,
contextWindow: CONTEXT_WINDOW[identifier],
embedQuery: async (input: string) => {
if (!isEmbeddingIdentifier(identifier)) {
throw new Error(`${identifier} is not an embedding model`);
Expand Down
5 changes: 5 additions & 0 deletions control-plane/src/modules/models/routing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ import { logger } from "../observability/logger";
import { BedrockCohereEmbeddings } from "../embeddings/bedrock-cohere-embeddings";
import { CohereEmbeddings } from "@langchain/cohere";

export const CONTEXT_WINDOW: Record<string, number> = {
"claude-3-5-sonnet": 1000,
"claude-3-haiku": 1000,
};

const routingOptions = {
"claude-3-5-sonnet": [
...(env.BEDROCK_AVAILABLE
Expand Down
15 changes: 12 additions & 3 deletions control-plane/src/modules/workflows/agent/nodes/model-call.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ReleventToolLookup } from '../agent';
import { toAnthropicMessages } from '../../workflow-messages';
import { toAnthropicMessage, toAnthropicMessages } from '../../workflow-messages';
import { logger } from '../../../observability/logger';
import { WorkflowAgentState, WorkflowAgentStateMessage } from '../state';
import { addAttributes, withSpan } from '../../../observability/tracer';
Expand All @@ -14,6 +14,7 @@ import { ToolUseBlock } from '@anthropic-ai/sdk/resources';
import { Schema, Validator } from 'jsonschema';
import { buildModelSchema, ModelOutput } from './model-output';
import { getSystemPrompt } from './system-prompt';
import { handleContextWindowOverflow } from '../overflow';

type WorkflowStateUpdate = Partial<WorkflowAgentState>;

Expand All @@ -40,8 +41,6 @@ const _handleModelCall = async (
'model.identifier': model.identifier,
});

const renderedMessages = toAnthropicMessages(state.messages);

if (!!state.workflow.resultSchema) {
const resultSchemaErrors = validateFunctionSchema(
state.workflow.resultSchema as JsonSchemaInput
Expand All @@ -63,6 +62,15 @@ const _handleModelCall = async (

const systemPrompt = getSystemPrompt(state, schemaString);

const trimmedMessages = await handleContextWindowOverflow({
modelContextWindow: model.contextWindow ?? 0,
systemPrompt,
messages: state.messages,
render: toAnthropicMessage,
});

const renderedMessages = toAnthropicMessages(trimmedMessages);

if (state.workflow.debug) {
addAttributes({
'model.input.additional_context': state.additionalContext,
Expand Down Expand Up @@ -258,6 +266,7 @@ const _handleModelCall = async (
};
};


const detectCycle = (messages: WorkflowAgentStateMessage[]) => {
if (messages.length >= 100) {
throw new AgentError('Maximum workflow message length exceeded.');
Expand Down
184 changes: 184 additions & 0 deletions control-plane/src/modules/workflows/agent/overflow.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import { AgentError } from '../../../utilities/errors';
import { WorkflowAgentStateMessage } from './state';
import { handleContextWindowOverflow } from './overflow';
import { estimateTokenCount } from './utils';

jest.mock('./utils', () => ({
estimateTokenCount: jest.fn(),
}));

describe('handleContextWindowOverflow', () => {
beforeEach(() => {
jest.clearAllMocks();
});

it('should throw if system prompt exceeds threshold', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [];
const modelContextWindow = 1000;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(701); // system prompt (0.7 * 1000)

await expect(
handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
})
).rejects.toThrow(new AgentError('System prompt can not exceed 700 tokens'));
});

it('should not modify messages if total tokens are under threshold', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ type: 'human', data: { message: 'Hello' } } as any,
{ type: 'agent', data: { message: 'Hi' } } as any,
];
const modelContextWindow = 1000;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(100) // system prompt
.mockResolvedValueOnce(200); // messages

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(result).toEqual(messages);
expect(messages).toHaveLength(2);
});

it('should handle empty messages array', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [];
const modelContextWindow = 1000;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(0); // empty messages

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(result).toEqual(messages);
expect(messages).toHaveLength(0);
});

describe('truncate strategy', () => {
it('should remove messages until total tokens are under threshold', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = Array(5).fill({
type: 'human',
data: { message: 'Message' },
});
const modelContextWindow = 600;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(900) // initial messages
.mockResolvedValueOnce(700) // after first removal
.mockResolvedValueOnce(500) // after second removal
.mockResolvedValueOnce(300); // after third removal

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(result).toHaveLength(2);
});

it('should throw if a single message exceeds the context window', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ type: 'human', data: { message: 'Message' } } as any,
];
const modelContextWindow = 400;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(400); // message

await expect(
handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
})
).rejects.toThrow(AgentError);
});


it('should remove tool invocation result when removing agent message', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ id: "123", type: 'agent', data: { message: 'Hi', invocations: [
{
id: "toolCallId1",
},
{
id: "toolCallId2",
},
{
id: "toolCallId3",
},
]}} as any,
{ id: "456", type: 'invocation-result', data: { id: "toolCallId1" } } as any,
{ id: "456", type: 'invocation-result', data: { id: "toolCallId2" } } as any,
{ id: "456", type: 'invocation-result', data: { id: "toolCallId3" } } as any,
{ id: "789", type: 'human', data: { message: 'Hello' }} as any,
];
const modelContextWindow = 1100;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(1000) // initial messages
.mockResolvedValueOnce(800) // after first removal

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(result).toHaveLength(1);
expect(result[0].type).toBe('human');
});

it('should remove agent message when removing tool invocation result', async () => {
const systemPrompt = 'System prompt';
const messages: WorkflowAgentStateMessage[] = [
{ id: "456", type: 'invocation-result', data: { id: "toolCallId1" } } as any,
{ id: "123", type: 'agent', data: { message: 'Hi', invocations: [
{
id: "toolCallId1",
},
]}} as any,
{ id: "789", type: 'human', data: { message: 'Hello' }} as any,
];
const modelContextWindow = 1100;

(estimateTokenCount as jest.Mock)
.mockResolvedValueOnce(200) // system prompt
.mockResolvedValueOnce(1000) // initial messages
.mockResolvedValueOnce(800) // after first removal
.mockResolvedValueOnce(600) // after second removal

const result = await handleContextWindowOverflow({
systemPrompt,
messages,
modelContextWindow,
});

expect(result).toHaveLength(1);
expect(result[0].type).toBe('human');
})
})
});
64 changes: 64 additions & 0 deletions control-plane/src/modules/workflows/agent/overflow.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { AgentError } from "../../../utilities/errors";
import { logger } from "../../observability/logger";
import { WorkflowAgentStateMessage } from "./state";
import { estimateTokenCount } from "./utils";

const TOTAL_CONTEXT_THRESHOLD = 0.95;
const SYSTEM_PROMPT_THRESHOLD = 0.7;

export const handleContextWindowOverflow = async ({
systemPrompt,
messages,
modelContextWindow,
render = JSON.stringify
}: {
systemPrompt: string
messages: WorkflowAgentStateMessage[]
modelContextWindow: number
render? (message: WorkflowAgentStateMessage): unknown
//strategy?: "truncate"
}) => {
const systemPromptTokenCount = await estimateTokenCount(systemPrompt);

if (systemPromptTokenCount > modelContextWindow * SYSTEM_PROMPT_THRESHOLD) {
throw new AgentError(`System prompt can not exceed ${modelContextWindow * SYSTEM_PROMPT_THRESHOLD} tokens`);
}

let messagesTokenCount = await estimateTokenCount(messages.map(render).join("\n"));
if (messagesTokenCount + systemPromptTokenCount < (modelContextWindow * TOTAL_CONTEXT_THRESHOLD)) {
return messages;
}

logger.info("Chat history exceeds context window, early messages will be dropped", {
systemPromptTokenCount,
messagesTokenCount,
})

do {
if (messages.length === 1) {
throw new AgentError("Single chat message exceeds context window");
}

const removed = messages.shift();

// This _techincally_ shouldn't happen as we remove from the earliest message forward.
// If the removed message is an invocation result, we need to remove the agent message also
if (removed?.type === 'invocation-result') {
messages = messages.filter((message) => message.type !== 'agent' || !message.data.invocations?.find((tool) => tool.id === removed.data.id));
}

// If the removed message is an agent message, we need to remove the invocation result message also
if (removed?.type === 'agent') {
removed.data.invocations?.forEach((tool) => {
messages = messages.filter((message) => message.type !== 'invocation-result' || message.data.id !== tool.id);
})
}

logger.info("Dropping early message");

messagesTokenCount = await estimateTokenCount(messages.map(render).join("\n"));

} while (messagesTokenCount + systemPromptTokenCount > modelContextWindow * TOTAL_CONTEXT_THRESHOLD);

return messages;
};

0 comments on commit 820a8cf

Please sign in to comment.