From a4aaa08bc19decff078f35886e78b6a7c8d2aa47 Mon Sep 17 00:00:00 2001 From: lipandeng Date: Fri, 6 Dec 2024 17:36:15 +0800 Subject: [PATCH] feat: manually mirror eino's code from bytedance --- README.md | 18 +- README.zh_CN.md | 54 + callbacks/aspect_inject.go | 234 +++ callbacks/aspect_inject_test.go | 191 +++ callbacks/doc.go | 108 ++ callbacks/handler_builder.go | 191 +++ callbacks/interface.go | 100 ++ callbacks/internal/manager.go | 21 + callbacks/manager.go | 210 +++ callbacks/manager_test.go | 142 ++ callbacks/template/default.go | 51 + callbacks/template/template.go | 530 ++++++ callbacks/template/template_test.go | 335 ++++ components/document/callback_extra_loader.go | 94 ++ .../document/callback_extra_transformer.go | 91 ++ components/document/doc.go | 17 + components/document/interface.go | 48 + components/document/option.go | 189 +++ components/document/option_test.go | 65 + components/document/parser/ext_parser.go | 131 ++ components/document/parser/interface.go | 29 + components/document/parser/option.go | 115 ++ components/document/parser/option_test.go | 54 + components/document/parser/parser_test.go | 119 ++ components/document/parser/testdata/test.md | 2 + components/document/parser/text_parser.go | 59 + components/embedding/callback_extra.go | 120 ++ components/embedding/callback_extra_test.go | 33 + components/embedding/doc.go | 17 + components/embedding/interface.go | 24 + components/embedding/option.go | 60 + components/embedding/option_test.go | 30 + components/indexer/callback_extra.go | 89 + components/indexer/callback_extra_test.go | 35 + components/indexer/doc.go | 17 + components/indexer/interface.go | 32 + components/indexer/option.go | 73 + components/indexer/option_test.go | 45 + components/model/callback_extra.go | 126 ++ components/model/callback_extra_test.go | 35 + components/model/doc.go | 17 + components/model/interface.go | 38 + components/model/option.go | 132 ++ components/model/option_test.go | 88 + components/prompt/callback_extra.go | 96 ++ components/prompt/callback_extra_test.go | 35 + components/prompt/chat_template.go | 89 + components/prompt/chat_template_test.go | 115 ++ components/prompt/doc.go | 17 + components/prompt/interface.go | 29 + components/prompt/option.go | 48 + components/prompt/option_test.go | 51 + components/retriever/callback_extra.go | 100 ++ components/retriever/callback_extra_test.go | 35 + components/retriever/doc.go | 17 + components/retriever/interface.go | 42 + components/retriever/option.go | 111 ++ components/retriever/option_test.go | 60 + components/tool/callback_extra.go | 88 + components/tool/callback_extra_test.go | 37 + components/tool/doc.go | 17 + components/tool/interface.go | 45 + components/tool/option.go | 78 + components/tool/option_test.go | 54 + components/tool/utils/create_options.go | 176 ++ components/tool/utils/doc.go | 17 + components/tool/utils/invokable_func.go | 176 ++ components/tool/utils/invokable_func_test.go | 275 ++++ components/tool/utils/streamable_func.go | 128 ++ components/tool/utils/streamable_func_test.go | 99 ++ components/types.go | 63 + compose/chain.go | 600 +++++++ compose/chain_branch.go | 251 +++ compose/chain_branch_test.go | 274 ++++ compose/chain_parallel.go | 217 +++ compose/chain_test.go | 583 +++++++ compose/component_to_graph_node.go | 180 ++ compose/dag.go | 78 + compose/dag_test.go | 228 +++ compose/doc.go | 17 + compose/error.go | 61 + compose/generic_graph.go | 104 ++ compose/graph.go | 1004 ++++++++++++ compose/graph_add_node_options.go | 152 ++ compose/graph_call_options.go | 190 +++ compose/graph_call_options_test.go | 303 ++++ compose/graph_compile_options.go | 68 + compose/graph_node.go | 188 +++ compose/graph_node_checker.go | 34 + compose/graph_run.go | 544 ++++++ compose/graph_test.go | 1456 +++++++++++++++++ compose/introspect.go | 54 + compose/pregel.go | 72 + compose/runnable.go | 626 +++++++ compose/runnable_test.go | 210 +++ compose/state.go | 235 +++ compose/state_test.go | 321 ++++ compose/stream_concat.go | 271 +++ compose/stream_concat_test.go | 213 +++ compose/stream_reader.go | 116 ++ compose/stream_reader_test.go | 102 ++ compose/tool_node.go | 290 ++++ compose/tool_node_test.go | 518 ++++++ compose/types.go | 48 + compose/types_composable.go | 30 + compose/types_lambda.go | 265 +++ compose/types_lambda_test.go | 212 +++ compose/utils.go | 339 ++++ compose/utils_test.go | 163 ++ doc.go | 17 + flow/agent/agent_option.go | 70 + flow/agent/multiagent/host/callback.go | 121 ++ flow/agent/multiagent/host/compose.go | 207 +++ flow/agent/multiagent/host/compose_test.go | 338 ++++ flow/agent/multiagent/host/doc.go | 17 + flow/agent/multiagent/host/options.go | 29 + flow/agent/multiagent/host/types.go | 141 ++ flow/agent/react/callback.go | 34 + flow/agent/react/doc.go | 17 + flow/agent/react/react.go | 371 +++++ flow/agent/react/react_test.go | 605 +++++++ flow/retriever/multiquery/multi_query.go | 211 +++ flow/retriever/multiquery/multi_query_test.go | 118 ++ flow/retriever/router/router.go | 193 +++ flow/retriever/router/router_test.go | 134 ++ flow/retriever/utils/utils.go | 83 + go.mod | 47 + go.sum | 149 ++ internal/gmap/gmap.go | 122 ++ internal/gmap/gmap_test.go | 89 + internal/gslice/gslice.go | 39 + internal/gslice/gslice_test.go | 36 + .../mock/components/document/document_mock.go | 159 ++ .../components/embedding/Embedding_mock.go | 77 + .../mock/components/indexer/indexer_mock.go | 78 + .../mock/components/model/ChatModel_mock.go | 112 ++ .../components/retriever/retriever_mock.go | 78 + internal/mock/doc.go | 51 + profile/README.md | 13 - schema/doc.go | 17 + schema/document.go | 175 ++ schema/document_test.go | 53 + schema/message.go | 705 ++++++++ schema/message_parser.go | 138 ++ schema/message_parser_test.go | 179 ++ schema/message_test.go | 658 ++++++++ schema/stream.go | 773 +++++++++ schema/stream_test.go | 551 +++++++ schema/tool.go | 188 +++ schema/tool_test.go | 99 ++ utils/generic/generic.go | 67 + utils/generic/generic_test.go | 88 + utils/generic/type_name.go | 71 + utils/generic/type_name_test.go | 86 + utils/safe/panic.go | 40 + 155 files changed, 24554 insertions(+), 14 deletions(-) create mode 100644 README.zh_CN.md create mode 100644 callbacks/aspect_inject.go create mode 100644 callbacks/aspect_inject_test.go create mode 100644 callbacks/doc.go create mode 100644 callbacks/handler_builder.go create mode 100644 callbacks/interface.go create mode 100644 callbacks/internal/manager.go create mode 100644 callbacks/manager.go create mode 100644 callbacks/manager_test.go create mode 100644 callbacks/template/default.go create mode 100644 callbacks/template/template.go create mode 100644 callbacks/template/template_test.go create mode 100644 components/document/callback_extra_loader.go create mode 100644 components/document/callback_extra_transformer.go create mode 100644 components/document/doc.go create mode 100644 components/document/interface.go create mode 100644 components/document/option.go create mode 100644 components/document/option_test.go create mode 100644 components/document/parser/ext_parser.go create mode 100644 components/document/parser/interface.go create mode 100644 components/document/parser/option.go create mode 100644 components/document/parser/option_test.go create mode 100644 components/document/parser/parser_test.go create mode 100644 components/document/parser/testdata/test.md create mode 100644 components/document/parser/text_parser.go create mode 100644 components/embedding/callback_extra.go create mode 100644 components/embedding/callback_extra_test.go create mode 100644 components/embedding/doc.go create mode 100644 components/embedding/interface.go create mode 100644 components/embedding/option.go create mode 100644 components/embedding/option_test.go create mode 100644 components/indexer/callback_extra.go create mode 100644 components/indexer/callback_extra_test.go create mode 100644 components/indexer/doc.go create mode 100644 components/indexer/interface.go create mode 100644 components/indexer/option.go create mode 100644 components/indexer/option_test.go create mode 100644 components/model/callback_extra.go create mode 100644 components/model/callback_extra_test.go create mode 100644 components/model/doc.go create mode 100644 components/model/interface.go create mode 100644 components/model/option.go create mode 100644 components/model/option_test.go create mode 100644 components/prompt/callback_extra.go create mode 100644 components/prompt/callback_extra_test.go create mode 100644 components/prompt/chat_template.go create mode 100644 components/prompt/chat_template_test.go create mode 100644 components/prompt/doc.go create mode 100644 components/prompt/interface.go create mode 100644 components/prompt/option.go create mode 100644 components/prompt/option_test.go create mode 100644 components/retriever/callback_extra.go create mode 100644 components/retriever/callback_extra_test.go create mode 100644 components/retriever/doc.go create mode 100644 components/retriever/interface.go create mode 100644 components/retriever/option.go create mode 100644 components/retriever/option_test.go create mode 100644 components/tool/callback_extra.go create mode 100644 components/tool/callback_extra_test.go create mode 100644 components/tool/doc.go create mode 100644 components/tool/interface.go create mode 100644 components/tool/option.go create mode 100644 components/tool/option_test.go create mode 100644 components/tool/utils/create_options.go create mode 100644 components/tool/utils/doc.go create mode 100644 components/tool/utils/invokable_func.go create mode 100644 components/tool/utils/invokable_func_test.go create mode 100644 components/tool/utils/streamable_func.go create mode 100644 components/tool/utils/streamable_func_test.go create mode 100644 components/types.go create mode 100644 compose/chain.go create mode 100644 compose/chain_branch.go create mode 100644 compose/chain_branch_test.go create mode 100644 compose/chain_parallel.go create mode 100644 compose/chain_test.go create mode 100644 compose/component_to_graph_node.go create mode 100644 compose/dag.go create mode 100644 compose/dag_test.go create mode 100644 compose/doc.go create mode 100644 compose/error.go create mode 100644 compose/generic_graph.go create mode 100644 compose/graph.go create mode 100644 compose/graph_add_node_options.go create mode 100644 compose/graph_call_options.go create mode 100644 compose/graph_call_options_test.go create mode 100644 compose/graph_compile_options.go create mode 100644 compose/graph_node.go create mode 100644 compose/graph_node_checker.go create mode 100644 compose/graph_run.go create mode 100644 compose/graph_test.go create mode 100644 compose/introspect.go create mode 100644 compose/pregel.go create mode 100644 compose/runnable.go create mode 100644 compose/runnable_test.go create mode 100644 compose/state.go create mode 100644 compose/state_test.go create mode 100644 compose/stream_concat.go create mode 100644 compose/stream_concat_test.go create mode 100644 compose/stream_reader.go create mode 100644 compose/stream_reader_test.go create mode 100644 compose/tool_node.go create mode 100644 compose/tool_node_test.go create mode 100644 compose/types.go create mode 100644 compose/types_composable.go create mode 100644 compose/types_lambda.go create mode 100644 compose/types_lambda_test.go create mode 100644 compose/utils.go create mode 100644 compose/utils_test.go create mode 100644 doc.go create mode 100644 flow/agent/agent_option.go create mode 100644 flow/agent/multiagent/host/callback.go create mode 100644 flow/agent/multiagent/host/compose.go create mode 100644 flow/agent/multiagent/host/compose_test.go create mode 100644 flow/agent/multiagent/host/doc.go create mode 100644 flow/agent/multiagent/host/options.go create mode 100644 flow/agent/multiagent/host/types.go create mode 100644 flow/agent/react/callback.go create mode 100644 flow/agent/react/doc.go create mode 100644 flow/agent/react/react.go create mode 100644 flow/agent/react/react_test.go create mode 100644 flow/retriever/multiquery/multi_query.go create mode 100644 flow/retriever/multiquery/multi_query_test.go create mode 100644 flow/retriever/router/router.go create mode 100644 flow/retriever/router/router_test.go create mode 100644 flow/retriever/utils/utils.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/gmap/gmap.go create mode 100644 internal/gmap/gmap_test.go create mode 100644 internal/gslice/gslice.go create mode 100644 internal/gslice/gslice_test.go create mode 100644 internal/mock/components/document/document_mock.go create mode 100644 internal/mock/components/embedding/Embedding_mock.go create mode 100644 internal/mock/components/indexer/indexer_mock.go create mode 100644 internal/mock/components/model/ChatModel_mock.go create mode 100644 internal/mock/components/retriever/retriever_mock.go create mode 100644 internal/mock/doc.go delete mode 100644 profile/README.md create mode 100644 schema/doc.go create mode 100644 schema/document.go create mode 100644 schema/document_test.go create mode 100644 schema/message.go create mode 100644 schema/message_parser.go create mode 100644 schema/message_parser_test.go create mode 100644 schema/message_test.go create mode 100644 schema/stream.go create mode 100644 schema/stream_test.go create mode 100644 schema/tool.go create mode 100644 schema/tool_test.go create mode 100644 utils/generic/generic.go create mode 100644 utils/generic/generic_test.go create mode 100644 utils/generic/type_name.go create mode 100644 utils/generic/type_name_test.go create mode 100644 utils/safe/panic.go diff --git a/README.md b/README.md index a46ae92..a4fca0c 100644 --- a/README.md +++ b/README.md @@ -1 +1,17 @@ -# .github \ No newline at end of file +# Eino + +English | [中文](README.zh_CN.md) + +## Overview + + +## Security + +If you discover a potential security issue in this project, or think you may +have discovered a security issue, we ask that you notify Bytedance Security via our [security center](https://security.bytedance.com/src) or [vulnerability reporting email](sec@bytedance.com). + +Please do **not** create a public GitHub issue. + +## License + +This project is licensed under the [Apache-2.0 License](LICENSE.txt). \ No newline at end of file diff --git a/README.zh_CN.md b/README.zh_CN.md new file mode 100644 index 0000000..2daa101 --- /dev/null +++ b/README.zh_CN.md @@ -0,0 +1,54 @@ +# Eino + +[English](README.md) | 中文 + + +## 概括 + +Eino['aino] (近似音: i know) 旨在提供 Golang 语言的 AI 应用开发框架。 Eino 参考了开源社区中诸多优秀的 AI 应用开发框架,例如 LangChain、LangGraph、LlamaIndex 等,提供了更符合 Golang 编程习惯的 AI 应用开发框架。 + +Eino 提供了丰富的辅助AI应用开发的原子组件、集成组件、组件编排、切面扩展等能力,可以帮助开发者更加简单便捷地开发出架构清晰、易维护、高可用的AI应用。 + +## 框架特点 + +- **丰富的组件** + + 将多场景普遍使用的能力,抽象成可独立使用、可编排使用的组件,开箱即用。例如 ChatModel、PromptTemplate、Retriever、Loader 等。 + + 组件又可细分为:功能不可细拆的原子组件、由一到多中组件以某种范式组合而成的集成组件。 + +- **易用的图编排** + + 将各组件实例,作为图的节点,以图的点边关系连接,以边的方向逐步执行节点并传输数据流,将AI应用的逻辑以图的方式进行编排和执行。 + + 图编排可极大简化 **并行**、**异步** 逻辑的开发,并优化其代码结构 + +- **完善的流处理** + + 根据输入、输出是否为流式,可划分成 4 种交互模式。 图编排可根据上下游节点的输入、输出是否是流,自动进行 流 和 非流 的转换,极大地方便开发者对AI应用提供流的能力 + + | 函数名 | 模式说明 | + |-----------|-----------| + | Invoke | 输入是非流、输出是非流 | + | Stream | 输入是非流、输出是流 | + | Collect | 输入是流、输出是非流 | + | Transform | 输入是流、输出是流 | + +- **高扩展性的切面** + + 图编排为图、节点的执行前后提供切面的注入、执行机制。开发者可基于此机制,在不侵入主流程的前提下,灵活地设计和注入自己的切面能力。例如 Trace、埋点、日志等 + + +## 详细文档 + +// TODO:链接用户手册等文档 + +## 安全 + +如果你在该项目中发现潜在的安全问题,或你认为可能发现了安全问题,请通过我们的[安全中心](https://security.bytedance.com/src)或[漏洞报告邮箱](sec@bytedance.com)通知字节跳动安全团队。 + +请**不要**创建公开的 GitHub Issue。 + +## 开源许可证 + +本项目依据 [Apache-2.0 许可证](LICENSE.txt) 授权。 diff --git a/callbacks/aspect_inject.go b/callbacks/aspect_inject.go new file mode 100644 index 0000000..2ec6836 --- /dev/null +++ b/callbacks/aspect_inject.go @@ -0,0 +1,234 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package callbacks + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// Fast inject callback input / output aspect for component developer +// e.g. +// +// func (t *testchatmodel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (resp *schema.Message, err error) { +// defer func() { +// if err != nil { +// callbacks.OnEnd(ctx, err) +// } +// }() +// +// ctx = callbacks.OnStart(ctx, &model.CallbackInput{ +// Messages: input, +// Tools: nil, +// Extra: nil, +// }) +// +// // do smt +// +// ctx = callbacks.OnEnd(ctx, &model.CallbackOutput{ +// Message: resp, +// Extra: nil, +// }) +// +// return resp, nil +// } +// +// OnStart invokes the OnStart logic for the particular context, ensuring that all registered +// handlers are executed in reverse order (compared to add order) when a process begins. +func OnStart(ctx context.Context, input CallbackInput) context.Context { + mgr, ok := managerFromCtx(ctx) + if !ok { + return ctx + } + + for i := len(mgr.handlers) - 1; i >= 0; i-- { + handler := mgr.handlers[i] + timingChecker, ok := handler.(TimingChecker) + if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnStart) { + ctx = handler.OnStart(ctx, mgr.runInfo, input) + } + } + + return ctx +} + +// OnEnd invokes the OnEnd logic of the particular context, allowing for proper cleanup +// and finalization when a process ends. +// handlers are executed in normal order (compared to add order). +func OnEnd(ctx context.Context, output CallbackOutput) context.Context { + mgr, ok := managerFromCtx(ctx) + if !ok { + return ctx + } + + for i := 0; i < len(mgr.handlers); i++ { + handler := mgr.handlers[i] + timingChecker, ok := handler.(TimingChecker) + if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnEnd) { + ctx = handler.OnEnd(ctx, mgr.runInfo, output) + } + } + + return ctx +} + +// OnStartWithStreamInput invokes the OnStartWithStreamInput logic of the particular context, ensuring that +// every input stream should be closed properly in handler. +// handlers are executed in reverse order (compared to add order). +func OnStartWithStreamInput[T any](ctx context.Context, input *schema.StreamReader[T]) ( + nextCtx context.Context, newStreamReader *schema.StreamReader[T]) { + + mgr, ok := managerFromCtx(ctx) + if !ok { + return ctx, input + } + + if len(mgr.handlers) == 0 { + return ctx, input + } + + var neededHandlers []Handler + for i := range mgr.handlers { + h := mgr.handlers[i] + timingChecker, ok := h.(TimingChecker) + if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnStartWithStreamInput) { + neededHandlers = append(neededHandlers, h) + } + } + + if len(neededHandlers) == 0 { + return ctx, input + } + + cp := input.Copy(len(neededHandlers) + 1) + for i := len(neededHandlers) - 1; i >= 0; i-- { + h := neededHandlers[i] + ctx = h.OnStartWithStreamInput(ctx, mgr.runInfo, schema.StreamReaderWithConvert(cp[i], func(src T) (CallbackInput, error) { + return src, nil + })) + } + + return ctx, cp[len(cp)-1] +} + +// OnEndWithStreamOutput invokes the OnEndWithStreamOutput logic of the particular, ensuring that +// every input stream should be closed properly in handler. +// handlers are executed in normal order (compared to add order). +func OnEndWithStreamOutput[T any](ctx context.Context, output *schema.StreamReader[T]) ( + nextCtx context.Context, newStreamReader *schema.StreamReader[T]) { + + mgr, ok := managerFromCtx(ctx) + if !ok { + return ctx, output + } + + if len(mgr.handlers) == 0 { + return ctx, output + } + + var neededHandlers []Handler + for i := range mgr.handlers { + h := mgr.handlers[i] + timingChecker, ok := h.(TimingChecker) + if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnEndWithStreamOutput) { + neededHandlers = append(neededHandlers, h) + } + } + + if len(neededHandlers) == 0 { + return ctx, output + } + + cp := output.Copy(len(neededHandlers) + 1) + for i := 0; i < len(neededHandlers); i++ { + h := neededHandlers[i] + ctx = h.OnEndWithStreamOutput(ctx, mgr.runInfo, schema.StreamReaderWithConvert(cp[i], func(src T) (CallbackOutput, error) { + return src, nil + })) + } + + return ctx, cp[len(cp)-1] +} + +// OnError invokes the OnError logic of the particular, notice that error in stream will not represent here. +// handlers are executed in normal order (compared to add order). +func OnError(ctx context.Context, err error) context.Context { + mgr, ok := managerFromCtx(ctx) + if !ok { + return ctx + } + + for i := 0; i < len(mgr.handlers); i++ { + handler := mgr.handlers[i] + timingChecker, ok := handler.(TimingChecker) + if !ok || timingChecker.Needed(ctx, mgr.runInfo, TimingOnError) { + ctx = handler.OnError(ctx, mgr.runInfo, err) + } + } + + return ctx +} + +// SwitchRunInfo updates the RunInfo in the context if a previous RunInfo already exists for that context. +func SwitchRunInfo(ctx context.Context, info *RunInfo) context.Context { + cbm, ok := managerFromCtx(ctx) + if !ok { + return ctx + } + + return ctxWithManager(ctx, cbm.withRunInfo(info)) +} + +// InitCallbacks initializes a new context with the provided RunInfo and handlers. +// If successful, it returns a new context containing RunInfo and handlers; otherwise, it returns a context with a nil manager. +func InitCallbacks(ctx context.Context, info *RunInfo, handlers ...Handler) context.Context { + mgr, ok := newManager(info, handlers...) + if ok { + return ctxWithManager(ctx, mgr) + } + + return ctxWithManager(ctx, nil) +} + +// Needed checks if any callback handlers exist in this context. +func Needed(ctx context.Context) bool { + _, cbmOK := managerFromCtx(ctx) + return cbmOK +} + +// NeededForTiming checks if any callback handlers exist in this context that are needed for this specific timing. +func NeededForTiming(ctx context.Context, timing CallbackTiming) bool { + mgr, ok := managerFromCtx(ctx) + if !ok { + return false + } + + if len(mgr.handlers) == 0 { + return false + } + + for i := 0; i < len(mgr.handlers); i++ { + handler := mgr.handlers[i] + timingChecker, ok := handler.(TimingChecker) + if !ok || timingChecker.Needed(ctx, mgr.runInfo, timing) { + return true + } + } + + return false +} diff --git a/callbacks/aspect_inject_test.go b/callbacks/aspect_inject_test.go new file mode 100644 index 0000000..0a3af79 --- /dev/null +++ b/callbacks/aspect_inject_test.go @@ -0,0 +1,191 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package callbacks + +import ( + "context" + "fmt" + "io" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAspectInject(t *testing.T) { + t.Run("ctx without manager", func(t *testing.T) { + ctx := context.Background() + ctx = OnStart(ctx, 1) + ctx = OnEnd(ctx, 2) + ctx = OnError(ctx, fmt.Errorf("3")) + isr, isw := schema.Pipe[int](2) + go func() { + for i := 0; i < 10; i++ { + isw.Send(i, nil) + } + isw.Close() + }() + + var nisr *schema.StreamReader[int] + ctx, nisr = OnStartWithStreamInput(ctx, isr) + j := 0 + for { + i, err := nisr.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, j, i) + j++ + } + nisr.Close() + + osr, osw := schema.Pipe[int](2) + go func() { + for i := 0; i < 10; i++ { + osw.Send(i, nil) + } + osw.Close() + }() + + var nosr *schema.StreamReader[int] + ctx, nosr = OnEndWithStreamOutput(ctx, osr) + j = 0 + for { + i, err := nosr.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, j, i) + j++ + } + nosr.Close() + }) + + t.Run("ctx with manager", func(t *testing.T) { + ctx := context.Background() + cnt := 0 + + hb := NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { + cnt += input.(int) + return ctx + }). + OnEndFn(func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { + cnt += output.(int) + return ctx + }). + OnErrorFn(func(ctx context.Context, info *RunInfo, err error) context.Context { + v, _ := strconv.ParseInt(err.Error(), 10, 64) + cnt += int(v) + return ctx + }). + OnStartWithStreamInputFn(func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context { + for { + i, err := input.Recv() + if err == io.EOF { + break + } + + cnt += i.(int) + } + + input.Close() + return ctx + }). + OnEndWithStreamOutputFn(func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context { + for { + o, err := output.Recv() + if err == io.EOF { + break + } + + cnt += o.(int) + } + + output.Close() + return ctx + }).Build() + + manager, ok := newManager(nil, hb) + assert.True(t, ok) + + ctx = ctxWithManager(ctx, manager) + ctx = OnStart(ctx, 1) + ctx = OnEnd(ctx, 2) + ctx = OnError(ctx, fmt.Errorf("3")) + isr, isw := schema.Pipe[int](2) + go func() { + for i := 0; i < 10; i++ { + isw.Send(i, nil) + } + isw.Close() + }() + + var nisr *schema.StreamReader[int] + ctx, nisr = OnStartWithStreamInput(ctx, isr) + j := 0 + for { + i, err := nisr.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, j, i) + j++ + cnt += i + } + nisr.Close() + + osr, osw := schema.Pipe[int](2) + go func() { + for i := 0; i < 10; i++ { + osw.Send(i, nil) + } + osw.Close() + }() + + var nosr *schema.StreamReader[int] + ctx, nosr = OnEndWithStreamOutput(ctx, osr) + j = 0 + for { + i, err := nosr.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, j, i) + j++ + cnt += i + } + nosr.Close() + assert.Equal(t, 186, cnt) + + assert.True(t, NeededForTiming(ctx, TimingOnStart)) + assert.True(t, NeededForTiming(ctx, TimingOnEnd)) + assert.True(t, NeededForTiming(ctx, TimingOnError)) + assert.True(t, NeededForTiming(ctx, TimingOnStartWithStreamInput)) + assert.True(t, NeededForTiming(ctx, TimingOnEndWithStreamOutput)) + }) +} diff --git a/callbacks/doc.go b/callbacks/doc.go new file mode 100644 index 0000000..1ba9852 --- /dev/null +++ b/callbacks/doc.go @@ -0,0 +1,108 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package callbacks provides callback mechanisms for component execution in Eino. +// +// This package allows you to inject callback handlers at different stages of component execution, +// such as start, end, and error handling. It's particularly useful for implementing governance capabilities like logging, monitoring, and metrics collection. +// +// The package provides two ways to create callback handlers: +// +// 1. Create a callback handler using HandlerBuilder: +// +// handler := callbacks.NewHandlerBuilder(). +// OnStart(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { +// // Handle component start +// return ctx +// }). +// OnEnd(func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { +// // Handle component end +// return ctx +// }). +// OnError(func(ctx context.Context, info *RunInfo, err error) context.Context { +// // Handle component error +// return ctx +// }). +// OnStartWithStreamInput(func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context { +// // Handle component start with stream input +// return ctx +// }). +// OnEndWithStreamOutput(func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context { +// // Handle component end with stream output +// return ctx +// }). +// Build() +// +// For this way, you need to convert the callback input types by yourself, and implement the logic for different component types in one handler. +// +// 2. Use [template.HandlerHelper] to create a handler: +// +// Package template provides [template.HandlerHelper] as a convenient way to build callback handlers +// for different component types. It allows you to set specific handlers for each component type, +// and a fallback handler for unmatched components. +// +// eg. +// +// // Create handlers for specific components +// modelHandler := &model.CallbackHandler{ +// OnStart: func(ctx context.Context, info *RunInfo, input *model.CallbackInput) context.Context { +// log.Printf("Model execution started: %s", info.ComponentName) +// return ctx +// }, +// } +// +// promptHandler := &prompt.CallbackHandler{ +// OnEnd: func(ctx context.Context, info *RunInfo, output *prompt.CallbackOutput) context.Context { +// log.Printf("Prompt execution completed: %s", output.Result) +// return ctx +// }, +// } +// +// // Create a fallback handler for unmatched components +// fallbackHandler := &DefaultCallbackHandler{ +// OnStart: func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { +// log.Printf("Generic component started: %s", info.ComponentName) +// return ctx +// }, +// } +// +// // Build the handler using HandlerHelper +// handler := callbacks.NewHandlerHelper(). +// ChatModel(modelHandler). +// Prompt(promptHandler). +// Fallback(fallbackHandler). +// Handler() +// +// [HandlerHelper] supports handlers for various component types including: +// - Prompt components (via prompt.CallbackHandler) +// - Chat model components (via model.CallbackHandler) +// - Embedding components (via embedding.CallbackHandler) +// - Indexer components (via indexer.CallbackHandler) +// - Retriever components (via retriever.CallbackHandler) +// - Document loader components (via loader.CallbackHandler) +// - Document transformer components (via transformer.CallbackHandler) +// - Tool components (via tool.CallbackHandler) +// - Graph (via template.DefaultCallbackHandler) +// - State graph (via template.DefaultCallbackHandler) +// - Chain (via template.DefaultCallbackHandler) +// - Passthrough (via template.DefaultCallbackHandler) +// - Tools node (via template.DefaultCallbackHandler) +// - Lambda (via template.DefaultCallbackHandler) +// +// Use the handler with a component: +// +// runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) +package callbacks diff --git a/callbacks/handler_builder.go b/callbacks/handler_builder.go new file mode 100644 index 0000000..41c7146 --- /dev/null +++ b/callbacks/handler_builder.go @@ -0,0 +1,191 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package callbacks + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// HandlerBuilder can be used to build a Handler with callback functions. +// e.g. +// +// handler := &HandlerBuilder{ +// OnStartFn: func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context {} // self defined start callback function +// } +// +// graph := compose.NewGraph[inputType, outputType]() +// runnable, err := graph.Compile() +// if err != nil {...} +// runnable.Invoke(ctx, params, compose.WithCallback(handler)) // => only implement functions which you want to override +// +// Deprecated: In most situations, it is preferred to use template.NewHandlerHelper. Otherwise, use NewHandlerBuilder().OnStartFn()...Build(). +type HandlerBuilder struct { + OnStartFn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context + OnEndFn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context + OnErrorFn func(ctx context.Context, info *RunInfo, err error) context.Context + OnStartWithStreamInputFn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context + OnEndWithStreamOutputFn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context +} + +func (h *HandlerBuilder) OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { + if h.OnStartFn != nil { + return h.OnStartFn(ctx, info, input) + } + + return ctx +} + +func (h *HandlerBuilder) OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { + if h.OnEndFn != nil { + return h.OnEndFn(ctx, info, output) + } + + return ctx +} + +func (h *HandlerBuilder) OnError(ctx context.Context, info *RunInfo, err error) context.Context { + if h.OnErrorFn != nil { + return h.OnErrorFn(ctx, info, err) + } + + return ctx +} + +func (h *HandlerBuilder) OnStartWithStreamInput(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context { + if h.OnStartWithStreamInputFn != nil { + return h.OnStartWithStreamInputFn(ctx, info, input) + } + + input.Close() + + return ctx +} + +func (h *HandlerBuilder) OnEndWithStreamOutput(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context { + if h.OnEndWithStreamOutputFn != nil { + return h.OnEndWithStreamOutputFn(ctx, info, output) + } + + output.Close() + + return ctx +} + +type handlerBuilder struct { + onStartFn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context + onEndFn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context + onErrorFn func(ctx context.Context, info *RunInfo, err error) context.Context + onStartWithStreamInputFn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context + onEndWithStreamOutputFn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context +} + +func (hb *handlerBuilder) OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { + if hb.onStartFn != nil { + return hb.onStartFn(ctx, info, input) + } + + return ctx +} + +func (hb *handlerBuilder) OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { + if hb.onEndFn != nil { + return hb.onEndFn(ctx, info, output) + } + + return ctx +} + +func (hb *handlerBuilder) OnError(ctx context.Context, info *RunInfo, err error) context.Context { + if hb.onErrorFn != nil { + return hb.onErrorFn(ctx, info, err) + } + + return ctx +} + +func (hb *handlerBuilder) OnStartWithStreamInput(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context { + if hb.onStartWithStreamInputFn != nil { + return hb.onStartWithStreamInputFn(ctx, info, input) + } + + input.Close() + + return ctx +} + +func (hb *handlerBuilder) OnEndWithStreamOutput(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context { + if hb.onEndWithStreamOutputFn != nil { + return hb.onEndWithStreamOutputFn(ctx, info, output) + } + + output.Close() + + return ctx +} + +func (hb *handlerBuilder) Needed(_ context.Context, _ *RunInfo, timing CallbackTiming) bool { + switch timing { + case TimingOnStart: + return hb.onStartFn != nil + case TimingOnEnd: + return hb.onEndFn != nil + case TimingOnError: + return hb.onErrorFn != nil + case TimingOnStartWithStreamInput: + return hb.onStartWithStreamInputFn != nil + case TimingOnEndWithStreamOutput: + return hb.onEndWithStreamOutputFn != nil + default: + return false + } +} + +func NewHandlerBuilder() *handlerBuilder { + return &handlerBuilder{} +} + +func (hb *handlerBuilder) OnStartFn(fn func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context) *handlerBuilder { + hb.onStartFn = fn + return hb +} + +func (hb *handlerBuilder) OnEndFn(fn func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context) *handlerBuilder { + hb.onEndFn = fn + return hb +} + +func (hb *handlerBuilder) OnErrorFn(fn func(ctx context.Context, info *RunInfo, err error) context.Context) *handlerBuilder { + hb.onErrorFn = fn + return hb +} + +func (hb *handlerBuilder) OnStartWithStreamInputFn(fn func(ctx context.Context, info *RunInfo, input *schema.StreamReader[CallbackInput]) context.Context) *handlerBuilder { + hb.onStartWithStreamInputFn = fn + return hb +} + +func (hb *handlerBuilder) OnEndWithStreamOutputFn(fn func(ctx context.Context, info *RunInfo, output *schema.StreamReader[CallbackOutput]) context.Context) *handlerBuilder { + hb.onEndWithStreamOutputFn = fn + return hb +} + +// Build returns a Handler with the functions set in the builder. +func (hb *handlerBuilder) Build() Handler { + return hb +} diff --git a/callbacks/interface.go b/callbacks/interface.go new file mode 100644 index 0000000..def6efa --- /dev/null +++ b/callbacks/interface.go @@ -0,0 +1,100 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package callbacks + +import ( + "context" + + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/schema" +) + +// RunInfo is the info of run node. +type RunInfo struct { + Name string + Type string + Component components.Component +} + +// CallbackInput is the input of the callback. +// the type of input is defined by the component. +// using type Assert or convert func to convert the input to the right type you want. +// e.g. +// +// CallbackInput in components/model/interface.go is: +// type CallbackInput struct { +// Messages []*schema.Message +// Config *Config +// Extra map[string]any +// } +// +// and provide a func of model.ConvCallbackInput() to convert CallbackInput to *model.CallbackInput +// in callback handler, you can use the following code to get the input: +// +// modelCallbackInput := model.ConvCallbackInput(in) +// if modelCallbackInput == nil { +// // is not a model callback input, just ignore it +// return +// } +type CallbackInput any + +type CallbackOutput any + +type Handler interface { + OnStart(ctx context.Context, info *RunInfo, input CallbackInput) context.Context + OnEnd(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context + + OnError(ctx context.Context, info *RunInfo, err error) context.Context + + OnStartWithStreamInput(ctx context.Context, info *RunInfo, + input *schema.StreamReader[CallbackInput]) context.Context + OnEndWithStreamOutput(ctx context.Context, info *RunInfo, + output *schema.StreamReader[CallbackOutput]) context.Context +} + +var globalHandlers []Handler + +// InitCallbackHandlers sets the global callback handlers. +// It should be called BEFORE any callback handler by user. +// It's useful when you want to inject some basic callbacks to all nodes. +func InitCallbackHandlers(handlers []Handler) { + globalHandlers = handlers +} + +func GetGlobalHandlers() []Handler { + return globalHandlers +} + +// CallbackTiming enumerates all the timing of callback aspects. +type CallbackTiming uint8 + +const ( + TimingOnStart CallbackTiming = iota + TimingOnEnd + TimingOnError + TimingOnStartWithStreamInput + TimingOnEndWithStreamOutput +) + +// TimingChecker checks if the handler is needed for the given callback aspect timing. +// It's recommended for callback handlers to implement this interface, but not mandatory. +// If a callback handler is created by using template.HandlerHelper or handlerBuilder, then this interface is automatically implemented. +// Eino's callback mechanism will try to use this interface to determine whether any handlers are needed for the given timing. +// Also, the callback handler that is not needed for that timing will be skipped. +type TimingChecker interface { + Needed(ctx context.Context, info *RunInfo, timing CallbackTiming) bool +} diff --git a/callbacks/internal/manager.go b/callbacks/internal/manager.go new file mode 100644 index 0000000..c6302b2 --- /dev/null +++ b/callbacks/internal/manager.go @@ -0,0 +1,21 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package internal + +type CtxManagerKey struct{} + +// TODO: move callback manager to internal diff --git a/callbacks/manager.go b/callbacks/manager.go new file mode 100644 index 0000000..90237c1 --- /dev/null +++ b/callbacks/manager.go @@ -0,0 +1,210 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package callbacks + +import ( + "context" + + "github.com/cloudwego/eino/callbacks/internal" + "github.com/cloudwego/eino/schema" +) + +// Manager is a callback manager of one running node. +// Deprecated: Manager will become the inner conception, use methods in aspect_inject.go instead +type Manager struct { + *manager +} + +type manager struct { + handlers []Handler + runInfo *RunInfo +} + +// NewManager creates a callback manager. +// It will return a nil manager if no callback handler is provided, please check the return value first before using. +// Deprecated: Manager will become the inner conception, use methods in aspect_inject.go instead +func NewManager(runInfo *RunInfo, handlers ...Handler) (*Manager, bool) { + m, ok := newManager(runInfo, handlers...) + if !ok { + return nil, false + } + + return &Manager{ + manager: m, + }, true +} + +func newManager(runInfo *RunInfo, handlers ...Handler) (*manager, bool) { + l := len(handlers) + len(globalHandlers) + if l == 0 { + return nil, false + } + hs := make([]Handler, 0, l) + hs = append(hs, globalHandlers...) + hs = append(hs, handlers...) + + return &manager{ + handlers: hs, + runInfo: runInfo, + }, true +} + +func (m *manager) Handlers() []Handler { + return m.handlers +} + +// Deprecated: Manager will become the inner conception, use methods in aspect_inject.go instead +func (mm *Manager) WithRunInfo(runInfo *RunInfo) *Manager { + if mm == nil { + return nil + } + + m := mm.manager.withRunInfo(runInfo) + + return &Manager{ + manager: m, + } +} + +func (m *manager) withRunInfo(runInfo *RunInfo) *manager { + if m == nil { + return nil + } + + return &manager{ + handlers: m.handlers, + runInfo: runInfo, + } +} + +// Deprecated: Manager will become the inner conception, use methods in aspect_inject.go instead +func ManagerFromCtx(ctx context.Context) (*Manager, bool) { + internalM, ok := managerFromCtx(ctx) + if ok { + return &Manager{ + manager: internalM, + }, true + } + + return nil, false +} + +func managerFromCtx(ctx context.Context) (*manager, bool) { + m, ok := ctx.Value(internal.CtxManagerKey{}).(*manager) + if ok && m != nil { + return &manager{ + handlers: m.handlers, + runInfo: m.runInfo, + }, true + } + + return nil, false +} + +// Deprecated: Manager will become the inner conception, use methods in aspect_inject.go instead +func CtxWithManager(ctx context.Context, manager *Manager) context.Context { + return ctxWithManager(ctx, manager.manager) +} + +func ctxWithManager(ctx context.Context, manager *manager) context.Context { + return context.WithValue(ctx, internal.CtxManagerKey{}, manager) +} + +func (m *manager) OnStart(ctx context.Context, input CallbackInput) context.Context { + if m == nil { + return ctx + } + + for i := len(m.handlers) - 1; i >= 0; i-- { + handler := m.handlers[i] + ctx = handler.OnStart(ctx, m.runInfo, input) + } + + return ctx +} + +func (m *manager) OnEnd(ctx context.Context, output CallbackOutput) context.Context { + if m == nil { + return ctx + } + + for i := 0; i < len(m.handlers); i++ { + handler := m.handlers[i] + ctx = handler.OnEnd(ctx, m.runInfo, output) + } + + return ctx +} + +func (m *manager) OnError(ctx context.Context, err error) context.Context { + if m == nil { + return ctx + } + + for i := 0; i < len(m.handlers); i++ { + handler := m.handlers[i] + ctx = handler.OnError(ctx, m.runInfo, err) + } + + return ctx +} + +func (m *manager) OnStartWithStreamInput( + ctx context.Context, input *schema.StreamReader[CallbackInput]) context.Context { + if m == nil { + if input != nil { + input.Close() + } + return ctx + } + + if len(m.handlers) == 0 { + input.Close() + return ctx + } + + ins := input.Copy(len(m.handlers)) + for i := len(m.handlers) - 1; i >= 0; i-- { + handler := m.handlers[i] + ctx = handler.OnStartWithStreamInput(ctx, m.runInfo, ins[i]) + } + + return ctx +} + +func (m *manager) OnEndWithStreamOutput( + ctx context.Context, output *schema.StreamReader[CallbackOutput]) context.Context { + if m == nil { + if output != nil { + output.Close() + } + return ctx + } + + if len(m.handlers) == 0 { + output.Close() + return ctx + } + + outs := output.Copy(len(m.handlers)) + for i := 0; i < len(m.handlers); i++ { + handler := m.handlers[i] + ctx = handler.OnEndWithStreamOutput(ctx, m.runInfo, outs[i]) + } + + return ctx +} diff --git a/callbacks/manager_test.go b/callbacks/manager_test.go new file mode 100644 index 0000000..7a47c29 --- /dev/null +++ b/callbacks/manager_test.go @@ -0,0 +1,142 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package callbacks + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestManager(t *testing.T) { + + t.Run("usable_manager", func(t *testing.T) { + defer func() { + globalHandlers = nil + }() + + var startCnt, endCnt, errCnt int + var globalKey, sessionKey = "global", "session" + + globalHandlers = []Handler{ + NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { + startCnt++ + return context.WithValue(ctx, globalKey, "start") + }). + OnEndFn(func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { + if ctx.Value(globalKey).(string) == "start" { + endCnt++ + return context.WithValue(ctx, globalKey, "end") + } + return ctx + }). + OnErrorFn(func(ctx context.Context, info *RunInfo, err error) context.Context { + if ctx.Value(globalKey).(string) == "start" { + errCnt++ + return context.WithValue(ctx, globalKey, "error") + } + return ctx + }).Build(), + } + + sessionHandler := NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *RunInfo, input CallbackInput) context.Context { + startCnt++ + return context.WithValue(ctx, sessionKey, "start") + }). + OnEndFn(func(ctx context.Context, info *RunInfo, output CallbackOutput) context.Context { + if ctx.Value(sessionKey).(string) == "start" { + endCnt++ + return context.WithValue(ctx, sessionKey, "end") + } + return ctx + }). + OnErrorFn(func(ctx context.Context, info *RunInfo, err error) context.Context { + if ctx.Value(sessionKey).(string) == "start" { + errCnt++ + return context.WithValue(ctx, sessionKey, "error") + } + return ctx + }).Build() + + manager, ok := newManager(&RunInfo{}, sessionHandler) + assert.True(t, ok) + c0 := context.Background() + c1 := ctxWithManager(c0, manager) + c2 := ctxWithManager(c1, nil) + + m0, ok := managerFromCtx(c0) + assert.False(t, ok) + assert.Nil(t, m0) + + m1, ok := managerFromCtx(c1) + assert.True(t, ok) + assert.NotNil(t, m1) + m2, ok := managerFromCtx(c2) + assert.False(t, ok) + assert.Nil(t, m2) + + c3 := manager.OnStart(context.Background(), nil) + c4 := manager.OnError(c3, fmt.Errorf("mock err")) + c5 := manager.OnEnd(c3, nil) + assert.Equal(t, startCnt, 2) + assert.Equal(t, endCnt, 2) + assert.Equal(t, errCnt, 2) + assert.Equal(t, c3.Value(globalKey).(string), "start") + assert.Equal(t, c3.Value(sessionKey).(string), "start") + assert.Equal(t, c4.Value(globalKey).(string), "error") + assert.Equal(t, c4.Value(sessionKey).(string), "error") + assert.Equal(t, c5.Value(globalKey).(string), "end") + assert.Equal(t, c5.Value(sessionKey).(string), "end") + }) + + t.Run("empty manager", func(t *testing.T) { + ctx := context.Background() + globalHandlers = nil + mgr, ok := newManager(nil) + assert.False(t, ok) + assert.Nil(t, mgr) + + nCtx := mgr.OnStart(ctx, nil) + assert.IsType(t, ctx, nCtx) + + ctx = mgr.OnEnd(ctx, nil) + assert.IsType(t, ctx, nCtx) + + ctx = mgr.OnError(ctx, fmt.Errorf("mock err")) + assert.IsType(t, ctx, nCtx) + + sri, _ := schema.Pipe[CallbackInput](1) + ctx = mgr.OnStartWithStreamInput(ctx, sri) + assert.IsType(t, ctx, nCtx) + + sro, _ := schema.Pipe[CallbackOutput](1) + ctx = mgr.OnEndWithStreamOutput(ctx, sro) + assert.IsType(t, ctx, nCtx) + + ctx = mgr.OnStartWithStreamInput(ctx, nil) + assert.IsType(t, ctx, nCtx) + + ctx = mgr.OnEndWithStreamOutput(ctx, nil) + assert.IsType(t, ctx, nCtx) + }) +} diff --git a/callbacks/template/default.go b/callbacks/template/default.go new file mode 100644 index 0000000..bf84c5c --- /dev/null +++ b/callbacks/template/default.go @@ -0,0 +1,51 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package template + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// DefaultCallbackHandler is the default callback handler implementation, can be used for callback handler builder in template.HandlerHelper (for example, Graph, StateGraph, Chain, Lambda, etc.). +type DefaultCallbackHandler struct { + OnStart func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context + OnStartWithStreamInput func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context + OnEnd func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context + OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (d *DefaultCallbackHandler) Needed(_ context.Context, _ *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return d.OnStart != nil + case callbacks.TimingOnEnd: + return d.OnEnd != nil + case callbacks.TimingOnError: + return d.OnError != nil + case callbacks.TimingOnStartWithStreamInput: + return d.OnStartWithStreamInput != nil + case callbacks.TimingOnEndWithStreamOutput: + return d.OnEndWithStreamOutput != nil + default: + return false + } +} diff --git a/callbacks/template/template.go b/callbacks/template/template.go new file mode 100644 index 0000000..b53d5a6 --- /dev/null +++ b/callbacks/template/template.go @@ -0,0 +1,530 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package template + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +// NewHandlerHelper creates a new component template handler builder. +// This builder can be used to configure and build a component template handler, +// which can handle callback events for different components with its own struct definition, +// and fallbackTemplate can be used to handle scenarios where none of the cases are hit as a fallback. +func NewHandlerHelper() *HandlerHelper { + return &HandlerHelper{ + composeTemplates: map[components.Component]*DefaultCallbackHandler{}, + } +} + +// HandlerHelper is a builder for creating a callbacks.Handler with specific handlers for different component types. +// create a handler with template.NewHandlerHelper(). +// eg. +// +// helper := template.NewHandlerHelper(). +// ChatModel(&model.CallbackHandler{}). +// Prompt(&prompt.CallbackHandler{}). +// Handler() +// +// then use the handler with runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) +type HandlerHelper struct { + promptHandler *prompt.CallbackHandler + chatModelHandler *model.CallbackHandler + embeddingHandler *embedding.CallbackHandler + indexerHandler *indexer.CallbackHandler + retrieverHandler *retriever.CallbackHandler + loaderHandler *document.LoaderCallbackHandler + transformerHandler *document.TransformerCallbackHandler + toolHandler *tool.CallbackHandler + composeTemplates map[components.Component]*DefaultCallbackHandler + fallbackTemplate *DefaultCallbackHandler // execute when not matching any other condition +} + +// Handler returns the callbacks.Handler created by HandlerHelper. +func (c *HandlerHelper) Handler() callbacks.Handler { + return &handlerTemplate{c} +} + +// Prompt sets the prompt handler for the handler helper, which will be called when the prompt component is executed. +func (c *HandlerHelper) Prompt(handler *prompt.CallbackHandler) *HandlerHelper { + c.promptHandler = handler + return c +} + +// ChatModel sets the chat model handler for the handler helper, which will be called when the chat model component is executed. +func (c *HandlerHelper) ChatModel(handler *model.CallbackHandler) *HandlerHelper { + c.chatModelHandler = handler + return c +} + +// Embedding sets the embedding handler for the handler helper, which will be called when the embedding component is executed. +func (c *HandlerHelper) Embedding(handler *embedding.CallbackHandler) *HandlerHelper { + c.embeddingHandler = handler + return c +} + +// Indexer sets the indexer handler for the handler helper, which will be called when the indexer component is executed. +func (c *HandlerHelper) Indexer(handler *indexer.CallbackHandler) *HandlerHelper { + c.indexerHandler = handler + return c +} + +// Retriever sets the retriever handler for the handler helper, which will be called when the retriever component is executed. +func (c *HandlerHelper) Retriever(handler *retriever.CallbackHandler) *HandlerHelper { + c.retrieverHandler = handler + return c +} + +// Loader sets the loader handler for the handler helper, which will be called when the loader component is executed. +func (c *HandlerHelper) Loader(handler *document.LoaderCallbackHandler) *HandlerHelper { + c.loaderHandler = handler + return c +} + +// Transformer sets the transformer handler for the handler helper, which will be called when the transformer component is executed. +func (c *HandlerHelper) Transformer(handler *document.TransformerCallbackHandler) *HandlerHelper { + c.transformerHandler = handler + return c +} + +// Tool sets the tool handler for the handler helper, which will be called when the tool component is executed. +func (c *HandlerHelper) Tool(handler *tool.CallbackHandler) *HandlerHelper { + c.toolHandler = handler + return c +} + +// Graph sets the graph handler for the handler helper, which will be called when the graph is executed. +func (c *HandlerHelper) Graph(handler *DefaultCallbackHandler) *HandlerHelper { + c.composeTemplates[compose.ComponentOfGraph] = handler + return c +} + +// StateGraph sets the state graph handler for the handler helper, which will be called when the state graph is executed. +func (c *HandlerHelper) StateGraph(handler *DefaultCallbackHandler) *HandlerHelper { + c.composeTemplates[compose.ComponentOfStateGraph] = handler + return c +} + +// Chain sets the chain handler for the handler helper, which will be called when the chain is executed. +func (c *HandlerHelper) Chain(handler *DefaultCallbackHandler) *HandlerHelper { + c.composeTemplates[compose.ComponentOfChain] = handler + return c +} + +// Passthrough sets the passthrough handler for the handler helper, which will be called when the passthrough is executed. +func (c *HandlerHelper) Passthrough(handler *DefaultCallbackHandler) *HandlerHelper { + c.composeTemplates[compose.ComponentOfPassthrough] = handler + return c +} + +// ToolsNode sets the tools node handler for the handler helper, which will be called when the tools node is executed. +func (c *HandlerHelper) ToolsNode(handler *DefaultCallbackHandler) *HandlerHelper { + c.composeTemplates[compose.ComponentOfToolsNode] = handler + return c +} + +// Lambda sets the lambda handler for the handler helper, which will be called when the lambda is executed. +func (c *HandlerHelper) Lambda(handler *DefaultCallbackHandler) *HandlerHelper { + c.composeTemplates[compose.ComponentOfLambda] = handler + return c +} + +// Fallback sets the fallback handler for the handler helper, which will be called when no other handlers are matched. +func (c *HandlerHelper) Fallback(handler *DefaultCallbackHandler) *HandlerHelper { + c.fallbackTemplate = handler + return c +} + +type handlerTemplate struct { + *HandlerHelper +} + +// OnStart is the callback function for the start event of a component. +// implement the callbacks Handler interface. +func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + if info == nil { + return ctx + } + + match := false + + switch info.Component { + case components.ComponentOfPrompt: + if c.promptHandler != nil && c.promptHandler.OnStart != nil { + match = true + ctx = c.promptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) + } + case components.ComponentOfChatModel: + if c.chatModelHandler != nil && c.chatModelHandler.OnStart != nil { + match = true + ctx = c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + } + case components.ComponentOfEmbedding: + if c.embeddingHandler != nil && c.embeddingHandler.OnStart != nil { + match = true + ctx = c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) + } + case components.ComponentOfIndexer: + if c.indexerHandler != nil && c.indexerHandler.OnStart != nil { + match = true + ctx = c.indexerHandler.OnStart(ctx, info, indexer.ConvCallbackInput(input)) + } + case components.ComponentOfRetriever: + if c.retrieverHandler != nil && c.retrieverHandler.OnStart != nil { + match = true + ctx = c.retrieverHandler.OnStart(ctx, info, retriever.ConvCallbackInput(input)) + } + case components.ComponentOfLoader: + if c.loaderHandler != nil && c.loaderHandler.OnStart != nil { + match = true + ctx = c.loaderHandler.OnStart(ctx, info, document.ConvLoaderCallbackInput(input)) + } + case components.ComponentOfTransformer: + if c.transformerHandler != nil && c.transformerHandler.OnStart != nil { + match = true + ctx = c.transformerHandler.OnStart(ctx, info, document.ConvTransformerCallbackInput(input)) + } + case components.ComponentOfTool: + if c.toolHandler != nil && c.toolHandler.OnStart != nil { + match = true + ctx = c.toolHandler.OnStart(ctx, info, tool.ConvCallbackInput(input)) + } + case compose.ComponentOfGraph, + compose.ComponentOfStateGraph, + compose.ComponentOfChain, + compose.ComponentOfPassthrough, + compose.ComponentOfToolsNode, + compose.ComponentOfLambda: + + if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnStart != nil { + match = true + ctx = c.composeTemplates[info.Component].OnStart(ctx, info, input) + } + default: + + } + + if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnStart != nil { + ctx = c.fallbackTemplate.OnStart(ctx, info, input) + } + + return ctx +} + +// OnEnd is the callback function for the end event of a component. +// implement the callbacks Handler interface. +func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + if info == nil { + return ctx + } + + match := false + + switch info.Component { + case components.ComponentOfPrompt: + if c.promptHandler != nil && c.promptHandler.OnEnd != nil { + match = true + ctx = c.promptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) + } + case components.ComponentOfChatModel: + if c.chatModelHandler != nil && c.chatModelHandler.OnEnd != nil { + match = true + ctx = c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + } + case components.ComponentOfEmbedding: + if c.embeddingHandler != nil && c.embeddingHandler.OnEnd != nil { + match = true + ctx = c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) + } + case components.ComponentOfIndexer: + if c.indexerHandler != nil && c.indexerHandler.OnEnd != nil { + match = true + ctx = c.indexerHandler.OnEnd(ctx, info, indexer.ConvCallbackOutput(output)) + } + case components.ComponentOfRetriever: + if c.retrieverHandler != nil && c.retrieverHandler.OnEnd != nil { + match = true + ctx = c.retrieverHandler.OnEnd(ctx, info, retriever.ConvCallbackOutput(output)) + } + case components.ComponentOfLoader: + if c.loaderHandler != nil && c.loaderHandler.OnEnd != nil { + match = true + ctx = c.loaderHandler.OnEnd(ctx, info, document.ConvLoaderCallbackOutput(output)) + } + case components.ComponentOfTransformer: + if c.transformerHandler != nil && c.transformerHandler.OnEnd != nil { + match = true + ctx = c.transformerHandler.OnEnd(ctx, info, document.ConvTransformerCallbackOutput(output)) + } + case components.ComponentOfTool: + if c.toolHandler != nil && c.toolHandler.OnEnd != nil { + match = true + ctx = c.toolHandler.OnEnd(ctx, info, tool.ConvCallbackOutput(output)) + } + case compose.ComponentOfGraph, + compose.ComponentOfStateGraph, + compose.ComponentOfChain, + compose.ComponentOfPassthrough, + compose.ComponentOfToolsNode, + compose.ComponentOfLambda: + + if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnEnd != nil { + match = true + ctx = c.composeTemplates[info.Component].OnEnd(ctx, info, output) + } + default: + + } + + if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnEnd != nil { + ctx = c.fallbackTemplate.OnEnd(ctx, info, output) + } + + return ctx +} + +// OnError is the callback function for the error event of a component. +// implement the callbacks Handler interface. +func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + if info == nil { + return ctx + } + + match := false + + switch info.Component { + case components.ComponentOfPrompt: + if c.promptHandler != nil && c.promptHandler.OnError != nil { + match = true + ctx = c.promptHandler.OnError(ctx, info, err) + } + case components.ComponentOfChatModel: + if c.chatModelHandler != nil && c.chatModelHandler.OnError != nil { + match = true + ctx = c.chatModelHandler.OnError(ctx, info, err) + } + case components.ComponentOfEmbedding: + if c.embeddingHandler != nil && c.embeddingHandler.OnError != nil { + match = true + ctx = c.embeddingHandler.OnError(ctx, info, err) + } + case components.ComponentOfIndexer: + if c.indexerHandler != nil && c.indexerHandler.OnError != nil { + match = true + ctx = c.indexerHandler.OnError(ctx, info, err) + } + case components.ComponentOfRetriever: + if c.retrieverHandler != nil && c.retrieverHandler.OnError != nil { + match = true + ctx = c.retrieverHandler.OnError(ctx, info, err) + } + case components.ComponentOfLoader: + if c.loaderHandler != nil && c.loaderHandler.OnError != nil { + match = true + ctx = c.loaderHandler.OnError(ctx, info, err) + } + case components.ComponentOfTransformer: + if c.transformerHandler != nil && c.transformerHandler.OnError != nil { + match = true + ctx = c.transformerHandler.OnError(ctx, info, err) + } + case components.ComponentOfTool: + if c.toolHandler != nil && c.toolHandler.OnError != nil { + match = true + ctx = c.toolHandler.OnError(ctx, info, err) + } + case compose.ComponentOfGraph, + compose.ComponentOfStateGraph, + compose.ComponentOfChain, + compose.ComponentOfPassthrough, + compose.ComponentOfToolsNode, + compose.ComponentOfLambda: + + if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnError != nil { + match = true + ctx = c.composeTemplates[info.Component].OnError(ctx, info, err) + } + default: + + } + + if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnError != nil { + ctx = c.fallbackTemplate.OnError(ctx, info, err) + } + + return ctx +} + +// OnStartWithStreamInput is the callback function for the start event of a component with stream input. +// implement the callbacks Handler interface. +func (c *handlerTemplate) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { + match := false + defer func() { + if !match { + input.Close() + } + }() + + if info == nil { + return ctx + } + + switch info.Component { + // currently no components.Component receive stream as input + case compose.ComponentOfGraph, + compose.ComponentOfStateGraph, + compose.ComponentOfChain, + compose.ComponentOfPassthrough, + compose.ComponentOfToolsNode, + compose.ComponentOfLambda: + if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnStartWithStreamInput != nil { + match = true + ctx = c.composeTemplates[info.Component].OnStartWithStreamInput(ctx, info, input) + } + default: + + } + + if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnStartWithStreamInput != nil { + match = true + ctx = c.fallbackTemplate.OnStartWithStreamInput(ctx, info, input) + } + + return ctx +} + +// OnEndWithStreamOutput is the callback function for the end event of a component with stream output. +// implement the callbacks Handler interface. +func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { + match := false + defer func() { + if !match { + output.Close() + } + }() + + if info == nil { + return ctx + } + + switch info.Component { + case components.ComponentOfChatModel: + if c.chatModelHandler != nil && c.chatModelHandler.OnEndWithStreamOutput != nil { + match = true + ctx = c.chatModelHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { + return model.ConvCallbackOutput(item), nil + })) + } + + case components.ComponentOfTool: + if c.toolHandler != nil && c.toolHandler.OnEndWithStreamOutput != nil { + match = true + ctx = c.toolHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*tool.CallbackOutput, error) { + return tool.ConvCallbackOutput(item), nil + })) + } + case compose.ComponentOfGraph, + compose.ComponentOfStateGraph, + compose.ComponentOfChain, + compose.ComponentOfPassthrough, + compose.ComponentOfToolsNode, + compose.ComponentOfLambda: + if c.composeTemplates[info.Component] != nil && c.composeTemplates[info.Component].OnEndWithStreamOutput != nil { + match = true + ctx = c.composeTemplates[info.Component].OnEndWithStreamOutput(ctx, info, output) + } + + default: + + } + + if !match && c.fallbackTemplate != nil && c.fallbackTemplate.OnEndWithStreamOutput != nil { + match = true + ctx = c.fallbackTemplate.OnEndWithStreamOutput(ctx, info, output) + } + + return ctx +} + +// Needed checks if the callback handler is needed for the given timing. +func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch info.Component { + case components.ComponentOfChatModel: + if c.chatModelHandler != nil && c.chatModelHandler.Needed(ctx, info, timing) { + return true + } + case components.ComponentOfEmbedding: + if c.embeddingHandler != nil && c.embeddingHandler.Needed(ctx, info, timing) { + return true + } + case components.ComponentOfIndexer: + if c.indexerHandler != nil && c.indexerHandler.Needed(ctx, info, timing) { + return true + } + case components.ComponentOfLoader: + if c.loaderHandler != nil && c.loaderHandler.Needed(ctx, info, timing) { + return true + } + case components.ComponentOfPrompt: + if c.promptHandler != nil && c.promptHandler.Needed(ctx, info, timing) { + return true + } + case components.ComponentOfRetriever: + if c.retrieverHandler != nil && c.retrieverHandler.Needed(ctx, info, timing) { + return true + } + case components.ComponentOfTool: + if c.toolHandler != nil && c.toolHandler.Needed(ctx, info, timing) { + return true + } + case components.ComponentOfTransformer: + if c.transformerHandler != nil && c.transformerHandler.Needed(ctx, info, timing) { + return true + } + case compose.ComponentOfGraph, + compose.ComponentOfStateGraph, + compose.ComponentOfChain, + compose.ComponentOfPassthrough, + compose.ComponentOfToolsNode, + compose.ComponentOfLambda: + template := c.composeTemplates[info.Component] + if template != nil && template.Needed(ctx, info, timing) { + return true + } + default: + + } + + if c.fallbackTemplate != nil { + return c.fallbackTemplate.Needed(ctx, info, timing) + } + + return false +} diff --git a/callbacks/template/template_test.go b/callbacks/template/template_test.go new file mode 100644 index 0000000..c9f03b5 --- /dev/null +++ b/callbacks/template/template_test.go @@ -0,0 +1,335 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package template + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +func TestNewComponentTemplate(t *testing.T) { + t.Run("test no fallback", func(t *testing.T) { + cnt := 0 + tpl := NewHandlerHelper() + tpl.ChatModel(&model.CallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }}). + Embedding(&embedding.CallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *embedding.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *embedding.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + Prompt(&prompt.CallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + Retriever(&retriever.CallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *retriever.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + Tool(&tool.CallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *tool.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *tool.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*tool.CallbackOutput]) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + Lambda(&DefaultCallbackHandler{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnStartWithStreamInput: func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { + input.Close() + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }) + + typs := []components.Component{ + components.ComponentOfPrompt, + components.ComponentOfLoaderSplitter, + components.ComponentOfChatModel, + components.ComponentOfEmbedding, + components.ComponentOfRetriever, + components.ComponentOfTool, + compose.ComponentOfLambda, + } + + handler := tpl.Handler() + ctx := context.Background() + for _, typ := range typs { + handler.OnStart(ctx, &callbacks.RunInfo{Component: typ}, nil) + handler.OnEnd(ctx, &callbacks.RunInfo{Component: typ}, nil) + handler.OnError(ctx, &callbacks.RunInfo{Component: typ}, fmt.Errorf("mock err")) + + sir, siw := schema.Pipe[callbacks.CallbackInput](1) + siw.Close() + handler.OnStartWithStreamInput(ctx, &callbacks.RunInfo{Component: typ}, sir) + + sor, sow := schema.Pipe[callbacks.CallbackOutput](1) + sow.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: typ}, sor) + } + + assert.Equal(t, 22, cnt) + + ctx = context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) + callbacks.OnStart(ctx, nil) + assert.Equal(t, 22, cnt) + + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) + callbacks.OnStart(ctx, nil) + assert.Equal(t, 23, cnt) + + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) + callbacks.OnEnd(ctx, nil) + assert.Equal(t, 23, cnt) + + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) + callbacks.OnError(ctx, nil) + assert.Equal(t, 24, cnt) + + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) + callbacks.OnStart(ctx, nil) + assert.Equal(t, 24, cnt) + + tpl.Transformer(&document.TransformerCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *document.TransformerCallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }).Indexer(&indexer.CallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *indexer.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *indexer.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }).Loader(&document.LoaderCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *document.LoaderCallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }) + + handler = tpl.Handler() + ctx = context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) + callbacks.OnEnd(ctx, nil) + assert.Equal(t, 25, cnt) + + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) + callbacks.OnStart(ctx, nil) + assert.Equal(t, 26, cnt) + + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) + callbacks.OnEnd(ctx, nil) + assert.Equal(t, 27, cnt) + }) + + t.Run("test fallback", func(t *testing.T) { + cnt, cntf := 0, 0 + tpl := NewHandlerHelper(). + Retriever(&retriever.CallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *retriever.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + Fallback(&DefaultCallbackHandler{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + cntf++ + return ctx + }, + OnStartWithStreamInput: func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { + input.Close() + cntf++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + cntf++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context { + output.Close() + cntf++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cntf++ + return ctx + }, + }) + + handler := tpl.Handler() + ctx := context.Background() + handler.OnStart(ctx, &callbacks.RunInfo{ + Component: compose.ComponentOfLambda, + }, nil) + + handler.OnEnd(ctx, &callbacks.RunInfo{ + Component: compose.ComponentOfLambda, + }, nil) + + handler.OnError(ctx, &callbacks.RunInfo{ + Component: compose.ComponentOfLambda, + }, fmt.Errorf("mock err")) + + sir, siw := schema.Pipe[callbacks.CallbackInput](1) + siw.Close() + handler.OnStartWithStreamInput(ctx, &callbacks.RunInfo{ + Component: compose.ComponentOfLambda, + }, sir) + + sor, sow := schema.Pipe[callbacks.CallbackOutput](1) + sow.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{ + Component: compose.ComponentOfLambda, + }, sor) + + assert.Equal(t, 0, cnt) + assert.Equal(t, 5, cntf) + + ctx = context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, handler) + callbacks.OnStart(ctx, nil) + callbacks.OnEnd(ctx, nil) + callbacks.OnError(ctx, nil) + callbacks.OnStartWithStreamInput(ctx, &schema.StreamReader[callbacks.CallbackInput]{}) + callbacks.OnEndWithStreamOutput(ctx, &schema.StreamReader[callbacks.CallbackOutput]{}) + assert.Equal(t, 10, cntf) + + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}) + callbacks.OnStart(ctx, nil) + assert.Equal(t, 1, cnt) + }) +} diff --git a/components/document/callback_extra_loader.go b/components/document/callback_extra_loader.go new file mode 100644 index 0000000..61a3265 --- /dev/null +++ b/components/document/callback_extra_loader.go @@ -0,0 +1,94 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package document + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// LoaderCallbackInput is the input for the loader callback. +type LoaderCallbackInput struct { + // Source is the source of the documents. + Source Source + + // Extra is the extra information for the callback. + Extra map[string]any +} + +// LoaderCallbackOutput is the output for the loader callback. +type LoaderCallbackOutput struct { + // Source is the source of the documents. + Source Source + + // Docs is the documents to be loaded. + Docs []*schema.Document + + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvLoaderCallbackInput converts the callback input to the loader callback input. +func ConvLoaderCallbackInput(src callbacks.CallbackInput) *LoaderCallbackInput { + switch t := src.(type) { + case *LoaderCallbackInput: + return t + case Source: + return &LoaderCallbackInput{ + Source: t, + } + default: + return nil + } +} + +// ConvLoaderCallbackOutput converts the callback output to the loader callback output. +func ConvLoaderCallbackOutput(src callbacks.CallbackOutput) *LoaderCallbackOutput { + switch t := src.(type) { + case *LoaderCallbackOutput: + return t + case []*schema.Document: + return &LoaderCallbackOutput{ + Docs: t, + } + default: + return nil + } +} + +// LoaderCallbackHandler is the handler for the loader callback. +type LoaderCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *LoaderCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *LoaderCallbackOutput) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *LoaderCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} diff --git a/components/document/callback_extra_transformer.go b/components/document/callback_extra_transformer.go new file mode 100644 index 0000000..474e363 --- /dev/null +++ b/components/document/callback_extra_transformer.go @@ -0,0 +1,91 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package document + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// TransformerCallbackInput is the input for the transformer callback. +type TransformerCallbackInput struct { + // Input is the input documents. + Input []*schema.Document + + // Extra is the extra information for the callback. + Extra map[string]any +} + +// TransformerCallbackOutput is the output for the transformer callback. +type TransformerCallbackOutput struct { + // Output is the output documents. + Output []*schema.Document + + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvTransformerCallbackInput converts the callback input to the transformer callback input. +func ConvTransformerCallbackInput(src callbacks.CallbackInput) *TransformerCallbackInput { + switch t := src.(type) { + case *TransformerCallbackInput: + return t + case []*schema.Document: + return &TransformerCallbackInput{ + Input: t, + } + default: + return nil + } +} + +// ConvTransformerCallbackOutput converts the callback output to the transformer callback output. +func ConvTransformerCallbackOutput(src callbacks.CallbackOutput) *TransformerCallbackOutput { + switch t := src.(type) { + case *TransformerCallbackOutput: + return t + case []*schema.Document: + return &TransformerCallbackOutput{ + Output: t, + } + default: + return nil + } +} + +// TransformerCallbackHandler is the handler for the transformer callback. +type TransformerCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *TransformerCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *TransformerCallbackOutput) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *TransformerCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} diff --git a/components/document/doc.go b/components/document/doc.go new file mode 100644 index 0000000..92dcc11 --- /dev/null +++ b/components/document/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package document diff --git a/components/document/interface.go b/components/document/interface.go new file mode 100644 index 0000000..9ab7e50 --- /dev/null +++ b/components/document/interface.go @@ -0,0 +1,48 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package document + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// Source is a document source. +// e.g. https://www.bytedance.com/docx/xxxx, https://xxx.xxx.xxx/xx.pdf. +// make sure the URI can be reached by service. +type Source struct { + URI string +} + +//go:generate mockgen -destination ../../internal/mock/components/document/document_mock.go --package document -source interface.go + +// LoaderSplitter is a document loader and splitter. +// Deprecated: use Loader instead. +type LoaderSplitter interface { + LoadAndSplit(ctx context.Context, src Source, opts ...LoaderSplitterOption) ([]*schema.Document, error) +} + +// Loader is a document loader. +type Loader interface { + Load(ctx context.Context, src Source, opts ...LoaderOption) ([]*schema.Document, error) +} + +// Transformer is to convert documents, such as split or filter. +type Transformer interface { + Transform(ctx context.Context, src []*schema.Document, opts ...TransformerOption) ([]*schema.Document, error) +} diff --git a/components/document/option.go b/components/document/option.go new file mode 100644 index 0000000..449587f --- /dev/null +++ b/components/document/option.go @@ -0,0 +1,189 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package document + +// LoaderSplitterOption defines call option for LoaderSplitter component, which is part of the component interface signature. +// Each LoaderSplitter implementation could define its own options struct and option funcs within its own package, +// then wrap the impl specific option funcs into this type, before passing to LoadAndSplit. +// Deprecated: use LoaderOption instead. +type LoaderSplitterOption struct { + implSpecificOptFn any +} + +// LoaderOption defines call option for Loader component, which is part of the component interface signature. +// Each Loader implementation could define its own options struct and option funcs within its own package, +// then wrap the impl specific option funcs into this type, before passing to Load. +type LoaderOption struct { + implSpecificOptFn any +} + +// WrapImplSpecificOptFn wraps the impl specific option functions into LoaderSplitterOption type. +// T: the type of the impl specific options struct. +// LoaderSplitter implementations are required to use this function to convert its own option functions into the unified LoaderSplitterOption type. +// For example, if the LoaderSplitter impl defines its own options struct: +// +// type customOptions struct { +// conf string +// } +// +// Then the impl needs to provide an option function as such: +// +// func WithConf(conf string) Option { +// return WrapImplSpecificOptFn(func(o *customOptions) { +// o.conf = conf +// } +// } +// +// . +// Deprecated: use WrapLoaderImplSpecificOptFn instead. +func WrapImplSpecificOptFn[T any](optFn func(*T)) LoaderSplitterOption { + return LoaderSplitterOption{ + implSpecificOptFn: optFn, + } +} + +// WrapLoaderImplSpecificOptFn wraps the impl specific option functions into LoaderOption type. +// T: the type of the impl specific options struct. +// Loader implementations are required to use this function to convert its own option functions into the unified LoaderOption type. +// For example, if the Loader impl defines its own options struct: +// +// type customOptions struct { +// conf string +// } +// +// Then the impl needs to provide an option function as such: +// +// func WithConf(conf string) Option { +// return WrapLoaderImplSpecificOptFn(func(o *customOptions) { +// o.conf = conf +// } +// } +func WrapLoaderImplSpecificOptFn[T any](optFn func(*T)) LoaderOption { + return LoaderOption{ + implSpecificOptFn: optFn, + } +} + +// GetImplSpecificOptions provides LoaderSplitter author the ability to extract their own custom options from the unified LoaderSplitterOption type. +// T: the type of the impl specific options struct. +// This function should be used within the LoaderSplitter implementation's LoadAndSplit function. +// It is recommended to provide a base T as the first argument, within which the LoaderSplitter author can provide default values for the impl specific options. +// Deprecated: use GetLoaderImplSpecificOptions instead. +func GetImplSpecificOptions[T any](base *T, opts ...LoaderSplitterOption) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + s, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + s(base) + } + } + } + + return base +} + +// GetLoaderImplSpecificOptions provides Loader author the ability to extract their own custom options from the unified LoaderOption type. +// T: the type of the impl specific options struct. +// This function should be used within the Loader implementation's Load function. +// It is recommended to provide a base T as the first argument, within which the Loader author can provide default values for the impl specific options. +// eg. +// +// myOption := &MyOption{ +// Field1: "default_value", +// } +// myOption := loader.GetLoaderImplSpecificOptions(myOption, opts...) +func GetLoaderImplSpecificOptions[T any](base *T, opts ...LoaderOption) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + s, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + s(base) + } + } + } + + return base +} + +// TransformerOption defines call option for Transformer component, which is part of the component interface signature. +// Each Transformer implementation could define its own options struct and option funcs within its own package, +// then wrap the impl specific option funcs into this type, before passing to Transform. +type TransformerOption struct { + implSpecificOptFn any +} + +// WrapTransformerImplSpecificOptFn wraps the impl specific option functions into TransformerOption type. +// T: the type of the impl specific options struct. +// Transformer implementations are required to use this function to convert its own option functions into the unified TransformerOption type. +// For example, if the Transformer impl defines its own options struct: +// +// type customOptions struct { +// conf string +// } +// +// Then the impl needs to provide an option function as such: +// +// func WithConf(conf string) TransformerOption { +// return WrapTransformerImplSpecificOptFn(func(o *customOptions) { +// o.conf = conf +// } +// } +// +// . +func WrapTransformerImplSpecificOptFn[T any](optFn func(*T)) TransformerOption { + return TransformerOption{ + implSpecificOptFn: optFn, + } +} + +// GetTransformerImplSpecificOptions provides Transformer author the ability to extract their own custom options from the unified TransformerOption type. +// T: the type of the impl specific options struct. +// This function should be used within the Transformer implementation's Transform function. +// It is recommended to provide a base T as the first argument, within which the Transformer author can provide default values for the impl specific options. +// eg. +// +// myOption := &MyOption{ +// Field1: "default_value", +// } +// myOption := transformer.GetTransformerImplSpecificOptions(myOption, opts...) +func GetTransformerImplSpecificOptions[T any](base *T, opts ...TransformerOption) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + s, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + s(base) + } + } + } + + return base +} diff --git a/components/document/option_test.go b/components/document/option_test.go new file mode 100644 index 0000000..68e8fe2 --- /dev/null +++ b/components/document/option_test.go @@ -0,0 +1,65 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package document + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestImplSpecificOpts(t *testing.T) { + type implSpecificOptions struct { + conf string + index int + } + + withConf := func(conf string) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.conf = conf + } + } + + withIndex := func(index int) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.index = index + } + } + + convey.Convey("TestLoaderImplSpecificOpts", t, func() { + documentOption1 := WrapLoaderImplSpecificOptFn(withConf("test_conf")) + documentOption2 := WrapLoaderImplSpecificOptFn(withIndex(1)) + + implSpecificOpts := GetLoaderImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) + + convey.So(implSpecificOpts, convey.ShouldResemble, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }) + }) + convey.Convey("TestTransformerImplSpecificOpts", t, func() { + documentOption1 := WrapTransformerImplSpecificOptFn(withConf("test_conf")) + documentOption2 := WrapTransformerImplSpecificOptFn(withIndex(1)) + + implSpecificOpts := GetTransformerImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) + + convey.So(implSpecificOpts, convey.ShouldResemble, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }) + }) +} diff --git a/components/document/parser/ext_parser.go b/components/document/parser/ext_parser.go new file mode 100644 index 0000000..a8e4c59 --- /dev/null +++ b/components/document/parser/ext_parser.go @@ -0,0 +1,131 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +import ( + "context" + "errors" + "io" + "path/filepath" + + "github.com/cloudwego/eino/schema" +) + +// ExtParserConfig defines the configuration for the ExtParser. +type ExtParserConfig struct { + // ext -> parser. + // eg: map[string]Parser{ + // ".pdf": &PDFParser{}, + // ".md": &MarkdownParser{}, + // } + Parsers map[string]Parser + + // Fallback parser to use when no other parser is found. + // Default is TextParser if not set. + FallbackParser Parser +} + +// ExtParser is a parser that uses the file extension to determine which parser to use. +// You can register your own parsers by calling RegisterParser. +// Default parser is TextParser. +// Note: +// +// parse 时,是通过 filepath.Ext(uri) 的方式找到对应的 parser,因此使用时需要: +// ① 必须使用 parser.WithURI 在请求时传入 URI +// ② URI 必须能通过 filepath.Ext 来解析出符合预期的 ext +// +// eg: +// +// pdf, _ := os.Open("./testdata/test.pdf") +// docs, err := ExtParser.Parse(ctx, pdf, parser.WithURI("./testdata/test.pdf")) +type ExtParser struct { + parsers map[string]Parser + + fallbackParser Parser +} + +// NewExtParser creates a new ExtParser. +func NewExtParser(ctx context.Context, conf *ExtParserConfig) (*ExtParser, error) { + if conf == nil { + conf = &ExtParserConfig{} + } + + p := &ExtParser{ + parsers: conf.Parsers, + fallbackParser: conf.FallbackParser, + } + + if p.fallbackParser == nil { + p.fallbackParser = TextParser{} + } + + if p.parsers == nil { + p.parsers = make(map[string]Parser) + } + + return p, nil +} + +// GetParsers returns a copy of the registered parsers. +// It is safe to modify the returned parsers. +func (p *ExtParser) GetParsers() map[string]Parser { + + res := make(map[string]Parser, len(p.parsers)) + for k, v := range p.parsers { + res[k] = v + } + + return res +} + +// Parse parses the given reader and returns a list of documents. +func (p *ExtParser) Parse(ctx context.Context, reader io.Reader, opts ...Option) ([]*schema.Document, error) { + opt := GetCommonOptions(&Options{}, opts...) + + ext := filepath.Ext(opt.URI) + + parser, ok := p.parsers[ext] + + if !ok { + parser = p.fallbackParser + } + + if parser == nil { + return nil, errors.New("no parser found for extension " + ext) + } + + docs, err := parser.Parse(ctx, reader, opts...) + if err != nil { + return nil, err + } + + for _, doc := range docs { + if doc == nil { + continue + } + + if doc.MetaData == nil { + doc.MetaData = make(map[string]any) + } + + for k, v := range opt.ExtraMeta { + doc.MetaData[k] = v + } + } + + return docs, nil +} diff --git a/components/document/parser/interface.go b/components/document/parser/interface.go new file mode 100644 index 0000000..388417d --- /dev/null +++ b/components/document/parser/interface.go @@ -0,0 +1,29 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +import ( + "context" + "io" + + "github.com/cloudwego/eino/schema" +) + +// Parser is a document parser, can be used to parse a document from a reader. +type Parser interface { + Parse(ctx context.Context, reader io.Reader, opts ...Option) ([]*schema.Document, error) +} diff --git a/components/document/parser/option.go b/components/document/parser/option.go new file mode 100644 index 0000000..7c23e49 --- /dev/null +++ b/components/document/parser/option.go @@ -0,0 +1,115 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +type Options struct { + // uri of source. + URI string + + // extra metadata will merge to each document. + ExtraMeta map[string]any +} + +// Option defines call option for Parser component, which is part of the component interface signature. +// Each Parser implementation could define its own options struct and option funcs within its own package, +// then wrap the impl specific option funcs into this type, before passing to Transform. +type Option struct { + apply func(opts *Options) + + implSpecificOptFn any +} + +// WithURI specifies the URI of the document. +// It will be used as to select parser in ExtParser. +func WithURI(uri string) Option { + return Option{ + apply: func(opts *Options) { + opts.URI = uri + }, + } +} + +// WithExtraMeta specifies the extra meta data of the document. +func WithExtraMeta(meta map[string]any) Option { + return Option{ + apply: func(opts *Options) { + opts.ExtraMeta = meta + }, + } +} + +// GetCommonOptions extract parser Options from Option list, optionally providing a base Options with default values. +func GetCommonOptions(base *Options, opts ...Option) *Options { + if base == nil { + base = &Options{} + } + + for i := range opts { + opt := opts[i] + if opt.apply != nil { + opt.apply(base) + } + } + + return base +} + +// WrapImplSpecificOptFn wraps the impl specific option functions into Option type. +// T: the type of the impl specific options struct. +// Parser implementations are required to use this function to convert its own option functions into the unified Option type. +// For example, if the Parser impl defines its own options struct: +// +// type customOptions struct { +// conf string +// } +// +// Then the impl needs to provide an option function as such: +// +// func WithConf(conf string) Option { +// return WrapImplSpecificOptFn(func(o *customOptions) { +// o.conf = conf +// } +// } +// +// . +func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { + return Option{ + implSpecificOptFn: optFn, + } +} + +// GetImplSpecificOptions provides Parser author the ability to extract their own custom options from the unified Option type. +// T: the type of the impl specific options struct. +// This function should be used within the Parser implementation's Transform function. +// It is recommended to provide a base T as the first argument, within which the Parser author can provide default values for the impl specific options. +func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + s, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + s(base) + } + } + } + + return base +} diff --git a/components/document/parser/option_test.go b/components/document/parser/option_test.go new file mode 100644 index 0000000..d2ed699 --- /dev/null +++ b/components/document/parser/option_test.go @@ -0,0 +1,54 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestImplSpecificOpts(t *testing.T) { + type implSpecificOptions struct { + conf string + index int + } + + withConf := func(conf string) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.conf = conf + } + } + + withIndex := func(index int) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.index = index + } + } + + convey.Convey("TestImplSpecificOpts", t, func() { + parserOption1 := WrapImplSpecificOptFn(withConf("test_conf")) + parserOption2 := WrapImplSpecificOptFn(withIndex(1)) + + implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, parserOption1, parserOption2) + + convey.So(implSpecificOpts, convey.ShouldResemble, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }) + }) +} diff --git a/components/document/parser/parser_test.go b/components/document/parser/parser_test.go new file mode 100644 index 0000000..7249b54 --- /dev/null +++ b/components/document/parser/parser_test.go @@ -0,0 +1,119 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +import ( + "context" + "io" + "os" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +type ParserForTest struct { + mock func() ([]*schema.Document, error) +} + +func (p *ParserForTest) Parse(ctx context.Context, reader io.Reader, opts ...Option) ([]*schema.Document, error) { + return p.mock() +} + +func TestParser(t *testing.T) { + ctx := context.Background() + + t.Run("Test default parser", func(t *testing.T) { + conf := &ExtParserConfig{} + + p, err := NewExtParser(ctx, conf) + if err != nil { + t.Fatal(err) + } + + f, err := os.Open("testdata/test.md") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + docs, err := p.Parse(ctx, f, WithURI("testdata/test.md")) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, 1, len(docs)) + assert.Equal(t, "# Title\nhello world", docs[0].Content) + }) + + t.Run("test types", func(t *testing.T) { + mockParser := &ParserForTest{ + mock: func() ([]*schema.Document, error) { + return []*schema.Document{ + { + Content: "hello world", + MetaData: map[string]any{ + "type": "text", + }, + }, + }, nil + }, + } + + conf := &ExtParserConfig{ + Parsers: map[string]Parser{ + ".md": mockParser, + }, + } + + p, err := NewExtParser(ctx, conf) + if err != nil { + t.Fatal(err) + } + + f, err := os.Open("testdata/test.md") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + docs, err := p.Parse(ctx, f, WithURI("x/test.md")) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, 1, len(docs)) + assert.Equal(t, "hello world", docs[0].Content) + assert.Equal(t, "text", docs[0].MetaData["type"]) + + }) + + t.Run("test get parsers", func(t *testing.T) { + p, err := NewExtParser(ctx, &ExtParserConfig{ + Parsers: map[string]Parser{ + ".md": &TextParser{}, + }, + }) + if err != nil { + t.Fatal(err) + } + + ps := p.GetParsers() + assert.Equal(t, 1, len(ps)) + }) +} diff --git a/components/document/parser/testdata/test.md b/components/document/parser/testdata/test.md new file mode 100644 index 0000000..fb76356 --- /dev/null +++ b/components/document/parser/testdata/test.md @@ -0,0 +1,2 @@ +# Title +hello world \ No newline at end of file diff --git a/components/document/parser/text_parser.go b/components/document/parser/text_parser.go new file mode 100644 index 0000000..947aaeb --- /dev/null +++ b/components/document/parser/text_parser.go @@ -0,0 +1,59 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package parser + +import ( + "context" + "io" + + "github.com/cloudwego/eino/schema" +) + +const ( + MetaKeySource = "_source" +) + +// TextParser is a simple parser that reads the text from a reader and returns a single document. +// eg: +// +// docs, err := TextParser.Parse(ctx, strings.NewReader("hello world")) +// fmt.Println(docs[0].Content) // "hello world" +type TextParser struct{} + +// Parse reads the text from a reader and returns a single document. +func (dp TextParser) Parse(ctx context.Context, reader io.Reader, opts ...Option) ([]*schema.Document, error) { + data, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + + opt := GetCommonOptions(&Options{}, opts...) + + meta := make(map[string]any) + meta[MetaKeySource] = opt.URI + + for k, v := range opt.ExtraMeta { + meta[k] = v + } + + doc := &schema.Document{ + Content: string(data), + MetaData: meta, + } + + return []*schema.Document{doc}, nil +} diff --git a/components/embedding/callback_extra.go b/components/embedding/callback_extra.go new file mode 100644 index 0000000..d091ed2 --- /dev/null +++ b/components/embedding/callback_extra.go @@ -0,0 +1,120 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package embedding + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" +) + +// TokenUsage is the token usage for the embedding. +type TokenUsage struct { + // PromptTokens is the number of prompt tokens. + PromptTokens int + // CompletionTokens is the number of completion tokens. + CompletionTokens int + // TotalTokens is the total number of tokens. + TotalTokens int +} + +// Config is the config for the embedding. +type Config struct { + // Model is the model name. + Model string + // EncodingFormat is the encoding format. + EncodingFormat string +} + +// ComponentExtra is the extra information for the embedding. +type ComponentExtra struct { + // Config is the config for the embedding. + Config *Config + // TokenUsage is the token usage for the embedding. + TokenUsage *TokenUsage +} + +// CallbackInput is the input for the embedding callback. +type CallbackInput struct { + // Texts is the texts to be embedded. + Texts []string + // Config is the config for the embedding. + Config *Config + // Extra is the extra information for the callback. + Extra map[string]any +} + +// CallbackOutput is the output for the embedding callback. +type CallbackOutput struct { + // Embeddings is the embeddings. + Embeddings [][]float64 + // Config is the config for creating the embedding. + Config *Config + // TokenUsage is the token usage for the embedding. + TokenUsage *TokenUsage + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the embedding callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: + return t + case []string: + return &CallbackInput{ + Texts: t, + } + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the embedding callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: + return t + case [][]float64: + return &CallbackOutput{ + Embeddings: t, + } + default: + return nil + } +} + +// CallbackHandler is the handler for the embedding callback. +type CallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} diff --git a/components/embedding/callback_extra_test.go b/components/embedding/callback_extra_test.go new file mode 100644 index 0000000..90e5cb9 --- /dev/null +++ b/components/embedding/callback_extra_test.go @@ -0,0 +1,33 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package embedding + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConvEmbedding(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput([]string{})) + assert.Nil(t, ConvCallbackInput("asd")) + + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput([][]float64{})) + assert.Nil(t, ConvCallbackOutput("asd")) +} diff --git a/components/embedding/doc.go b/components/embedding/doc.go new file mode 100644 index 0000000..a97ab6a --- /dev/null +++ b/components/embedding/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package embedding diff --git a/components/embedding/interface.go b/components/embedding/interface.go new file mode 100644 index 0000000..ed20492 --- /dev/null +++ b/components/embedding/interface.go @@ -0,0 +1,24 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package embedding + +import "context" + +//go:generate mockgen -destination ../../internal/mock/components/embedding/Embedding_mock.go --package embedding -source interface.go +type Embedder interface { + EmbedStrings(ctx context.Context, texts []string, opts ...Option) ([][]float64, error) // invoke +} diff --git a/components/embedding/option.go b/components/embedding/option.go new file mode 100644 index 0000000..7fcba1d --- /dev/null +++ b/components/embedding/option.go @@ -0,0 +1,60 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package embedding + +// Options is the options for the embedding. +type Options struct { + // Model is the model name for the embedding. + Model *string +} + +// Option is the call option for Embedder component. +type Option struct { + apply func(opts *Options) +} + +// WithModel is the option to set the model for the embedding. +func WithModel(model string) Option { + return Option{ + apply: func(opts *Options) { + opts.Model = &model + }, + } +} + +// GetCommonOptions extract embedding Options from Option list, optionally providing a base Options with default values. +// eg. +// +// defaultModelName := "default_model" +// embeddingOption := &embedding.Options{ +// Model: &defaultModelName, +// } +// embeddingOption := embedding.GetCommonOptions(embeddingOption, opts...) +func GetCommonOptions(base *Options, opts ...Option) *Options { + if base == nil { + base = &Options{} + } + + for i := range opts { + opt := opts[i] + if opt.apply != nil { + opt.apply(base) + } + } + + return base +} diff --git a/components/embedding/option_test.go b/components/embedding/option_test.go new file mode 100644 index 0000000..63c71af --- /dev/null +++ b/components/embedding/option_test.go @@ -0,0 +1,30 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package embedding + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOptions(t *testing.T) { + defaultModel := "default_model" + opts := GetCommonOptions(&Options{Model: &defaultModel}, WithModel("test_model")) + assert.NotNil(t, opts.Model) + assert.Equal(t, *opts.Model, "test_model") +} diff --git a/components/indexer/callback_extra.go b/components/indexer/callback_extra.go new file mode 100644 index 0000000..a596261 --- /dev/null +++ b/components/indexer/callback_extra.go @@ -0,0 +1,89 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package indexer + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// CallbackInput is the input for the indexer callback. +type CallbackInput struct { + // Docs is the documents to be indexed. + Docs []*schema.Document + // Extra is the extra information for the callback. + Extra map[string]any +} + +// CallbackOutput is the output for the indexer callback. +type CallbackOutput struct { + // IDs is the ids of the indexed documents returned by the indexer. + IDs []string + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the indexer callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: + return t + case []*schema.Document: + return &CallbackInput{ + Docs: t, + } + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the indexer callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: + return t + case []string: + return &CallbackOutput{ + IDs: t, + } + default: + return nil + } +} + +// CallbackHandler is the handler for the indexer callback. +type CallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} diff --git a/components/indexer/callback_extra_test.go b/components/indexer/callback_extra_test.go new file mode 100644 index 0000000..2ac79be --- /dev/null +++ b/components/indexer/callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package indexer + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvIndexer(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput([]*schema.Document{})) + assert.Nil(t, ConvCallbackInput("asd")) + + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput([]string{})) + assert.Nil(t, ConvCallbackOutput("asd")) +} diff --git a/components/indexer/doc.go b/components/indexer/doc.go new file mode 100644 index 0000000..ad602fd --- /dev/null +++ b/components/indexer/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package indexer diff --git a/components/indexer/interface.go b/components/indexer/interface.go new file mode 100644 index 0000000..cef19b3 --- /dev/null +++ b/components/indexer/interface.go @@ -0,0 +1,32 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package indexer + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// Indexer is the interface for the indexer. +// Indexer is used to store the documents. +// +//go:generate mockgen -destination ../../internal/mock/components/indexer/indexer_mock.go --package indexer -source interface.go +type Indexer interface { + // Store stores the documents. + Store(ctx context.Context, docs []*schema.Document, opts ...Option) (ids []string, err error) // invoke +} diff --git a/components/indexer/option.go b/components/indexer/option.go new file mode 100644 index 0000000..546fe05 --- /dev/null +++ b/components/indexer/option.go @@ -0,0 +1,73 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package indexer + +import "github.com/cloudwego/eino/components/embedding" + +// Options is the options for the indexer. +type Options struct { + // SubIndexes is the sub indexes to be indexed. + SubIndexes []string + // Embedding is the embedding component. + Embedding embedding.Embedder +} + +// WithSubIndexes is the option to set the sub indexes for the indexer. +func WithSubIndexes(subIndexes []string) Option { + return Option{ + apply: func(opts *Options) { + opts.SubIndexes = subIndexes + }, + } +} + +// WithEmbedding is the option to set the embedder for the indexer, which convert document to embeddings. +func WithEmbedding(emb embedding.Embedder) Option { + return Option{ + apply: func(opts *Options) { + opts.Embedding = emb + }, + } +} + +// Option is the call option for Indexer component. +type Option struct { + apply func(opts *Options) +} + +// GetCommonOptions extract indexer Options from Option list, optionally providing a base Options with default values. +// eg. +// +// indexerOption := &IndexerOption{ +// SubIndexes: []string{"default_sub_index"}, // default value +// } +// +// indexerOption := indexer.GetCommonOptions(indexerOption, opts...) +func GetCommonOptions(base *Options, opts ...Option) *Options { + if base == nil { + base = &Options{} + } + + for i := range opts { + opt := opts[i] + if opt.apply != nil { + opt.apply(base) + } + } + + return base +} diff --git a/components/indexer/option_test.go b/components/indexer/option_test.go new file mode 100644 index 0000000..92929c2 --- /dev/null +++ b/components/indexer/option_test.go @@ -0,0 +1,45 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package indexer + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + + "github.com/cloudwego/eino/internal/mock/components/embedding" +) + +func TestOptions(t *testing.T) { + convey.Convey("test options", t, func() { + var ( + subIndexes = []string{"index_1", "index_2"} + e = &embedding.MockEmbedder{} + ) + + opts := GetCommonOptions( + &Options{}, + WithSubIndexes(subIndexes), + WithEmbedding(e), + ) + + convey.So(opts, convey.ShouldResemble, &Options{ + SubIndexes: subIndexes, + Embedding: e, + }) + }) +} diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go new file mode 100644 index 0000000..a5f420c --- /dev/null +++ b/components/model/callback_extra.go @@ -0,0 +1,126 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// TokenUsage is the token usage for the model. +type TokenUsage struct { + // PromptTokens is the number of prompt tokens. + PromptTokens int + // CompletionTokens is the number of completion tokens. + CompletionTokens int + // TotalTokens is the total number of tokens. + TotalTokens int +} + +// Config is the config for the model. +type Config struct { + // Model is the model name. + Model string + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". + MaxTokens int + // Temperature is the temperature, which controls the randomness of the model. + Temperature float32 + // TopP is the top p, which controls the diversity of the model. + TopP float32 + // Stop is the stop words, which controls the stopping condition of the model. + Stop []string +} + +// CallbackInput is the input for the model callback. +type CallbackInput struct { + // Messages is the messages to be sent to the model. + Messages []*schema.Message + // Tools is the tools to be used in the model. + Tools []*schema.ToolInfo + // ToolChoice is the tool choice, which controls the tool to be used in the model. + ToolChoice any // string / *schema.ToolInfo + // Config is the config for the model. + Config *Config + // Extra is the extra information for the callback. + Extra map[string]any +} + +// CallbackOutput is the output for the model callback. +type CallbackOutput struct { + // Message is the message generated by the model. + Message *schema.Message + // Config is the config for the model. + Config *Config + // TokenUsage is the token usage of this request. + TokenUsage *TokenUsage + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the model callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: + return t + case []*schema.Message: + return &CallbackInput{ + Messages: t, + } + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the model callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: + return t + case *schema.Message: + return &CallbackOutput{ + Message: t, + } + default: + return nil + } +} + +// CallbackHandler is the handler for the model callback. +type CallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*CallbackOutput]) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + default: + return false + } +} diff --git a/components/model/callback_extra_test.go b/components/model/callback_extra_test.go new file mode 100644 index 0000000..2fe8443 --- /dev/null +++ b/components/model/callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvModel(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput([]*schema.Message{})) + assert.Nil(t, ConvCallbackInput("asd")) + + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&schema.Message{})) + assert.Nil(t, ConvCallbackOutput("asd")) +} diff --git a/components/model/doc.go b/components/model/doc.go new file mode 100644 index 0000000..ab4f35a --- /dev/null +++ b/components/model/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model diff --git a/components/model/interface.go b/components/model/interface.go new file mode 100644 index 0000000..220473f --- /dev/null +++ b/components/model/interface.go @@ -0,0 +1,38 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// ChatModel support openai and maas. +// use Generate for completed output, use Stream as for stream output. +// +//go:generate mockgen -destination ../../internal/mock/components/model/ChatModel_mock.go --package model -source interface.go +type ChatModel interface { + Generate(ctx context.Context, input []*schema.Message, opts ...Option) (*schema.Message, error) + Stream(ctx context.Context, input []*schema.Message, opts ...Option) ( + *schema.StreamReader[*schema.Message], error) + + // BindTools bind tools to the model. + // BindTools before requesting ChatModel generally. + // notice the non-atomic problem of BindTools and Generate. + BindTools(tools []*schema.ToolInfo) error +} diff --git a/components/model/option.go b/components/model/option.go new file mode 100644 index 0000000..f788703 --- /dev/null +++ b/components/model/option.go @@ -0,0 +1,132 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +// Options is the common options for the model. +type Options struct { + // Temperature is the temperature for the model, which controls the randomness of the model. + Temperature *float32 + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". + MaxTokens *int + // Model is the model name. + Model *string + // TopP is the top p for the model, which controls the diversity of the model. + TopP *float32 + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string +} + +// Option is the call option for ChatModel component. +type Option struct { + apply func(opts *Options) + + implSpecificOptFn any +} + +// WithTemperature is the option to set the temperature for the model. +func WithTemperature(temperature float32) Option { + return Option{ + apply: func(opts *Options) { + opts.Temperature = &temperature + }, + } +} + +// WithMaxTokens is the option to set the max tokens for the model. +func WithMaxTokens(maxTokens int) Option { + return Option{ + apply: func(opts *Options) { + opts.MaxTokens = &maxTokens + }, + } +} + +// WithModel is the option to set the model name. +func WithModel(name string) Option { + return Option{ + apply: func(opts *Options) { + opts.Model = &name + }, + } +} + +// WithTopP is the option to set the top p for the model. +func WithTopP(topP float32) Option { + return Option{ + apply: func(opts *Options) { + opts.TopP = &topP + }, + } +} + +// WithStop is the option to set the stop words for the model. +func WithStop(stop []string) Option { + return Option{ + apply: func(opts *Options) { + opts.Stop = stop + }, + } +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. +func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { + return Option{ + implSpecificOptFn: optFn, + } +} + +// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values. +func GetCommonOptions(base *Options, opts ...Option) *Options { + if base == nil { + base = &Options{} + } + + for i := range opts { + opt := opts[i] + if opt.apply != nil { + opt.apply(base) + } + } + + return base +} + +// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values. +// eg. +// +// myOption := &MyOption{ +// Field1: "default_value", +// } +// +// myOption := model.GetImplSpecificOptions(myOption, opts...) +func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} diff --git a/components/model/option_test.go b/components/model/option_test.go new file mode 100644 index 0000000..e1dda7d --- /dev/null +++ b/components/model/option_test.go @@ -0,0 +1,88 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestOptions(t *testing.T) { + convey.Convey("test options", t, func() { + var ( + modelName = "model" + temperature float32 = 0.9 + maxToken = 5000 + topP float32 = 0.8 + defaultModel = "default_model" + defaultTemperature float32 = 1.0 + defaultMaxTokens = 1000 + defaultTopP float32 = 0.5 + ) + + opts := GetCommonOptions( + &Options{ + Model: &defaultModel, + Temperature: &defaultTemperature, + MaxTokens: &defaultMaxTokens, + TopP: &defaultTopP, + }, + WithModel(modelName), + WithTemperature(temperature), + WithMaxTokens(maxToken), + WithTopP(topP), + WithStop([]string{"hello", "bye"}), + ) + + convey.So(opts, convey.ShouldResemble, &Options{ + Model: &modelName, + Temperature: &temperature, + MaxTokens: &maxToken, + TopP: &topP, + Stop: []string{"hello", "bye"}, + }) + }) +} + +type implOption struct { + userID int64 + name string +} + +func WithUserID(uid int64) Option { + return WrapImplSpecificOptFn[implOption](func(i *implOption) { + i.userID = uid + }) +} + +func WithName(n string) Option { + return WrapImplSpecificOptFn[implOption](func(i *implOption) { + i.name = n + }) +} + +func TestImplSpecificOption(t *testing.T) { + convey.Convey("impl_specific_option", t, func() { + opt := GetImplSpecificOptions(&implOption{}, WithUserID(101), WithName("Wang")) + + convey.So(opt, convey.ShouldEqual, &implOption{ + userID: 101, + name: "Wang", + }) + }) +} diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go new file mode 100644 index 0000000..1d85e66 --- /dev/null +++ b/components/prompt/callback_extra.go @@ -0,0 +1,96 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// CallbackInput is the input for the callback. +type CallbackInput struct { + // Variables is the variables for the callback. + Variables map[string]any + // Templates is the templates for the callback. + Templates []schema.MessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// CallbackOutput is the output for the callback. +type CallbackOutput struct { + // Result is the result for the callback. + Result []*schema.Message + // Templates is the templates for the callback. + Templates []schema.MessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the prompt callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: + return t + case map[string]any: + return &CallbackInput{ + Variables: t, + } + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the prompt callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: + return t + case []*schema.Message: + return &CallbackOutput{ + Result: t, + } + default: + return nil + } +} + +// CallbackHandler is the handler for the callback. +type CallbackHandler struct { + // OnStart is the callback function for the start of the callback. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context + // OnEnd is the callback function for the end of the callback. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context + // OnError is the callback function for the error of the callback. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go new file mode 100644 index 0000000..456297e --- /dev/null +++ b/components/prompt/callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvPrompt(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput(map[string]any{})) + assert.Nil(t, ConvCallbackInput("asd")) + + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) + assert.Nil(t, ConvCallbackOutput("asd")) +} diff --git a/components/prompt/chat_template.go b/components/prompt/chat_template.go new file mode 100644 index 0000000..69cffec --- /dev/null +++ b/components/prompt/chat_template.go @@ -0,0 +1,89 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// DefaultChatTemplate is the default chat template implementation. +type DefaultChatTemplate struct { + // templates is the templates for the chat template. + templates []schema.MessagesTemplate + // formatType is the format type for the chat template. + formatType schema.FormatType +} + +// FromMessages creates a new DefaultChatTemplate from the given templates and format type. +// eg. +// +// template := prompt.FromMessages(schema.FString, &schema.Message{Content: "Hello, {name}!"}, &schema.Message{Content: "how are you?"}) +// // in chain, or graph +// chain := compose.NewChain[map[string]any, []*schema.Message]() +// chain.AppendChatTemplate(template) +func FromMessages(formatType schema.FormatType, templates ...schema.MessagesTemplate) *DefaultChatTemplate { + return &DefaultChatTemplate{ + templates: templates, + formatType: formatType, + } +} + +// Format formats the chat template with the given context and variables. +func (t *DefaultChatTemplate) Format(ctx context.Context, + vs map[string]any, _ ...Option) (result []*schema.Message, err error) { + + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + } + }() + + ctx = callbacks.OnStart(ctx, &CallbackInput{ + Variables: vs, + Templates: t.templates, + }) + + result = make([]*schema.Message, 0, len(t.templates)) + for _, template := range t.templates { + msgs, err := template.Format(ctx, vs, t.formatType) + if err != nil { + return nil, err + } + + result = append(result, msgs...) + } + + _ = callbacks.OnEnd(ctx, &CallbackOutput{ + Result: result, + Templates: t.templates, + }) + + return result, nil +} + +// GetType returns the type of the chat template (Default). +func (t *DefaultChatTemplate) GetType() string { + return "Default" +} + +// IsCallbacksEnabled checks if the callbacks are enabled for the chat template. +func (t *DefaultChatTemplate) IsCallbacksEnabled() bool { + return true +} diff --git a/components/prompt/chat_template_test.go b/components/prompt/chat_template_test.go new file mode 100644 index 0000000..ec94681 --- /dev/null +++ b/components/prompt/chat_template_test.go @@ -0,0 +1,115 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestFormat(t *testing.T) { + pyFmtTestTemplate := []schema.MessagesTemplate{ + schema.SystemMessage( + "you are a helpful assistant.\n" + + "here is the context: {context}"), + schema.MessagesPlaceholder("chat_history", true), + schema.UserMessage("question: {question}"), + } + jinja2TestTemplate := []schema.MessagesTemplate{ + schema.SystemMessage( + "you are a helpful assistant.\n" + + "here is the context: {{context}}"), + schema.MessagesPlaceholder("chat_history", true), + schema.UserMessage("question: {{question}}"), + } + goFmtTestTemplate := []schema.MessagesTemplate{ + schema.SystemMessage( + "you are a helpful assistant.\n" + + "here is the context: {{.context}}"), + schema.MessagesPlaceholder("chat_history", true), + schema.UserMessage("question: {{.question}}"), + } + testValues := map[string]any{ + "context": "it's beautiful day", + "question": "how is the day today", + "chat_history": []*schema.Message{ + schema.UserMessage("who are you"), + schema.AssistantMessage("I'm a helpful assistant", nil), + }, + } + expected := []*schema.Message{ + schema.SystemMessage( + "you are a helpful assistant.\n" + + "here is the context: it's beautiful day"), + schema.UserMessage("who are you"), + schema.AssistantMessage("I'm a helpful assistant", nil), + schema.UserMessage("question: how is the day today"), + } + + // FString + chatTemplate := FromMessages(schema.FString, pyFmtTestTemplate...) + msgs, err := chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) + + // Jinja2 + chatTemplate = FromMessages(schema.Jinja2, jinja2TestTemplate...) + msgs, err = chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) + + // GoTemplate + chatTemplate = FromMessages(schema.GoTemplate, goFmtTestTemplate...) + msgs, err = chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) +} + +func TestDocumentFormat(t *testing.T) { + docs := []*schema.Document{ + { + ID: "1", + Content: "qwe", + MetaData: map[string]any{ + "hello": 888, + }, + }, + { + ID: "2", + Content: "asd", + MetaData: map[string]any{ + "bye": 111, + }, + }, + } + + template := FromMessages(schema.FString, + schema.SystemMessage("all:{all_docs}\nsingle:{single_doc}"), + ) + + msgs, err := template.Format(context.Background(), map[string]any{ + "all_docs": docs, + "single_doc": docs[0], + }) + + assert.Nil(t, err) + t.Log(msgs) +} diff --git a/components/prompt/doc.go b/components/prompt/doc.go new file mode 100644 index 0000000..c7717ee --- /dev/null +++ b/components/prompt/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt diff --git a/components/prompt/interface.go b/components/prompt/interface.go new file mode 100644 index 0000000..0d43541 --- /dev/null +++ b/components/prompt/interface.go @@ -0,0 +1,29 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +var _ ChatTemplate = &DefaultChatTemplate{} + +type ChatTemplate interface { + Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) +} diff --git a/components/prompt/option.go b/components/prompt/option.go new file mode 100644 index 0000000..494989b --- /dev/null +++ b/components/prompt/option.go @@ -0,0 +1,48 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +// Option is the call option for ChatTemplate component. +type Option struct { + implSpecificOptFn any +} + +// WrapImplSpecificOptFn wraps the implementation specific option function. +func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { + return Option{ + implSpecificOptFn: optFn, + } +} + +// GetImplSpecificOptions extracts the implementation specific options from Option list, optionally providing a base options with default values. +func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + s, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + s(base) + } + } + } + + return base +} diff --git a/components/prompt/option_test.go b/components/prompt/option_test.go new file mode 100644 index 0000000..851a013 --- /dev/null +++ b/components/prompt/option_test.go @@ -0,0 +1,51 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +type implOption struct { + userID int64 + name string +} + +func WithUserID(uid int64) Option { + return WrapImplSpecificOptFn[implOption](func(i *implOption) { + i.userID = uid + }) +} + +func WithName(n string) Option { + return WrapImplSpecificOptFn[implOption](func(i *implOption) { + i.name = n + }) +} + +func TestImplSpecificOption(t *testing.T) { + convey.Convey("impl_specific_option", t, func() { + opt := GetImplSpecificOptions(&implOption{}, WithUserID(101), WithName("Wang")) + + convey.So(opt, convey.ShouldEqual, &implOption{ + userID: 101, + name: "Wang", + }) + }) +} diff --git a/components/retriever/callback_extra.go b/components/retriever/callback_extra.go new file mode 100644 index 0000000..76ade40 --- /dev/null +++ b/components/retriever/callback_extra.go @@ -0,0 +1,100 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retriever + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// CallbackInput is the input for the retriever callback. +type CallbackInput struct { + // Query is the query for the retriever. + Query string + + // TopK is the top k for the retriever, which means the top number of documents to retrieve. + TopK int + // Filter is the filter for the retriever. + Filter string + // ScoreThreshold is the score threshold for the retriever, eg 0.5 means the score of the document must be greater than 0.5. + ScoreThreshold *float64 + + // Extra is the extra information for the retriever. + Extra map[string]any +} + +// CallbackOutput is the output for the retriever callback. +type CallbackOutput struct { + // Docs is the documents for the retriever. + Docs []*schema.Document + // Extra is the extra information for the retriever. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the retriever callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: + return t + case string: + return &CallbackInput{ + Query: t, + } + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the retriever callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: + return t + case []*schema.Document: + return &CallbackOutput{ + Docs: t, + } + default: + return nil + } +} + +// CallbackHandler is the handler for the retriever callback. +type CallbackHandler struct { + // OnStart is the callback function for the start of the retriever. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *CallbackInput) context.Context + // OnEnd is the callback function for the end of the retriever. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *CallbackOutput) context.Context + // OnError is the callback function for the error of the retriever. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} diff --git a/components/retriever/callback_extra_test.go b/components/retriever/callback_extra_test.go new file mode 100644 index 0000000..83cb9b5 --- /dev/null +++ b/components/retriever/callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retriever + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvRetriever(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput("asd")) + assert.Nil(t, ConvCallbackInput([]string{})) + + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput([]*schema.Document{})) + assert.Nil(t, ConvCallbackOutput("asd")) +} diff --git a/components/retriever/doc.go b/components/retriever/doc.go new file mode 100644 index 0000000..b12635f --- /dev/null +++ b/components/retriever/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retriever diff --git a/components/retriever/interface.go b/components/retriever/interface.go new file mode 100644 index 0000000..09ab1dd --- /dev/null +++ b/components/retriever/interface.go @@ -0,0 +1,42 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retriever + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +//go:generate mockgen -destination ../../internal/mock/components/retriever/retriever_mock.go --package retriever -source interface.go + +// Retriever is the interface for retriever. +// It is used to retrieve documents from a source. +// there are `vectorstore` and `fornaxknowledge` can be used as retriever. +// +// e.g. +// +// retriever, err := fornaxknowledge.NewRetriever(ctx, &RetrieverConfig{}) +// if err != nil {...} +// docs, err := retriever.Retrieve(ctx, "query") // <= using directly +// docs, err := retriever.Retrieve(ctx, "query", retriever.WithTopK(3)) // <= using options +// +// graph := compose.NewGraph[inputType, outputType](compose.RunTypeDAG) +// graph.AddRetrieverNode("retriever_node_key", retriever) // <= using in graph +type Retriever interface { + Retrieve(ctx context.Context, query string, opts ...Option) ([]*schema.Document, error) +} diff --git a/components/retriever/option.go b/components/retriever/option.go new file mode 100644 index 0000000..ce22513 --- /dev/null +++ b/components/retriever/option.go @@ -0,0 +1,111 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retriever + +import "github.com/cloudwego/eino/components/embedding" + +// Options is the options for the retriever. +type Options struct { + // Index is the index for the retriever, index in different retriever may be different. + Index *string + // SubIndex is the sub index for the retriever, sub index in different retriever may be different. + SubIndex *string + // TopK is the top k for the retriever, which means the top number of documents to retrieve. + TopK *int + // ScoreThreshold is the score threshold for the retriever, eg 0.5 means the score of the document must be greater than 0.5. + ScoreThreshold *float64 + // Embedding is the embedder for the retriever, which is used to embed the query for retrieval . + Embedding embedding.Embedder + + // DSLInfo is the dsl info for the retriever, which is used to retrieve the documents from the retriever. + // viking only + DSLInfo map[string]interface{} +} + +// WithIndex wraps the index option. +func WithIndex(index string) Option { + return Option{ + apply: func(opts *Options) { + opts.Index = &index + }, + } +} + +// WithSubIndex wraps the sub index option. +func WithSubIndex(subIndex string) Option { + return Option{ + apply: func(opts *Options) { + opts.SubIndex = &subIndex + }, + } +} + +// WithTopK wraps the top k option. +func WithTopK(topK int) Option { + return Option{ + apply: func(opts *Options) { + opts.TopK = &topK + }, + } +} + +// WithScoreThreshold wraps the score threshold option. +func WithScoreThreshold(threshold float64) Option { + return Option{ + apply: func(opts *Options) { + opts.ScoreThreshold = &threshold + }, + } +} + +// WithEmbedding wraps the embedder option. +func WithEmbedding(emb embedding.Embedder) Option { + return Option{ + apply: func(opts *Options) { + opts.Embedding = emb + }, + } +} + +// WithDSLInfo wraps the dsl info option. +func WithDSLInfo(dsl map[string]any) Option { + return Option{ + apply: func(opts *Options) { + opts.DSLInfo = dsl + }, + } +} + +// Option is the call option for Retriever component. +type Option struct { + apply func(opts *Options) +} + +// GetCommonOptions extract retriever Options from Option list, optionally providing a base Options with default values. +func GetCommonOptions(base *Options, opts ...Option) *Options { + if base == nil { + base = &Options{} + } + + for i := range opts { + if opts[i].apply != nil { + opts[i].apply(base) + } + } + + return base +} diff --git a/components/retriever/option_test.go b/components/retriever/option_test.go new file mode 100644 index 0000000..b332dc7 --- /dev/null +++ b/components/retriever/option_test.go @@ -0,0 +1,60 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package retriever + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" + + "github.com/cloudwego/eino/internal/mock/components/embedding" +) + +func TestOptions(t *testing.T) { + convey.Convey("test options", t, func() { + var ( + index = "index" + topK = 2 + scoreThreshold = 4.0 + subIndex = "sub_index" + dslInfo = map[string]any{"dsl": "dsl"} + e = &embedding.MockEmbedder{} + defaultTopK = 1 + ) + + opts := GetCommonOptions( + &Options{ + TopK: &defaultTopK, + }, + WithIndex(index), + WithTopK(topK), + WithScoreThreshold(scoreThreshold), + WithSubIndex(subIndex), + WithDSLInfo(dslInfo), + WithEmbedding(e), + ) + + convey.So(opts, convey.ShouldResemble, &Options{ + Index: &index, + TopK: &topK, + ScoreThreshold: &scoreThreshold, + SubIndex: &subIndex, + DSLInfo: dslInfo, + Embedding: e, + }) + }) +} diff --git a/components/tool/callback_extra.go b/components/tool/callback_extra.go new file mode 100644 index 0000000..bb8ee7a --- /dev/null +++ b/components/tool/callback_extra.go @@ -0,0 +1,88 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tool + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// CallbackInput is the input for the tool callback. +type CallbackInput struct { + // ArgumentsInJSON is the arguments in json format for the tool. + ArgumentsInJSON string + // Extra is the extra information for the tool. + Extra map[string]any +} + +// CallbackOutput is the output for the tool callback. +type CallbackOutput struct { + // Response is the response for the tool. + Response string + // Extra is the extra information for the tool. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the tool callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: + return t + case string: + return &CallbackInput{ArgumentsInJSON: t} + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the tool callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: + return t + case string: + return &CallbackOutput{Response: t} + default: + return nil + } +} + +// CallbackHandler is the handler for the tool callback. +type CallbackHandler struct { + OnStart func(ctx context.Context, info *callbacks.RunInfo, input *CallbackInput) context.Context + OnEnd func(ctx context.Context, info *callbacks.RunInfo, input *CallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*CallbackOutput]) context.Context + OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *CallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} diff --git a/components/tool/callback_extra_test.go b/components/tool/callback_extra_test.go new file mode 100644 index 0000000..dc51edf --- /dev/null +++ b/components/tool/callback_extra_test.go @@ -0,0 +1,37 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tool + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConvCallbackInput(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput("asd")) + assert.Nil(t, ConvCallbackInput(123)) + assert.Nil(t, ConvCallbackInput(nil)) +} + +func TestConvCallbackOutput(t *testing.T) { + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput("asd")) + assert.Nil(t, ConvCallbackOutput(123)) + assert.Nil(t, ConvCallbackOutput(nil)) +} diff --git a/components/tool/doc.go b/components/tool/doc.go new file mode 100644 index 0000000..be50356 --- /dev/null +++ b/components/tool/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tool diff --git a/components/tool/interface.go b/components/tool/interface.go new file mode 100644 index 0000000..e95fa09 --- /dev/null +++ b/components/tool/interface.go @@ -0,0 +1,45 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tool + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// BaseTool get tool info for ChatModel intent recognition. +type BaseTool interface { + Info(ctx context.Context) (*schema.ToolInfo, error) +} + +// InvokableTool the tool for ChatModel intent recognition and ToolsNode execution. +// nolint: byted_s_interface_name +type InvokableTool interface { + BaseTool + + // InvokableRun call function with arguments in JSON format + InvokableRun(ctx context.Context, argumentsInJSON string, opts ...Option) (string, error) +} + +// StreamableTool the stream tool for ChatModel intent recognition and ToolsNode execution. +// nolint: byted_s_interface_name +type StreamableTool interface { + BaseTool + + StreamableRun(ctx context.Context, argumentsInJSON string, opts ...Option) (*schema.StreamReader[string], error) +} diff --git a/components/tool/option.go b/components/tool/option.go new file mode 100644 index 0000000..1458448 --- /dev/null +++ b/components/tool/option.go @@ -0,0 +1,78 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tool + +// Option defines call option for InvokableTool or StreamableTool component, which is part of component interface signature. +// Each tool implementation could define its own options struct and option funcs within its own package, +// then wrap the impl specific option funcs into this type, before passing to InvokableRun or StreamableRun. +type Option struct { + implSpecificOptFn any +} + +// WrapImplSpecificOptFn wraps the impl specific option functions into Option type. +// T: the type of the impl specific options struct. +// Tool implementations are required to use this function to convert its own option functions into the unified Option type. +// For example, if the tool defines its own options struct: +// +// type customOptions struct { +// conf string +// } +// +// Then the tool needs to provide an option function as such: +// +// func WithConf(conf string) Option { +// return WrapImplSpecificOptFn(func(o *customOptions) { +// o.conf = conf +// } +// } +// +// . +func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { + return Option{ + implSpecificOptFn: optFn, + } +} + +// GetImplSpecificOptions provides tool author the ability to extract their own custom options from the unified Option type. +// T: the type of the impl specific options struct. +// This function should be used within the tool implementation's InvokableRun or StreamableRun functions. +// It is recommended to provide a base T as the first argument, within which the tool author can provide default values for the impl specific options. +// eg. +// +// type customOptions struct { +// conf string +// } +// defaultOptions := &customOptions{} +// +// customOptions := tool.GetImplSpecificOptions(defaultOptions, opts...) +func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} diff --git a/components/tool/option_test.go b/components/tool/option_test.go new file mode 100644 index 0000000..2a27604 --- /dev/null +++ b/components/tool/option_test.go @@ -0,0 +1,54 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tool + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestImplSpecificOpts(t *testing.T) { + convey.Convey("TestImplSpecificOpts", t, func() { + type implSpecificOptions struct { + conf string + index int + } + + withConf := func(conf string) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.conf = conf + } + } + + withIndex := func(index int) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.index = index + } + } + + toolOption1 := WrapImplSpecificOptFn(withConf("test_conf")) + toolOption2 := WrapImplSpecificOptFn(withIndex(1)) + + implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, toolOption1, toolOption2) + + convey.So(implSpecificOpts, convey.ShouldResemble, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }) + }) +} diff --git a/components/tool/utils/create_options.go b/components/tool/utils/create_options.go new file mode 100644 index 0000000..aadbb50 --- /dev/null +++ b/components/tool/utils/create_options.go @@ -0,0 +1,176 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "context" + "fmt" + "reflect" + "sort" + "strings" + + "github.com/getkin/kin-openapi/openapi3" +) + +// UnmarshalArguments is the function type for unmarshalling the arguments. +type UnmarshalArguments func(ctx context.Context, arguments string) (interface{}, error) + +// MarshalOutput is the function type for marshalling the output. +type MarshalOutput func(ctx context.Context, output interface{}) (string, error) + +type toolOptions struct { + um UnmarshalArguments + m MarshalOutput + sc SchemaCustomizerFn +} + +// Option is the option func for the tool. +type Option func(o *toolOptions) + +// WithUnmarshalArguments wraps the unmarshal arguments option. +// when you want to unmarshal the arguments by yourself, you can use this option. +func WithUnmarshalArguments(um UnmarshalArguments) Option { + return func(o *toolOptions) { + o.um = um + } +} + +// WithMarshalOutput wraps the marshal output option. +// when you want to marshal the output by yourself, you can use this option. +func WithMarshalOutput(m MarshalOutput) Option { + return func(o *toolOptions) { + o.m = m + } +} + +// SchemaCustomizerFn is the schema customizer function for inferring tool parameter from tagged go struct. +// Within this function, end-user can parse custom go struct tags into corresponding openapi schema field. +// Parameters: +// 1. name: the name of current schema, usually the field name of the go struct. Specifically, the last 'name' visited is fixed to be '_root', which represents the entire go struct. Also, for array field, both the field itself and the element within the array will trigger this function. +// 2. t: the type of current schema, usually the field type of the go struct. +// 3. tag: the struct tag of current schema, usually the field tag of the go struct. Note that the element within an array field will use the same go struct tag as the array field itself. +// 4. schema: the current openapi schema object to be customized. +type SchemaCustomizerFn func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error + +// WithSchemaCustomizer sets a user-defined schema customizer for inferring tool parameter from tagged go struct. +// If this option is not set, the defaultSchemaCustomizer will be used. +func WithSchemaCustomizer(sc SchemaCustomizerFn) Option { + return func(o *toolOptions) { + o.sc = sc + } +} + +func getToolOptions(opt ...Option) *toolOptions { + opts := &toolOptions{ + um: nil, + m: nil, + } + for _, o := range opt { + o(opts) + } + return opts +} + +// defaultSchemaCustomizer is the default schema customizer when using reflect to infer tool parameter from tagged go struct. +// Supported struct tags: +// 1. jsonschema: "description=xxx" +// 2. jsonschema: "enum=xxx,enum=yyy,enum=zzz" +// 3. jsonschema: "required" +// 4. can also use json: "xxx,omitempty" to mark the field as not required, which means an absence of 'omitempty' in json tag means the field is required. +// If this defaultSchemaCustomizer is not sufficient or suitable to your specific need, define your own SchemaCustomizerFn and pass it to WithSchemaCustomizer during InferTool or InferStreamTool. +func defaultSchemaCustomizer(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { + jsonS := tag.Get("jsonschema") + if len(jsonS) > 0 { + tags := strings.Split(jsonS, ",") + for _, t := range tags { + kv := strings.Split(t, "=") + if len(kv) == 2 { + if kv[0] == "description" { + schema.Description = kv[1] + } + if kv[0] == "enum" { + schema.WithEnum(kv[1]) + } + } else if len(kv) == 1 { + if kv[0] == "required" { + if schema.Extensions == nil { + schema.Extensions = make(map[string]any, 1) + } + schema.Extensions["x_required"] = true + } + } + } + } + + json := tag.Get("json") + if len(json) > 0 && !strings.Contains(json, "omitempty") { + if schema.Extensions == nil { + schema.Extensions = make(map[string]any, 1) + } + schema.Extensions["x_required"] = true + } + + if name == "_root" { + if err := setRequired(schema); err != nil { + return err + } + } + + return nil +} + +func setRequired(sc *openapi3.Schema) error { // check if properties are marked as required, set schema required to true accordingly + if sc.Type != openapi3.TypeObject && sc.Type != openapi3.TypeArray { + return nil + } + + if sc.Type == openapi3.TypeArray { + if sc.Items.Value.Extensions != nil { + if _, ok := sc.Items.Value.Extensions["x_required"]; ok { + delete(sc.Items.Value.Extensions, "x_required") + if len(sc.Items.Value.Extensions) == 0 { + sc.Items.Value.Extensions = nil + } + } + } + + if err := setRequired(sc.Items.Value); err != nil { + return fmt.Errorf("setRequired for array failed: %w", err) + } + } + + for k, p := range sc.Properties { + if p.Value.Extensions != nil { + if _, ok := p.Value.Extensions["x_required"]; ok { + sc.Required = append(sc.Required, k) + delete(p.Value.Extensions, "x_required") + if len(p.Value.Extensions) == 0 { + p.Value.Extensions = nil + } + } + + } + err := setRequired(p.Value) + if err != nil { + return fmt.Errorf("setRequired for nested property %s failed: %w", k, err) + } + } + + sort.Strings(sc.Required) + + return nil +} diff --git a/components/tool/utils/doc.go b/components/tool/utils/doc.go new file mode 100644 index 0000000..633cb56 --- /dev/null +++ b/components/tool/utils/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils diff --git a/components/tool/utils/invokable_func.go b/components/tool/utils/invokable_func.go new file mode 100644 index 0000000..88cb027 --- /dev/null +++ b/components/tool/utils/invokable_func.go @@ -0,0 +1,176 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "context" + "fmt" + "strings" + + "github.com/bytedance/sonic" + "github.com/getkin/kin-openapi/openapi3gen" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +// InvokeFunc is the function type for the tool. +type InvokeFunc[T, D any] func(ctx context.Context, input T) (output D, err error) + +// InferTool creates an InvokableTool from a given function by inferring the ToolInfo from the function's request parameters. +// End-user can pass a SchemaCustomizerFn in opts to customize the go struct tag parsing process, overriding default behavior. +func InferTool[T, D any](toolName, toolDesc string, i InvokeFunc[T, D], opts ...Option) (tool.InvokableTool, error) { + ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) + if err != nil { + return nil, err + } + + return NewTool(ti, i, opts...), nil +} + +// GoStruct2ParamsOneOf converts a go struct to a ParamsOneOf. +// if you attempt to use ResponseFormat of some ChatModel to get StructuredOutput, you can infer the JSONSchema from the go struct. +func GoStruct2ParamsOneOf[T any](opts ...Option) (*schema.ParamsOneOf, error) { + return goStruct2ParamsOneOf[T](opts...) +} + +// GoStruct2ToolInfo converts a go struct to a ToolInfo. +// if you attempt to use BindTool to make ChatModel respond StructuredOutput, you can infer the ToolInfo from the go struct. +func GoStruct2ToolInfo[T any](toolName, toolDesc string, opts ...Option) (*schema.ToolInfo, error) { + return goStruct2ToolInfo[T](toolName, toolDesc, opts...) +} + +func goStruct2ToolInfo[T any](toolName, toolDesc string, opts ...Option) (*schema.ToolInfo, error) { + paramsOneOf, err := goStruct2ParamsOneOf[T](opts...) + if err != nil { + return nil, err + } + return &schema.ToolInfo{ + Name: toolName, + Desc: toolDesc, + ParamsOneOf: *paramsOneOf, + }, nil +} + +func goStruct2ParamsOneOf[T any](opts ...Option) (*schema.ParamsOneOf, error) { + options := getToolOptions(opts...) + schemaCustomizer := defaultSchemaCustomizer + if options.sc != nil { + schemaCustomizer = options.sc + } + + sc, err := openapi3gen.NewSchemaRefForValue(generic.NewInstance[T](), nil, openapi3gen.SchemaCustomizer(schemaCustomizer)) + if err != nil { + return nil, fmt.Errorf("new SchemaRef from T failed: %w", err) + } + + paramsOneOf := schema.NewParamsOneOfByOpenAPIV3(sc.Value) + + return ¶msOneOf, nil + +} + +// NewTool Create a tool, where the input and output are both in JSON format. +func NewTool[T, D any](desc *schema.ToolInfo, i InvokeFunc[T, D], opts ...Option) tool.InvokableTool { + to := getToolOptions(opts...) + + return &invokableTool[T, D]{ + info: desc, + um: to.um, + m: to.m, + Fn: i, + } +} + +type invokableTool[T, D any] struct { + info *schema.ToolInfo + + um UnmarshalArguments + m MarshalOutput + + Fn InvokeFunc[T, D] +} + +func (i *invokableTool[T, D]) Info(ctx context.Context) (*schema.ToolInfo, error) { + return i.info, nil +} + +// InvokableRun invokes the tool with the given arguments. +func (i *invokableTool[T, D]) InvokableRun(ctx context.Context, arguments string, opts ...tool.Option) (output string, err error) { + + var inst T + if i.um != nil { + var val interface{} + val, err = i.um(ctx, arguments) + if err != nil { + return "", fmt.Errorf("[LocalFunc] failed to unmarshal arguments: %w", err) + } + gt, ok := val.(T) + if !ok { + return "", fmt.Errorf("[LocalFunc] expected %T, but given %T", inst, val) + } + inst = gt + } else { + inst = generic.NewInstance[T]() + + err = sonic.UnmarshalString(arguments, &inst) + if err != nil { + return "", fmt.Errorf("[LocalFunc] failed to unmarshal arguments in json: %w", err) + } + } + + resp, err := i.Fn(ctx, inst) + if err != nil { + return "", fmt.Errorf("[LocalFunc] failed to invoke tool: %w", err) + } + + if i.m != nil { + output, err = i.m(ctx, resp) + if err != nil { + return "", fmt.Errorf("[LocalFunc] failed to marshal output: %w", err) + } + } else { + output, err = sonic.MarshalString(resp) + if err != nil { + return "", fmt.Errorf("[LocalFunc] failed to marshal output in json: %w", err) + } + } + + return output, nil +} + +func (i *invokableTool[T, D]) GetType() string { + return snakeToCamel(i.info.Name) +} + +// snakeToCamel converts a snake_case string to CamelCase. +func snakeToCamel(s string) string { + if s == "" { + return "" + } + + parts := strings.Split(s, "_") + + for i := 0; i < len(parts); i++ { + if len(parts[i]) > 0 { + parts[i] = strings.ToUpper(string(parts[i][0])) + strings.ToLower(parts[i][1:]) + } + } + + return strings.Join(parts, "") +} diff --git a/components/tool/utils/invokable_func_test.go b/components/tool/utils/invokable_func_test.go new file mode 100644 index 0000000..612c334 --- /dev/null +++ b/components/tool/utils/invokable_func_test.go @@ -0,0 +1,275 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "context" + "fmt" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +type Job struct { + Company string `json:"company" jsonschema:"description=the company where the user works"` + Position string `json:"position,omitempty" jsonschema:"description=the position of the user's job"` + ServiceLength float32 `json:"service_length,omitempty" jsonschema:"description=the year of user's service"` // 司龄,年 +} + +type Income struct { + Source string `json:"source" jsonschema:"description=the source of income"` + Amount int `json:"amount" jsonschema:"description=the amount of income"` + HasPayTax bool `json:"has_pay_tax" jsonschema:"description=whether the user has paid tax"` + Job *Job `json:"job,omitempty" jsonschema:"description=the job of the user when earning this income"` +} + +type User struct { + Name string `json:"name" jsonschema:"required,description=the name of the user"` + Age int `json:"age" jsonschema:"required,description=the age of the user"` + + Job *Job `json:"job,omitempty" jsonschema:"description=the job of the user"` + + Incomes []*Income `json:"incomes" jsonschema:"description=the incomes of the user"` +} + +type UserResult struct { + Code int `json:"code"` + Msg string `json:"msg"` +} + +var toolInfo = &schema.ToolInfo{ + Name: "update_user_info", + Desc: "full update user info", + ParamsOneOf: schema.NewParamsOneOfByOpenAPIV3( + &openapi3.Schema{ + Type: openapi3.TypeObject, + Required: []string{"age", "incomes", "name"}, + Properties: openapi3.Schemas{ + "name": { + Value: &openapi3.Schema{ + Type: openapi3.TypeString, + Description: "the name of the user", + }, + }, + "age": { + Value: &openapi3.Schema{ + Type: openapi3.TypeInteger, + Description: "the age of the user", + }, + }, + "job": { + Value: &openapi3.Schema{ + Type: openapi3.TypeObject, + Description: "the job of the user", + Required: []string{"company"}, + // Nullable: true, + Properties: openapi3.Schemas{ + "company": { + Value: &openapi3.Schema{ + Type: openapi3.TypeString, + Description: "the company where the user works", + }, + }, + "service_length": { + Value: &openapi3.Schema{ + Type: openapi3.TypeNumber, + Description: "the year of user's service", + Format: "float", + }, + }, + "position": { + Value: &openapi3.Schema{ + Type: openapi3.TypeString, + Description: "the position of the user's job", + }, + }, + }, + }, + }, + "incomes": { + Value: &openapi3.Schema{ + Type: openapi3.TypeArray, + Description: "the incomes of the user", + Items: &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: openapi3.TypeObject, + Required: []string{"amount", "has_pay_tax", "source"}, + Description: "the incomes of the user", + // Nullable: true, + Properties: openapi3.Schemas{ + "source": { + Value: &openapi3.Schema{ + Type: openapi3.TypeString, + Description: "the source of income", + }, + }, + "amount": { + Value: &openapi3.Schema{ + Type: openapi3.TypeInteger, + Description: "the amount of income", + }, + }, + "has_pay_tax": { + Value: &openapi3.Schema{ + Type: openapi3.TypeBoolean, + Description: "whether the user has paid tax", + }, + }, + "job": { + Value: &openapi3.Schema{ + Type: openapi3.TypeObject, + Description: "the job of the user when earning this income", + Required: []string{"company"}, + // Nullable: true, + Properties: openapi3.Schemas{ + "company": { + Value: &openapi3.Schema{ + Type: openapi3.TypeString, + Description: "the company where the user works", + }, + }, + "service_length": { + Value: &openapi3.Schema{ + Type: openapi3.TypeNumber, + Description: "the year of user's service", + Format: "float", + }, + }, + "position": { + Value: &openapi3.Schema{ + Type: openapi3.TypeString, + Description: "the position of the user's job", + }, + }, + }, + }, + }, + }, + AdditionalProperties: openapi3.AdditionalProperties{}, + }, + }, + }, + }, + }, + AdditionalProperties: openapi3.AdditionalProperties{}, + }), +} + +func updateUserInfo(ctx context.Context, input *User) (output *UserResult, err error) { + return &UserResult{ + Code: 200, + Msg: fmt.Sprintf("update %v success", input.Name), + }, nil +} + +func TestInferTool(t *testing.T) { + t.Run("invoke_infer_tool", func(t *testing.T) { + ctx := context.Background() + + tl, err := InferTool("update_user_info", "full update user info", updateUserInfo) + assert.NoError(t, err) + + info, err := tl.Info(context.Background()) + assert.NoError(t, err) + assert.Equal(t, toolInfo, info) + + content, err := tl.InvokableRun(ctx, `{"name": "bruce lee"}`) + assert.NoError(t, err) + assert.JSONEq(t, `{"code":200,"msg":"update bruce lee success"}`, content) + }) + +} + +func TestNewTool(t *testing.T) { + ctx := context.Background() + type Input struct { + Name string `json:"name"` + } + type Output struct { + Name string `json:"name"` + } + + t.Run("struct_input_struct_output", func(t *testing.T) { + + tl := NewTool[Input, Output](nil, func(ctx context.Context, input Input) (output Output, err error) { + return Output{ + Name: input.Name, + }, nil + }) + + _, err := tl.InvokableRun(ctx, `{"name":"test"}`) + assert.Nil(t, err) + }) + + t.Run("pointer_input_pointer_output", func(t *testing.T) { + tl := NewTool[*Input, *Output](nil, func(ctx context.Context, input *Input) (output *Output, err error) { + return &Output{ + Name: input.Name, + }, nil + }) + + content, err := tl.InvokableRun(ctx, `{"name":"test"}`) + assert.NoError(t, err) + assert.Equal(t, `{"name":"test"}`, content) + }) + + t.Run("string_input_int64_output", func(t *testing.T) { + tl := NewTool(nil, func(ctx context.Context, input string) (output int64, err error) { + return 10, nil + }) + + content, err := tl.InvokableRun(ctx, `100`) // json unmarshal must contains double quote if is not json string. + assert.Error(t, err) + assert.Equal(t, "", content) + }) + + t.Run("string_pointer_input_int64_pointer_output", func(t *testing.T) { + tl := NewTool[*string, *int64](nil, func(ctx context.Context, input *string) (output *int64, err error) { + n := int64(10) + return &n, nil + }) + + content, err := tl.InvokableRun(ctx, `"100"`) + assert.NoError(t, err) + assert.Equal(t, `10`, content) + }) +} + +func TestSnakeToCamel(t *testing.T) { + t.Run("normal_case", func(t *testing.T) { + assert.Equal(t, "GoogleSearch3", snakeToCamel("google_search_3")) + }) + + t.Run("empty_case", func(t *testing.T) { + assert.Equal(t, "", snakeToCamel("")) + }) + + t.Run("single_word_case", func(t *testing.T) { + assert.Equal(t, "Google", snakeToCamel("google")) + }) + + t.Run("upper_case", func(t *testing.T) { + assert.Equal(t, "HttpHost", snakeToCamel("_HTTP_HOST_")) + }) + + t.Run("underscore_case", func(t *testing.T) { + assert.Equal(t, "", snakeToCamel("_")) + }) +} diff --git a/components/tool/utils/streamable_func.go b/components/tool/utils/streamable_func.go new file mode 100644 index 0000000..fa53fba --- /dev/null +++ b/components/tool/utils/streamable_func.go @@ -0,0 +1,128 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "context" + "fmt" + + "github.com/bytedance/sonic" + + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +// StreamFunc is the function type for the streamable tool. +type StreamFunc[T, D any] func(ctx context.Context, input T) (output *schema.StreamReader[D], err error) + +// InferStreamTool creates an StreamableTool from a given function by inferring the ToolInfo from the function's request parameters +// End-user can pass a SchemaCustomizerFn in opts to customize the go struct tag parsing process, overriding default behavior. +func InferStreamTool[T, D any](toolName, toolDesc string, s StreamFunc[T, D], opts ...Option) (tool.StreamableTool, error) { + ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...) + if err != nil { + return nil, err + } + + return NewStreamTool(ti, s, opts...), nil +} + +// NewStreamTool Create a streaming tool, where the input and output are both in JSON format. +// convert: convert the stream frame to string that could be concatenated to a string. +func NewStreamTool[T, D any](desc *schema.ToolInfo, s StreamFunc[T, D], opts ...Option) tool.StreamableTool { + + to := getToolOptions(opts...) + + return &streamableTool[T, D]{ + info: desc, + + um: to.um, + m: to.m, + Fn: s, + } +} + +type streamableTool[T, D any] struct { + info *schema.ToolInfo + + um UnmarshalArguments + m MarshalOutput + + Fn StreamFunc[T, D] +} + +// Info returns the tool info, implement the BaseTool interface. +func (s *streamableTool[T, D]) Info(ctx context.Context) (*schema.ToolInfo, error) { + return s.info, nil +} + +// StreamableRun invokes the tool with the given arguments, implement the StreamableTool interface. +func (s *streamableTool[T, D]) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) ( + outStream *schema.StreamReader[string], err error) { + + var inst T + if s.um != nil { + var val interface{} + val, err = s.um(ctx, argumentsInJSON) + if err != nil { + return nil, fmt.Errorf("[LocalStreamFunc] failed to unmarshal arguments: %w", err) + } + + gt, ok := val.(T) + if !ok { + return nil, fmt.Errorf("[LocalStreamFunc] expected %T, but given %T", inst, val) + } + inst = gt + } else { + + inst = generic.NewInstance[T]() + + err = sonic.UnmarshalString(argumentsInJSON, &inst) + if err != nil { + return nil, fmt.Errorf("[LocalStreamFunc] failed to unmarshal arguments in json: %w", err) + } + } + + streamD, err := s.Fn(ctx, inst) + if err != nil { + return nil, err + } + + outStream = schema.StreamReaderWithConvert(streamD, func(d D) (string, error) { + var out string + var e error + if s.m != nil { + out, e = s.m(ctx, d) + if e != nil { + return "", fmt.Errorf("[LocalStreamFunc] failed to marshal output: %w", e) + } + } else { + out, e = sonic.MarshalString(d) + if e != nil { + return "", fmt.Errorf("[LocalStreamFunc] failed to marshal output in json: %w", e) + } + } + + return out, nil + }) + + return outStream, nil +} + +func (s *streamableTool[T, D]) GetType() string { + return snakeToCamel(s.info.Name) +} diff --git a/components/tool/utils/streamable_func_test.go b/components/tool/utils/streamable_func_test.go new file mode 100644 index 0000000..d388416 --- /dev/null +++ b/components/tool/utils/streamable_func_test.go @@ -0,0 +1,99 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestNewStreamableTool(t *testing.T) { + ctx := context.Background() + type Input struct { + Name string `json:"name"` + } + type Output struct { + Name string `json:"name"` + } + + t.Run("simple_case", func(t *testing.T) { + tl := NewStreamTool[*Input, *Output]( + &schema.ToolInfo{ + Name: "search_user", + Desc: "search user info", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "name": { + Type: "string", + Desc: "user name", + }, + }), + }, + func(ctx context.Context, input *Input) (output *schema.StreamReader[*Output], err error) { + sr, sw := schema.Pipe[*Output](2) + sw.Send(&Output{ + Name: input.Name, + }, nil) + sw.Send(&Output{ + Name: "lee", + }, nil) + sw.Close() + + return sr, nil + }, + ) + + info, err := tl.Info(ctx) + assert.NoError(t, err) + assert.Equal(t, "search_user", info.Name) + assert.Equal(t, map[string]*schema.ParameterInfo{ + "name": { + Type: "string", + Desc: "user name", + }, + }, info.Params) + + sr, err := tl.StreamableRun(ctx, `{"name":"xxx"}`) + assert.NoError(t, err) + + defer sr.Close() + + idx := 0 + for { + m, err := sr.Recv() + if errors.Is(err, io.EOF) { + break + } + assert.NoError(t, err) + + if idx == 0 { + assert.Equal(t, `{"name":"xxx"}`, m) + } else { + assert.Equal(t, `{"name":"lee"}`, m) + } + idx++ + } + + assert.Equal(t, 2, idx) + }) +} diff --git a/components/types.go b/components/types.go new file mode 100644 index 0000000..03c246c --- /dev/null +++ b/components/types.go @@ -0,0 +1,63 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// components are the basic components supported by eino. +package components + +// Typer get the type name of one component's implementation +// if Typer exists, the full name of the component instance will be {Typer}{Component} by default +// recommend using Camel Case Naming Style for Typer +type Typer interface { + GetType() string +} + +func GetType(component any) (string, bool) { + if typer, ok := component.(Typer); ok { + return typer.GetType(), true + } + + return "", false +} + +// Checker tells callback aspect status of component's implementation +// When the Checker interface is implemented and returns true, the framework will not start the default aspect. +// Instead, the component will decide the callback execution location and the information to be injected. +type Checker interface { + IsCallbacksEnabled() bool +} + +func IsCallbacksEnabled(i any) bool { + if checker, ok := i.(Checker); ok { + return checker.IsCallbacksEnabled() + } + + return false +} + +// Component the name of different kinds of components +type Component string + +const ( + ComponentOfPrompt Component = "ChatTemplate" + ComponentOfChatModel Component = "ChatModel" + ComponentOfEmbedding Component = "Embedding" + ComponentOfIndexer Component = "Indexer" + ComponentOfRetriever Component = "Retriever" + ComponentOfLoaderSplitter Component = "LoaderSplitter" + ComponentOfLoader Component = "Loader" + ComponentOfTransformer Component = "DocumentTransformer" + ComponentOfTool Component = "Tool" +) diff --git a/compose/chain.go b/compose/chain.go new file mode 100644 index 0000000..1fd70a3 --- /dev/null +++ b/compose/chain.go @@ -0,0 +1,600 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "errors" + "fmt" + "reflect" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/internal/gmap" + "github.com/cloudwego/eino/internal/gslice" + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +// NewChain create a chain with input/output type. +func NewChain[I, O any]() *Chain[I, O] { + ch := &Chain[I, O]{ + gg: NewGraph[I, O](), + } + + ch.gg.graph.addNodeChecker = nodeCheckerOfForbidProcessor(baseNodeChecker) + ch.gg.graph.runtimeGraphKey = defaultGraphKey() + + return ch +} + +// Chain is a chain of components. +// Chain nodes can be parallel / branch / sequence components. +// Chain is designed to be used in a builder pattern (should Compile() before use). +// And the interface is `Chain style`, you can use it like: `chain.AppendXX(...).AppendXX(...)` +// +// Normal usage: +// 1. create a chain with input/output type: `chain := NewChain[inputType, outputType]()` +// 2. add components to chainable list: +// 2.1 add components: `chain.AppendChatTemplate(...).AppendChatModel(...).AppendToolsNode(...)` +// 2.2 add parallel or branch node if needed: `chain.AppendParallel()`, `chain.AppendBranch()` +// 3. compile: `r, err := c.Compile()` +// 4. run: +// 4.1 `one input & one output` use `r.Invoke(ctx, input)` +// 4.2 `one input & multi output chunk` use `r.Stream(ctx, input)` +// 4.3 `multi input chunk & one output` use `r.Collect(ctx, inputReader)` +// 4.4 `multi input chunk & multi output chunk` use `r.Transform(ctx, inputReader)` +// +// Using in graph or other chain: +// chain1 := NewChain[inputType, outputType]() +// graph := NewGraph[](runTypePregel) +// graph.AddGraph("key", chain1) // chain is an AnyGraph implementation +// +// // or in another chain: +// chain2 := NewChain[inputType, outputType]() +// chain2.AppendGraph(chain1) +type Chain[I, O any] struct { + err error + + gg *Graph[I, O] + + namePrefix string + nodeIdx int + + preNodeKeys []string +} + +// implements AnyGraph. +func (c *Chain[I, O]) compile(ctx context.Context, option *graphCompileOptions) (*composableRunnable, error) { + if c.err != nil { + return nil, c.err + } + + if !c.gg.isFrozen() { + err := c.addEnds() + if err != nil { + return nil, err + } + } + c.gg.compileChecker = wrapCompileChecker(c.gg.compileChecker, func(options *graphCompileOptions) error { + if len(option.nodeTriggerMode) != 0 && option.nodeTriggerMode != AnyPredecessor { + return errors.New("only support AnyPredecessor in chain") // dag not support branch + } + + return nil + }) + + return c.gg.compile(ctx, option) +} + +// addEnds add END edge of the chain/graph. +// only run once when compiling. +func (c *Chain[I, O]) addEnds() error { + if len(c.preNodeKeys) == 0 { + return fmt.Errorf("pre node keys not set, number of nodes in chain= %d", len(c.gg.nodes)) + } + + for _, nodeKey := range c.preNodeKeys { + err := c.gg.AddEdge(nodeKey, END) + if err != nil { + return err + } + } + + return nil +} + +// inputType returns the input type of the chain. +// implements AnyGraph. +func (c *Chain[I, O]) inputType() reflect.Type { + return generic.TypeOf[I]() +} + +// outputType returns the output type of the chain. +// implements AnyGraph. +func (c *Chain[I, O]) outputType() reflect.Type { + return generic.TypeOf[O]() +} + +// compositeType returns the composite type of the chain. +// implements AnyGraph. +func (c *Chain[I, O]) component() component { + return ComponentOfChain +} + +// Compile to a Runnable. +// Runnable can be used directly. +// e.g. +// +// chain := NewChain[string, string]() +// r, err := chain.Compile() +// if err != nil {} +// +// r.Invoke(ctx, input) // ping => pong +// r.Stream(ctx, input) // ping => stream out +// r.Collect(ctx, inputReader) // stream in => pong +// r.Transform(ctx, inputReader) // stream in => stream out +func (c *Chain[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { + if c.err != nil { + return nil, c.err + } + + opts = append(opts, withComponent(ComponentOfChain)) + + if !c.gg.isFrozen() { + err := c.addEnds() + if err != nil { + return nil, err + } + } + + c.gg.compileChecker = wrapCompileChecker(c.gg.compileChecker, func(options *graphCompileOptions) error { + if len(options.nodeTriggerMode) != 0 && options.nodeTriggerMode != AnyPredecessor { + return errors.New("only support AnyPredecessor in chain") // dag not support branch + } + + return nil + }) + + tr, err := c.gg.Compile(ctx, opts...) + if err != nil { + return nil, err + } + + return tr, nil +} + +// AppendChatModel add a ChatModel node to the chain. +// e.g. +// +// model, err := openai.NewChatModel(ctx, config) +// if err != nil {...} +// chain.AppendChatModel(model) +func (c *Chain[I, O]) AppendChatModel(node model.ChatModel, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toChatModelNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendChatTemplate add a ChatTemplate node to the chain. +// eg. +// +// chatTemplate, err := prompt.FromMessages(schema.FString, &schema.Message{ +// Role: schema.System, +// Content: "You are acting as a {role}.", +// }) +// +// chain.AppendChatTemplate(chatTemplate) +func (c *Chain[I, O]) AppendChatTemplate(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toChatTemplateNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendToolsNode add a ToolsNode node to the chain. +// e.g. +// +// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ +// Tools: []tools.Tool{...}, +// }) +// +// chain.AppendToolsNode(toolsNode) +func (c *Chain[I, O]) AppendToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toToolsNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendDocumentTransformer add a DocumentTransformer node to the chain. +// e.g. +// +// markdownSplitter, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderSplitterConfig{}) +// +// chain.AppendDocumentTransformer(markdownSplitter) +func (c *Chain[I, O]) AppendDocumentTransformer(node document.Transformer, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toDocumentTransformerNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendLambda add a Lambda node to the chain. +// Lambda is a node that can be used to implement custom logic. +// e.g. +// +// lambdaNode := compose.InvokableLambda(func(ctx context.Context, docs []*schema.Document) (string, error) {...}) +// chain.AppendLambda(lambdaNode) +// +// Note: +// to create a Lambda node, you need to use `compose.AnyLambda` or `compose.InvokableLambda` or `compose.StreamableLambda` or `compose.TransformableLambda`. +// if you want this node has real stream output, you need to use `compose.StreamableLambda` or `compose.TransformableLambda`, for example. +func (c *Chain[I, O]) AppendLambda(node *Lambda, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toLambdaNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendEmbedding add a Embedding node to the chain. +// e.g. +// +// embedder, err := openai.NewEmbedder(ctx, config) +// if err != nil {...} +// chain.AppendEmbedding(embedder) +func (c *Chain[I, O]) AppendEmbedding(node embedding.Embedder, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toEmbeddingNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendRetriever add a Retriever node to the chain. +// e.g. +// +// retriever, err := vectorstore.NewRetriever(ctx, config) +// if err != nil {...} +// chain.AppendRetriever(retriever) +// +// or using fornax knowledge as retriever: +// +// config := fornaxknowledge.Config{...} +// retriever, err := fornaxknowledge.NewKnowledgeRetriever(ctx, config) +// if err != nil {...} +// chain.AppendRetriever(retriever) +func (c *Chain[I, O]) AppendRetriever(node retriever.Retriever, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toRetrieverNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendLoaderSplitter add a LoaderSplitter node to the chain. +// Deprecated: use AppendLoader instead. +func (c *Chain[I, O]) AppendLoaderSplitter(node document.LoaderSplitter, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toLoaderSplitterNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendLoader adds a Loader node to the chain. +// e.g. +// +// loader, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{}) +// if err != nil {...} +// chain.AppendLoader(loader) +func (c *Chain[I, O]) AppendLoader(node document.Loader, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toLoaderNode(node, opts...) + c.addNode(n) + return c +} + +// AppendIndexer add an Indexer node to the chain. +// Indexer is a node that can store documents. +// e.g. +// +// vectorStoreImpl, err := vikingdb.NewVectorStorer(ctx, vikingdbConfig) // in components/vectorstore/vikingdb/vectorstore.go +// if err != nil {...} +// +// config := vectorstore.IndexerConfig{VectorStore: vectorStoreImpl} +// indexer, err := vectorstore.NewIndexer(ctx, config) +// if err != nil {...} +// +// chain.AppendIndexer(indexer) +func (c *Chain[I, O]) AppendIndexer(node indexer.Indexer, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toIndexerNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendBranch add a conditional branch to chain. +// Each branch within the ChainBranch can be an AnyGraph. +// All branches should either lead to END, or converge to another node within the Chain. +// e.g. +// +// cb := compose.NewChainBranch(conditionFunc) +// cb.AddChatTemplate("chat_template_key_01", chatTemplate) +// cb.AddChatTemplate("chat_template_key_02", chatTemplate2) +// chain.AppendBranch(cb) +func (c *Chain[I, O]) AppendBranch(b *ChainBranch) *Chain[I, O] { // nolint: byted_s_too_many_lines_in_func + if b == nil { + c.reportError(fmt.Errorf("append branch invalid, branch is nil")) + return c + } + + if b.err != nil { + c.reportError(fmt.Errorf("append branch error: %w", b.err)) + return c + } + + if len(b.key2BranchNode) == 0 { + c.reportError(fmt.Errorf("append branch invalid, nodeList is empty")) + return c + } + + if len(b.key2BranchNode) == 1 { + c.reportError(fmt.Errorf("append branch invalid, nodeList length = 1")) + return c + } + + var startNode string + if len(c.preNodeKeys) == 0 { // branch appended directly to START + startNode = START + } else if len(c.preNodeKeys) == 1 { + startNode = c.preNodeKeys[0] + } else { + c.reportError(fmt.Errorf("append branch invalid, multiple previous nodes: %v ", c.preNodeKeys)) + return c + } + + pName := c.nextNodeKey("Branch") + key2NodeKey := make(map[string]string, len(b.key2BranchNode)) + + for key := range b.key2BranchNode { + node := b.key2BranchNode[key] + nodeKey := fmt.Sprintf("%s[%s]_%s", pName, key, node.getNodeName()) + if err := c.gg.addNode(nodeKey, node); err != nil { + c.reportError(fmt.Errorf("add branch node[%s] to chain failed: %w", nodeKey, err)) + return c + } + + key2NodeKey[key] = nodeKey + } + + condition := &composableRunnable{ + i: b.condition.i, + t: b.condition.t, + inputType: b.condition.inputType, + inputStreamFilter: b.condition.inputStreamFilter, + outputType: b.condition.outputType, + optionType: b.condition.optionType, + isPassthrough: b.condition.isPassthrough, + meta: b.condition.meta, + nodeInfo: b.condition.nodeInfo, + } + + invokeCon := func(ctx context.Context, in any, opts ...any) (endNode any, err error) { + endKey, err := b.condition.i(ctx, in, opts...) + if err != nil { + return "", err + } + + endStr, ok := endKey.(string) + if !ok { + return "", fmt.Errorf("chain branch result not string, got: %T", endKey) + } + + nodeKey, ok := key2NodeKey[endStr] + if !ok { + return "", fmt.Errorf("chain branch result not in added keys: %s", endStr) + } + + return nodeKey, nil + } + condition.i = invokeCon + + transformCon := func(ctx context.Context, sr streamReader, opts ...any) (streamReader, error) { + iEndStream, err := b.condition.t(ctx, sr, opts...) + if err != nil { + return nil, err + } + + if iEndStream.getChunkType() != reflect.TypeOf("") { + return nil, fmt.Errorf("chain branch result not string, got: %v", iEndStream.getChunkType()) + } + + endStream, ok := unpackStreamReader[string](iEndStream) + if !ok { + return nil, fmt.Errorf("unpack stream reader not ok") + } + + endStr, err := concatStreamReader(endStream) + if err != nil { + return nil, err + } + + nodeKey, ok := key2NodeKey[endStr] + if !ok { + return nil, fmt.Errorf("chain branch result not in added keys: %s", endStr) + } + + return packStreamReader(schema.StreamReaderFromArray([]string{nodeKey})), nil + } + condition.t = transformCon + + gBranch := &GraphBranch{ + condition: condition, + endNodes: gslice.ToMap(gmap.Values(key2NodeKey), func(k string) (string, bool) { + return k, true + }), + } + + if err := c.gg.AddBranch(startNode, gBranch); err != nil { + c.reportError(fmt.Errorf("chain append branch failed: %w", err)) + return c + } + + c.preNodeKeys = gmap.Values(key2NodeKey) + + return c +} + +// AppendParallel add a Parallel structure (multiple concurrent nodes) to the chain. +// e.g. +// +// parallel := compose.NewParallel() +// parallel.AddChatModel("openai", model1) // => "openai": *schema.Message{} +// parallel.AddChatModel("maas", model2) // => "maas": *schema.Message{} +// +// chain.AppendParallel(parallel) // => multiple concurrent nodes are added to the Chain +// +// The next node in the chain is either an END, or a node which accepts a map[string]any, where keys are `openai` `maas` as specified above. +func (c *Chain[I, O]) AppendParallel(p *Parallel) *Chain[I, O] { + if p == nil { + c.reportError(fmt.Errorf("append parallel invalid, parallel is nil")) + return c + } + + if p.err != nil { + c.reportError(fmt.Errorf("append parallel invalid, parallel error: %w", p.err)) + return c + } + + if len(p.nodes) <= 1 { + c.reportError(fmt.Errorf("append parallel invalid, not enough nodes, count = %d", len(p.nodes))) + return c + } + + var startNode string + if len(c.preNodeKeys) == 0 { // parallel appended directly to START + startNode = START + } else if len(c.preNodeKeys) == 1 { + startNode = c.preNodeKeys[0] + } else { + c.reportError(fmt.Errorf("append parallel invalid, multiple previous nodes: %v ", c.preNodeKeys)) + return c + } + + pName := c.nextNodeKey("Parallel") + var nodeKeys []string + + for i := range p.nodes { + node := p.nodes[i] + nodeKey := fmt.Sprintf("%s[%d]_%s", pName, i, node.getNodeName()) + if err := c.gg.addNode(nodeKey, node); err != nil { + c.reportError(fmt.Errorf("add parallel node[%s] to chain failed: %w", nodeKey, err)) + return c + } + if err := c.gg.AddEdge(startNode, nodeKey); err != nil { + c.reportError(fmt.Errorf("add parallel edge[%s]-[%s] to chain failed: %w", startNode, nodeKey, err)) + return c + } + nodeKeys = append(nodeKeys, nodeKey) + } + + c.preNodeKeys = nodeKeys + + return c +} + +// AppendGraph add a AnyGraph node to the chain. +// AnyGraph can be a chain or a graph. +// e.g. +// +// graph := compose.NewGraph[string, string]() +// chain.AppendGraph(graph) +func (c *Chain[I, O]) AppendGraph(node AnyGraph, opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toAnyGraphNode(node, opts...) + + c.addNode(n) + return c +} + +// AppendPassthrough add a Passthrough node to the chain. +// Could be used to connect multiple ChainBranch or Parallel. +// e.g. +// +// chain.AppendPassthrough() +func (c *Chain[I, O]) AppendPassthrough(opts ...GraphAddNodeOpt) *Chain[I, O] { + n := toPassthroughNode(opts...) + + c.addNode(n) + return c +} + +// nextNodeKey. +// get the next node key for the chain. +// e.g. "Chain[1]_ChatModel" => represent the second node of the chain, and is a ChatModel node. +// e.g. "Chain[2]_NameByUser" => represent the third node of the chain, and the node name is set by user of `NameByUser`. +func (c *Chain[I, O]) nextNodeKey(name string) string { + if c.namePrefix == "" { + c.namePrefix = string(ComponentOfChain) + } + fullKey := fmt.Sprintf("%s[%d]_%s", c.namePrefix, c.nodeIdx, name) + c.nodeIdx++ + return fullKey +} + +// reportError. +// save the first error in the chain. +func (c *Chain[I, O]) reportError(err error) { + if c.err == nil { + c.err = err + } +} + +// addNode. +// add a node to the chain. +func (c *Chain[I, O]) addNode(node *graphNode) { + if c.err != nil { + return + } + + if node == nil { + c.reportError(fmt.Errorf("chain add node invalid, node is nil")) + return + } + + nodeKey := c.nextNodeKey(node.getNodeName()) + if node.nodeInfo.key != "" { + nodeKey = node.nodeInfo.key + } + err := c.gg.addNode(nodeKey, node) + c.reportError(err) + + if len(c.preNodeKeys) == 0 { + c.preNodeKeys = append(c.preNodeKeys, START) + } + + for _, preNodeKey := range c.preNodeKeys { + err := c.gg.AddEdge(preNodeKey, nodeKey) + if err != nil { + c.reportError(err) + return + } + } + + c.preNodeKeys = []string{nodeKey} +} diff --git a/compose/chain_branch.go b/compose/chain_branch.go new file mode 100644 index 0000000..c42989c --- /dev/null +++ b/compose/chain_branch.go @@ -0,0 +1,251 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" +) + +// ChainBranch represents a conditional branch in a chain of operations. +// It allows for dynamic routing of execution based on a condition. +// All branches within ChainBranch are expected to either end the Chain, or converge to another node in the Chain. +type ChainBranch struct { + key2BranchNode map[string]*graphNode + condition *composableRunnable + err error +} + +// NewChainBranch creates a new ChainBranch instance based on a given condition. +// It takes a generic type T and a GraphBranchCondition function for that type. +// The returned ChainBranch will have an empty key2BranchNode map and a condition function +// that wraps the provided cond to handle type assertions and error checking. +// eg. +// +// condition := func(ctx context.Context, in string, opts ...any) (endNode string, err error) { +// // logic to determine the next node +// return "some_next_node_key", nil +// } +// +// cb := NewChainBranch[string](condition) +// cb.AddPassthrough("next_node_key_01", xxx) // node in branch, represent one path of branch +// cb.AddPassthrough("next_node_key_02", xxx) // node in branch +func NewChainBranch[T any](cond GraphBranchCondition[T]) *ChainBranch { + invokeCond := func(ctx context.Context, in T, opts ...any) (endNode string, err error) { + return cond(ctx, in) + } + + return &ChainBranch{ + key2BranchNode: make(map[string]*graphNode), + condition: runnableLambda(invokeCond, nil, nil, nil, false), + } +} + +// NewStreamChainBranch creates a new ChainBranch instance based on a given stream condition. +// It takes a generic type T and a StreamGraphBranchCondition function for that type. +// The returned ChainBranch will have an empty key2BranchNode map and a condition function +// that wraps the provided cond to handle type assertions and error checking. +// eg. +// +// condition := func(ctx context.Context, in *schema.StreamReader[string], opts ...any) (endNode string, err error) { +// // logic to determine the next node, you can read the stream and make a decision. +// // to save time, usually read the first chunk of stream, then make a decision which path to go. +// return "some_next_node_key", nil +// } +// +// cb := NewStreamChainBranch[string](condition) +func NewStreamChainBranch[T any](cond StreamGraphBranchCondition[T]) *ChainBranch { + collectCon := func(ctx context.Context, in *schema.StreamReader[T], opts ...any) (endNode string, err error) { + return cond(ctx, in) + } + + return &ChainBranch{ + key2BranchNode: make(map[string]*graphNode), + condition: runnableLambda(nil, nil, collectCon, nil, false), + } +} + +// AddChatModel adds a ChatModel node to the branch. +// eg. +// +// chatModel01, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ +// Model: "gpt-4o", +// }) +// chatModel02, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ +// Model: "gpt-4o-mini", +// }) +// cb.AddChatModel("chat_model_key_01", chatModel01) +// cb.AddChatModel("chat_model_key_02", chatModel02) +func (cb *ChainBranch) AddChatModel(key string, node model.ChatModel, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toChatModelNode(node, opts...)) +} + +// AddChatTemplate adds a ChatTemplate node to the branch. +// eg. +// +// chatTemplate, err := prompt.FromMessages(schema.FString, &schema.Message{ +// Role: schema.System, +// Content: "You are acting as a {role}.", +// }) +// +// cb.AddChatTemplate("chat_template_key_01", chatTemplate) +// +// chatTemplate2, err := prompt.FromMessages(schema.FString, &schema.Message{ +// Role: schema.System, +// Content: "You are acting as a {role}, you are not allowed to chat in other topics.", +// }) +// +// cb.AddChatTemplate("chat_template_key_02", chatTemplate2) +func (cb *ChainBranch) AddChatTemplate(key string, node prompt.ChatTemplate, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toChatTemplateNode(node, opts...)) +} + +// AddToolsNode adds a ToolsNode to the branch. +// eg. +// +// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ +// Tools: []tools.Tool{...}, +// }) +// +// cb.AddToolsNode("tools_node_key", toolsNode) +func (cb *ChainBranch) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toToolsNode(node, opts...)) +} + +// AddLambda adds a Lambda node to the branch. +// eg. +// +// lambdaFunc := func(ctx context.Context, in string, opts ...any) (out string, err error) { +// // logic to process the input +// return "processed_output", nil +// } +// +// cb.AddLambda("lambda_node_key", compose.InvokeLambda(lambdaFunc)) +func (cb *ChainBranch) AddLambda(key string, node *Lambda, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toLambdaNode(node, opts...)) +} + +// AddEmbedding adds an Embedding node to the branch. +// eg. +// +// embeddingNode, err := openai.NewEmbedder(ctx, &openai.EmbeddingConfig{ +// Model: "text-embedding-3-small", +// }) +// +// cb.AddEmbedding("embedding_node_key", embeddingNode) +func (cb *ChainBranch) AddEmbedding(key string, node embedding.Embedder, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toEmbeddingNode(node, opts...)) +} + +// AddRetriever adds a Retriever node to the branch. +// eg. +// +// retriever, err := volc_vikingdb.NewRetriever(ctx, &volc_vikingdb.RetrieverConfig{ +// Collection: "my_collection", +// }) +// +// cb.AddRetriever("retriever_node_key", retriever) +func (cb *ChainBranch) AddRetriever(key string, node retriever.Retriever, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toRetrieverNode(node, opts...)) +} + +// AddLoaderSplitter adds a LoaderSplitter node to the branch. +// Deprecated: use AddLoader instead. +func (cb *ChainBranch) AddLoaderSplitter(key string, node document.LoaderSplitter, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toLoaderSplitterNode(node, opts...)) +} + +// AddLoader adds a Loader node to the branch. +// eg. +// +// pdfParser, err := pdf.NewPDFParser() +// loader, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{ +// Parser: pdfParser, +// }) +// +// cb.AddLoader("loader_node_key", loader) +func (cb *ChainBranch) AddLoader(key string, node document.Loader, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toLoaderNode(node, opts...)) +} + +// AddIndexer adds an Indexer node to the branch. +// eg. +// +// indexer, err := volc_vikingdb.NewIndexer(ctx, &volc_vikingdb.IndexerConfig{ +// Collection: "my_collection", +// }) +// +// cb.AddIndexer("indexer_node_key", indexer) +func (cb *ChainBranch) AddIndexer(key string, node indexer.Indexer, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toIndexerNode(node, opts...)) +} + +// AddDocumentTransformer adds an Document Transformer node to the branch. +// eg. +// +// markdownSplitter, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderSplitterConfig{}) +// +// cb.AddDocumentTransformer("document_transformer_node_key", markdownSplitter) +func (cb *ChainBranch) AddDocumentTransformer(key string, node document.Transformer, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toDocumentTransformerNode(node, opts...)) +} + +// AddGraph adds a generic Graph node to the branch. +// eg. +// +// graph, err := compose.NewGraph[string, string]() +// +// cb.AddGraph("graph_node_key", graph) +func (cb *ChainBranch) AddGraph(key string, node AnyGraph, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toAnyGraphNode(node, opts...)) +} + +// AddPassthrough adds a Passthrough node to the branch. +// eg. +// +// cb.AddPassthrough("passthrough_node_key") +func (cb *ChainBranch) AddPassthrough(key string, opts ...GraphAddNodeOpt) *ChainBranch { + return cb.addNode(key, toPassthroughNode(opts...)) +} + +func (cb *ChainBranch) addNode(key string, node *graphNode) *ChainBranch { + if cb.err != nil { + return cb + } + + if cb.key2BranchNode == nil { + cb.key2BranchNode = make(map[string]*graphNode) + } + + _, ok := cb.key2BranchNode[key] + if ok { + cb.err = fmt.Errorf("chain branch add node, duplicate branch node key= %s", key) + return cb + } + + cb.key2BranchNode[key] = node // nolint: byted_use_map_without_nilcheck + + return cb +} diff --git a/compose/chain_branch_test.go b/compose/chain_branch_test.go new file mode 100644 index 0000000..403595a --- /dev/null +++ b/compose/chain_branch_test.go @@ -0,0 +1,274 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "strings" + "testing" + "unicode/utf8" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/schema" +) + +func TestChainBranch(t *testing.T) { + cond := func(ctx context.Context, input string) (key string, err error) { + switch input { + case "one": + return "one_key", nil + case "two": + return "two_key", nil + case "three": + return "three_key", nil + default: + return "", fmt.Errorf("invalid input= %s", input) + } + } + + t.Run("nested chain", func(t *testing.T) { + inner := NewChain[string, string]() + inner.AppendBranch(NewChainBranch(cond). + AddLambda("one_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { + return in + in, nil + })). + AddLambda("two_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { + return in + in + in, nil + }))) + inner.AppendParallel(NewParallel(). + AddLambda("one_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { + return in + in, nil + })). + AddLambda("two_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { + return in + in + in, nil + }))) + + outter := NewChain[string, string]() + outter.AppendGraph(inner) + _, err := outter.Compile(context.Background()) + assert.Error(t, err) + }) + + t.Run("bad param", func(t *testing.T) { + c := NewChain[string, string]() + c.AppendBranch(nil) + assert.NotNil(t, c.err) + + c = NewChain[string, string]() + c.AppendBranch(NewChainBranch[string](nil)) + assert.NotNil(t, c.err) + + c = NewChain[string, string]() + c.AppendBranch(NewChainBranch(cond).AddChatTemplate("template", prompt.FromMessages(schema.FString, schema.SystemMessage("hello")))) + assert.NotNil(t, c.err) + + c = NewChain[string, string]() + c.AppendBranch(NewChainBranch(cond).AddChatTemplate("1", prompt.FromMessages(schema.FString)).AddChatTemplate("1", prompt.FromMessages(schema.FString))) + assert.NotNil(t, c.err) + }) + + t.Run("different Node types in branch", func(t *testing.T) { + c := NewChain[string, string]() + c.AppendBranch(NewChainBranch(cond). + AddChatTemplate("t", prompt.FromMessages(schema.FString)). + AddGraph("c", NewChain[string, string]())) + assert.NotNil(t, c.err) + }) + + t.Run("type mismatch", func(t *testing.T) { + c := NewChain[int, string]() + c.AppendBranch(NewChainBranch(cond). + AddLambda("one_key", InvokableLambda(func(ctx context.Context, in int) (output string, err error) { + return strconv.Itoa(in), nil + })). + AddLambda("two_key", InvokableLambda(func(ctx context.Context, in int) (output string, err error) { + return strconv.Itoa(in), nil + }))) + _, err := c.Compile(context.Background()) + assert.NotNil(t, err) + }) + + t.Run("invoke", func(t *testing.T) { + c := NewChain[string, string]() + c.AppendBranch(NewChainBranch(cond). + AddLambda("one_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { + return in + in, nil + })). + AddLambda("two_key", InvokableLambda(func(ctx context.Context, in string) (output string, err error) { + return in + in + in, nil + }))) + c.AppendLambda(InvokableLambda(func(ctx context.Context, in string) (output string, err error) { + return in + in, nil + })) + assert.Nil(t, c.err) + compiledChain, err := c.Compile(context.Background()) + assert.Nil(t, err) + + out, err := compiledChain.Invoke(context.Background(), "two") + assert.Nil(t, err) + assert.Equal(t, "twotwotwotwotwotwo", out) + + _, err = compiledChain.Invoke(context.Background(), "three") + assert.NotNil(t, err) + + _, err = compiledChain.Invoke(context.Background(), "four") + assert.NotNil(t, err) + }) + + t.Run("fake stream", func(t *testing.T) { + c := NewChain[string, string]() + c.AppendLambda(StreamableLambda(func(ctx context.Context, in string) (output *schema.StreamReader[string], err error) { + sr, sw := schema.Pipe[string](utf8.RuneCountInString(in)) + + go func() { + for _, field := range strings.Fields(in) { + sw.Send(field, nil) + } + sw.Close() + }() + + return sr, nil + })) + c.AppendBranch(NewChainBranch[string](cond).AddLambda("one_key", CollectableLambda(func(ctx context.Context, in *schema.StreamReader[string]) (output string, err error) { + defer in.Close() + for { + v, err := in.Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return "", err + } + + output += v + } + + return output + output, nil + })). + AddLambda("two_key", CollectableLambda(func(ctx context.Context, in *schema.StreamReader[string]) (output string, err error) { + defer in.Close() + for { + v, err := in.Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return "", err + } + + output += v + } + + return output + output + output, nil + }))) + + assert.Nil(t, c.err) + compiledChain, err := c.Compile(context.Background()) + assert.Nil(t, err) + + out, err := compiledChain.Invoke(context.Background(), "one") + assert.Nil(t, err) + assert.Equal(t, "oneone", out) + }) + + t.Run("real stream", func(t *testing.T) { + streamCon := func(ctx context.Context, sr *schema.StreamReader[string]) (key string, err error) { + msg, err := sr.Recv() + if err != nil { + return "", err + } + defer sr.Close() + + switch msg { + case "one": + return "one_key", nil + case "two": + return "two_key", nil + case "three": + return "three_key", nil + default: + return "", fmt.Errorf("invalid input= %s", msg) + } + } + + c := NewChain[string, string]() + c.AppendLambda(StreamableLambda(func(ctx context.Context, in string) (output *schema.StreamReader[string], err error) { + sr, sw := schema.Pipe[string](utf8.RuneCountInString(in)) + + go func() { + for _, field := range strings.Fields(in) { + sw.Send(field, nil) + } + sw.Close() + }() + + return sr, nil + })) + c.AppendBranch(NewStreamChainBranch(streamCon).AddLambda("one_key", CollectableLambda(func(ctx context.Context, in *schema.StreamReader[string]) (output string, err error) { + defer in.Close() + for { + v, err := in.Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return "", err + } + + output += v + } + + return output + output, nil + })). + AddLambda("two_key", CollectableLambda(func(ctx context.Context, in *schema.StreamReader[string]) (output string, err error) { + defer in.Close() + for { + v, err := in.Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return "", err + } + + output += v + } + + return output + output + output, nil + }))) + + assert.Nil(t, c.err) + compiledChain, err := c.Compile(context.Background()) + assert.Nil(t, err) + + out, err := compiledChain.Stream(context.Background(), "one size fit all") + assert.Nil(t, err) + concat, err := concatStreamReader(out) + assert.Nil(t, err) + assert.Equal(t, "onesizefitallonesizefitall", concat) + }) +} diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go new file mode 100644 index 0000000..c262099 --- /dev/null +++ b/compose/chain_parallel.go @@ -0,0 +1,217 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "fmt" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" +) + +// NewParallel creates a new parallel type. +// it is useful when you want to run multiple nodes in parallel in a chain. +func NewParallel() *Parallel { + return &Parallel{ + nodes: make([]*graphNode, 0), + outputKeys: make(map[string]bool), + } +} + +// Parallel run multiple nodes in parallel +// +// use `NewParallel()` to create a new parallel type +// Example: +// +// parallel := NewParallel() +// parallel.AddChatModel("output_key01", chat01) +// parallel.AddChatModel("output_key01", chat02) +// +// chain := NewChain[any,any]() +// chain.AppendParallel(parallel) +type Parallel struct { + nodes []*graphNode + outputKeys map[string]bool + err error +} + +// AddChatModel adds a chat model to the parallel. +// eg. +// +// chatModel01, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ +// Model: "gpt-4o", +// }) +// +// chatModel02, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ +// Model: "gpt-4o", +// }) +// +// p.AddChatModel("output_key01", chatModel01) +// p.AddChatModel("output_key02", chatModel02) +func (p *Parallel) AddChatModel(outputKey string, node model.ChatModel, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toChatModelNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddChatTemplate adds a chat template to the parallel. +// eg. +// +// chatTemplate01, err := prompt.FromMessages(schema.FString, &schema.Message{ +// Role: schema.System, +// Content: "You are acting as a {role}.", +// }) +// +// p.AddChatTemplate("output_key01", chatTemplate01) +func (p *Parallel) AddChatTemplate(outputKey string, node prompt.ChatTemplate, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toChatTemplateNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddToolsNode adds a tools node to the parallel. +// eg. +// +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tool.BaseTool{...}, +// }) +// +// p.AddToolsNode("output_key01", toolsNode) +func (p *Parallel) AddToolsNode(outputKey string, node *ToolsNode, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toToolsNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddLambda adds a lambda node to the parallel. +// eg. +// +// lambdaFunc := func(ctx context.Context, input *schema.Message) ([]*schema.Message, error) { +// return []*schema.Message{input}, nil +// } +// +// p.AddLambda("output_key01", compose.InvokeLambda(lambdaFunc)) +func (p *Parallel) AddLambda(outputKey string, node *Lambda, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toLambdaNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddEmbedding adds an embedding node to the parallel. +// eg. +// +// embeddingNode, err := openai.NewEmbedder(ctx, &openai.EmbeddingConfig{ +// Model: "text-embedding-3-small", +// }) +// +// p.AddEmbedding("output_key01", embeddingNode) +func (p *Parallel) AddEmbedding(outputKey string, node embedding.Embedder, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toEmbeddingNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddRetriever adds a retriever node to the parallel. +// eg. +// +// retriever, err := vikingdb.NewRetriever(ctx, &vikingdb.RetrieverConfig{}) +// +// p.AddRetriever("output_key01", retriever) +func (p *Parallel) AddRetriever(outputKey string, node retriever.Retriever, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toRetrieverNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddLoaderSplitter adds a loader splitter node to the parallel. +// Deprecated: use AddLoader instead. +func (p *Parallel) AddLoaderSplitter(outputKey string, node document.LoaderSplitter, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toLoaderSplitterNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddLoader adds a loader node to the parallel. +// eg. +// +// loader, err := file.NewLoader(ctx, &file.LoaderConfig{}) +// +// p.AddLoader("output_key01", loader) +func (p *Parallel) AddLoader(outputKey string, node document.Loader, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toLoaderNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddIndexer adds an indexer node to the parallel. +// eg. +// +// indexer, err := volc_vikingdb.NewIndexer(ctx, &volc_vikingdb.IndexerConfig{ +// Collection: "my_collection", +// }) +// +// p.AddIndexer("output_key01", indexer) +func (p *Parallel) AddIndexer(outputKey string, node indexer.Indexer, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toIndexerNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddDocumentTransformer adds an Document Transformer node to the parallel. +// eg. +// +// markdownSplitter, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderSplitterConfig{}) +// +// p.AddDocumentTransformer("output_key01", markdownSplitter) +func (p *Parallel) AddDocumentTransformer(outputKey string, node document.Transformer, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toDocumentTransformerNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddGraph adds a graph node to the parallel. +// It is useful when you want to use a graph or a chain as a node in the parallel. +// eg. +// +// graph, err := compose.NewChain[any,any]() +// +// p.AddGraph("output_key01", graph) +func (p *Parallel) AddGraph(outputKey string, node AnyGraph, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toAnyGraphNode(node, append(opts, WithOutputKey(outputKey))...)) +} + +// AddPassthrough adds a passthrough node to the parallel. +// eg. +// +// p.AddPassthrough("output_key01") +func (p *Parallel) AddPassthrough(outputKey string, opts ...GraphAddNodeOpt) *Parallel { + return p.addNode(outputKey, toPassthroughNode(append(opts, WithOutputKey(outputKey))...)) +} + +func (p *Parallel) addNode(outputKey string, node *graphNode) *Parallel { + if p.err != nil { + return p + } + + if node == nil { + p.err = fmt.Errorf("chain parallel add node invalid, node is nil") + return p + } + + if p.outputKeys == nil { + p.outputKeys = make(map[string]bool) + } + + if _, ok := p.outputKeys[outputKey]; ok { + p.err = fmt.Errorf("parallel add node err, duplicate output key= %s", outputKey) + return p + } + + if node.nodeInfo == nil { + p.err = fmt.Errorf("chain parallel add node invalid, nodeInfo is nil") + return p + } + + node.nodeInfo.outputKey = outputKey + p.nodes = append(p.nodes, node) + p.outputKeys[outputKey] = true // nolint: byted_use_struct_without_nilcheck, byted_use_map_without_nilcheck + return p +} diff --git a/compose/chain_test.go b/compose/chain_test.go new file mode 100644 index 0000000..419c770 --- /dev/null +++ b/compose/chain_test.go @@ -0,0 +1,583 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/internal/mock/components/document" + "github.com/cloudwego/eino/internal/mock/components/embedding" + "github.com/cloudwego/eino/internal/mock/components/indexer" + "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/internal/mock/components/retriever" + "github.com/cloudwego/eino/schema" +) + +func TestChain(t *testing.T) { + + cm := &mockIntentChatModel{} + + // 构建 branch + branchCond := func(ctx context.Context, input map[string]any) (string, error) { + if rand.Intn(2) == 1 { + return "b1", nil + } + return "b2", nil + } + + b1 := InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("hello in branch lambda 01") + kvs["role"] = "cat" + return kvs, nil + }) + b2 := InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("hello in branch lambda 02") + kvs["role"] = "dog" + return kvs, nil + }) + + // 并发节点 + parallel := NewParallel() + parallel. + AddLambda("role", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { + // may be change role to others by input kvs, for example (dentist/doctor...) + role := kvs["role"] + if role.(string) == "" { + role = "bird" + } + return role.(string), nil + })). + AddLambda("input", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { + return "你的叫声是怎样的?", nil + })) + + // 顺序节点 + rolePlayChain := NewChain[map[string]any, *schema.Message]() + rolePlayChain. + AppendChatTemplate(prompt.FromMessages(schema.FString, schema.SystemMessage(`You are a {role}.`), schema.UserMessage(`{input}`))). + AppendChatModel(cm) + + // 构建 chain + + chain := NewChain[map[string]any, string]() + chain. + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + // do some logic to prepare kv as variables for next Node + // just pass through + t.Log("in view lambda: ", kvs) + return kvs, nil + })). + AppendBranch(NewChainBranch[map[string]any](branchCond).AddLambda("b1", b1).AddLambda("b2", b2)). + AppendPassthrough(). + AppendParallel(parallel). + AppendGraph(rolePlayChain). + AppendLambda(InvokableLambda(func(ctx context.Context, m *schema.Message) (string, error) { + // do some logic to check the output or something + t.Log("in view of messages: ", m.Content) + + return m.Content, nil + })) + + r, err := chain.Compile(context.Background()) + assert.Nil(t, err) + + out, err := r.Invoke(context.Background(), map[string]any{}) + assert.Nil(t, err) + t.Log(err) + + t.Log("out is : ", out) +} + +func TestChainWithException(t *testing.T) { + chain := NewChain[map[string]any, string]() + chain. + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + // do some logic to prepare kv as variables for next Node + // just pass through + t.Log("in view lambda: ", kvs) + return kvs, nil + })) + + // items with parallels + parallel := NewParallel() + parallel. + AddLambda("hello", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { + t.Log("in parallel item 01") + return "world", nil + })). + AddLambda("world", InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { + t.Log("in parallel item 02") + return "hello", nil + })) + + // sequence items + nchain := NewChain[map[string]any, map[string]any]() + nchain. + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in sequence item 01") + return kvs, nil + })). + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in sequence item 02") + return kvs, nil + })) + + branchCond := func(ctx context.Context, input map[string]any) (string, error) { + if rand.Intn(2) == 1 { + return "b1", nil + } + return "b2", nil + } + + b1 := InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("hello in branch lambda 01") + kvs["role"] = "cat" + return kvs, nil + }) + b2 := InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + }) + + // sequence with branch + chain.AppendBranch(NewChainBranch[map[string]any](branchCond).AddLambda("b1", b1).AddLambda("b2", b2)) + + // parallel with sequence + parallel.AddGraph("test_sequence", nchain) + + // parallel with parallel + npara := NewParallel(). + AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })). + AddLambda("test_parallel2", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + + // parallel with graph + ngraph := NewChain[map[string]any, map[string]any](). + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in graph item 01") + return kvs, nil + })). + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in graph item 02") + return kvs, nil + })) + nc := NewChain[map[string]any, map[string]any]() + nc.AppendGraph(ngraph) + parallel.AddGraph("test_graph", nc) + + chain.AppendPassthrough() + + // sequence with parallel + chain.AppendParallel(npara) + + // 构建 chain + chain. + AppendGraph(nchain). + AppendParallel(parallel). + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { + t.Log("in last view lambda: ", kvs) + return "hello last", nil + })) + + ctx := context.Background() + + r, err := chain.Compile(ctx) + assert.Nil(t, err) + + out, err := r.Invoke(ctx, map[string]any{"test": "test"}) + assert.Nil(t, err) + t.Log("out is : ", out) +} + +func TestEmptyList(t *testing.T) { + ctx := context.Background() + + // no nodes in chain + chain := NewChain[map[string]any, map[string]any]() + _, err := chain.Compile(ctx) + assert.Error(t, err) + + // no nodes in parallel + parallel := NewParallel() + chain = NewChain[map[string]any, map[string]any]() + chain.AppendParallel(parallel) + + _, err = chain.Compile(ctx) + assert.Error(t, err) + + // no nodes in sequence + emptyChain := NewChain[map[string]any, map[string]any]() + chain = NewChain[map[string]any, map[string]any]() + + chain. + AppendParallel(parallel). + AppendGraph(emptyChain). + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + + _, err = chain.Compile(ctx) + assert.Error(t, err) +} + +func TestChainList(t *testing.T) { + chain := NewChain[map[string]any, map[string]any]() + chain. + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in view lambda: ", kvs) + return kvs, nil + })) + + // parallel + parallel := NewParallel() + parallel. + AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in parallel item 01") + return kvs, nil + })) + + // seq in parallel + nchain := NewChain[map[string]any, map[string]any]() + nchain. + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in sequence in parallel item 01") + return kvs, nil + })). + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in sequence in parallel item 02") + return kvs, nil + })) + + // seq in seq + nchainInChain := NewChain[map[string]any, map[string]any]() + nchainInChain. + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in sequence in sequence item 01") + return kvs, nil + })). + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in sequence in sequence item 02") + return kvs, nil + })) + + nchain.AppendGraph(nchainInChain) + + parallel.AddGraph("test_seq_in_parallel", nchain) + + chain.AppendParallel(parallel) + + r, err := chain.Compile(context.Background()) + assert.Nil(t, err) + out, err := r.Invoke(context.Background(), map[string]any{"test": "test"}) + assert.Nil(t, err) + t.Log("out is : ", out) +} + +func TestChainSingleNode(t *testing.T) { + chain := NewChain[map[string]any, map[string]any]() + chain. + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in view lambda: ", kvs) + return kvs, nil + })) + + // single Node in chain (prepare for parallel) + singleNodeChain := NewChain[map[string]any, map[string]any]() + singleNodeChain. + AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in sequence item 01") + return kvs, nil + })) + + // add parallel + parallel := NewParallel() + parallel. + AddLambda("test_parallel1_lambda", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in parallel item 01") + return kvs, nil + })) + + parallel.AddGraph("test_parallel2_chain", singleNodeChain) + + ctx := context.Background() + + chain.AppendParallel(parallel) + r, err := chain.Compile(ctx) + assert.Nil(t, err) + + out, err := r.Invoke(ctx, map[string]any{"test": "test"}) + assert.Nil(t, err) + t.Log("out is : ", out) +} + +func TestParallelModels(t *testing.T) { + cm := &mockIntentChatModel{} + chain := NewChain[map[string]any, map[string]any]() + chatSuite := NewChain[map[string]any, string]() + chatSuite. + AppendChatTemplate(prompt.FromMessages(schema.FString, schema.SystemMessage(`You are a {role}.`), schema.UserMessage(`{input}`))). + AppendChatModel(cm). + AppendLambda(InvokableLambda(func(ctx context.Context, msg *schema.Message) (string, error) { + t.Log("in parallel item 01") + return msg.Content, nil + })) + + parallel := NewParallel() + parallel. + AddGraph("time001", chatSuite). + AddGraph("time002", chatSuite). + AddGraph("time003", chatSuite) + + chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + t.Log("in view lambda: ", kvs) + return kvs, nil + })) + + chain.AppendParallel(parallel) + + ctx := context.Background() + + r, err := chain.Compile(ctx) + assert.Nil(t, err) + + out, err := r.Invoke(ctx, map[string]any{"role": "cat", "input": "你怎么叫的?"}) + assert.Nil(t, err) + + t.Log("out is : ", out) +} + +func TestChainMultiNodes(t *testing.T) { + ctx := context.Background() + + t.Run("test embedding Node", func(t *testing.T) { + chain := NewChain[[]string, [][]float64]() + + mockCtrl := gomock.NewController(t) + eb := embedding.NewMockEmbedder(mockCtrl) + chain.AppendEmbedding(eb) + + r, err := chain.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, r) + }) + + t.Run("test retriever Node", func(t *testing.T) { + chain := NewChain[string, []*schema.Document]() + + chain.AppendRetriever(retriever.NewMockRetriever(gomock.NewController(t))) + + r, err := chain.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, r) + }) + + t.Run("test chat model", func(t *testing.T) { + chain := NewChain[[]*schema.Message, *schema.Message]() + + cm := &mockIntentChatModel{} + chain.AppendChatModel(cm) + + r, err := chain.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, r) + }) + + t.Run("test chat template", func(t *testing.T) { + chain := NewChain[map[string]any, []*schema.Message]() + + chatTemplate := prompt.FromMessages(schema.FString) + chain.AppendChatTemplate(chatTemplate) + + r, err := chain.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, r) + }) + + t.Run("test lambda", func(t *testing.T) { + chain := NewChain[map[string]any, map[string]any]() + + chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + + r, err := chain.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, r) + }) + + t.Run("test indexer", func(t *testing.T) { + chain := NewChain[[]*schema.Document, []string]() + + chain.AppendIndexer(indexer.NewMockIndexer(gomock.NewController(t))) + + r, err := chain.Compile(ctx) + assert.NoError(t, err) + assert.NotNil(t, r) + }) + + t.Run("test parallel", func(t *testing.T) { + chain := NewChain[map[string]any, map[string]any]() + parallel := NewParallel() + parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + chain.AppendParallel(parallel) + _, err := chain.Compile(ctx) + assert.Error(t, err) + + chain = NewChain[map[string]any, map[string]any]() + parallel = NewParallel() + parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + chain.AppendParallel(parallel) + _, err = chain.Compile(ctx) + assert.Error(t, err) + + chain = NewChain[map[string]any, map[string]any]() + parallel = NewParallel() + parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + parallel.AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + chain.AppendParallel(parallel) + _, err = chain.Compile(ctx) + assert.NoError(t, err) + + chain = NewChain[map[string]any, map[string]any]() + parallel = NewParallel() + parallel.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + parallel.AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + chain.AppendParallel(parallel) + + parallel1 := NewParallel() + parallel1.AddLambda("test_parallel", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + parallel1.AddLambda("test_parallel1", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + chain.AppendParallel(parallel1) + + _, err = chain.Compile(ctx) + assert.Error(t, err) + }) + + t.Run("test tools Node", func(t *testing.T) { + ctx := context.Background() + chain := NewChain[map[string]any, map[string]any]() + toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{}) + assert.NoError(t, err) + chain.AppendToolsNode(toolsNode) + }) + + t.Run("test chain with compile option", func(t *testing.T) { + chain := NewChain[map[string]any, map[string]any]() + chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + r, err := chain.Compile(ctx, WithMaxRunSteps(10)) + assert.NoError(t, err) + assert.NotNil(t, r) + }) + + t.Run("test chain return type", func(t *testing.T) { + t.Run("test chain any output type", func(t *testing.T) { + chain := NewChain[map[string]any, map[string]any]() + chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (any, error) { + return 1, nil + })) + _, err := chain.Compile(ctx) + assert.Nil(t, err) + }) + + t.Run("test chain error output type", func(t *testing.T) { + chain := NewChain[map[string]any, map[string]any]() + chain.AppendLambda(InvokableLambda(func(ctx context.Context, kvs map[string]any) (string, error) { + return "123", nil + })) + _, err := chain.Compile(ctx) + assert.Error(t, err) + }) + + t.Run("test chain error input type", func(t *testing.T) { + chain := NewChain[map[string]any, map[string]any]() + chain.AppendLambda(InvokableLambda(func(ctx context.Context, input string) (map[string]any, error) { + return nil, nil + })) + _, err := chain.Compile(ctx) + assert.Error(t, err) + }) + }) + +} + +func TestParallelMultiNodes(t *testing.T) { + ctx := context.Background() + p := NewParallel() + p.AddLambda("lambda", InvokableLambda(func(ctx context.Context, kvs map[string]any) (map[string]any, error) { + return kvs, nil + })) + p.AddGraph("graph", NewChain[map[string]any, map[string]any]()) + p.AddIndexer("indexer", indexer.NewMockIndexer(gomock.NewController(t))) + p.AddLoader("loader", document.NewMockLoader(gomock.NewController(t))) + p.AddDocumentTransformer("document transformer", document.NewMockTransformer(gomock.NewController(t))) + p.AddRetriever("retriever", retriever.NewMockRetriever(gomock.NewController(t))) + p.AddChatModel("chatmodel", model.NewMockChatModel(gomock.NewController(t))) + p.AddChatTemplate("chatTemplate", prompt.FromMessages(schema.FString, schema.SystemMessage("hello"))) + p.AddEmbedding("embedding", embedding.NewMockEmbedder(gomock.NewController(t))) + p.AddPassthrough("passthrough") + toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{}) + assert.NoError(t, err) + p.AddToolsNode("tools", toolsNode) + + assert.Greater(t, len(p.nodes), 6) + + ctrl := gomock.NewController(t) + p = NewParallel() + p.AddIndexer("key", indexer.NewMockIndexer(ctrl)) + p.AddLoader("key", document.NewMockLoader(ctrl)) + p.AddRetriever("r", retriever.NewMockRetriever(ctrl)) + assert.NotNil(t, p.err) + + p = NewParallel() + p.addNode("k", nil) + assert.NotNil(t, p.err) + + p = &Parallel{ + outputKeys: nil, + } + p.addNode("k", &graphNode{}) + assert.NotNil(t, p.err) +} diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go new file mode 100644 index 0000000..5686f80 --- /dev/null +++ b/compose/component_to_graph_node.go @@ -0,0 +1,180 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" +) + +func toEmbeddingNode(node embedding.Embedder, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(components.ComponentOfEmbedding, node) + info := getNodeInfo(opts...) + run := runnableLambda(node.EmbedStrings, nil, nil, nil, + !meta.isComponentCallbackEnabled, + ) + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toRetrieverNode(node retriever.Retriever, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(components.ComponentOfRetriever, node) + info := getNodeInfo(opts...) + run := runnableLambda(node.Retrieve, nil, nil, nil, + !meta.isComponentCallbackEnabled, + ) + + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toLoaderSplitterNode(node document.LoaderSplitter, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(components.ComponentOfLoaderSplitter, node) + info := getNodeInfo(opts...) + run := runnableLambda(node.LoadAndSplit, nil, nil, nil, + !meta.isComponentCallbackEnabled, + ) + + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toLoaderNode(node document.Loader, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(components.ComponentOfLoader, node) + info := getNodeInfo(opts...) + run := runnableLambda(node.Load, nil, nil, nil, + !meta.isComponentCallbackEnabled, + ) + + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toIndexerNode(node indexer.Indexer, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(components.ComponentOfIndexer, node) + info := getNodeInfo(opts...) + run := runnableLambda(node.Store, nil, nil, nil, + !meta.isComponentCallbackEnabled, + ) + + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toChatModelNode(node model.ChatModel, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(components.ComponentOfChatModel, node) + info := getNodeInfo(opts...) + + run := runnableLambda(node.Generate, node.Stream, nil, nil, + !meta.isComponentCallbackEnabled, + ) + + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(components.ComponentOfPrompt, node) + info := getNodeInfo(opts...) + run := runnableLambda(node.Format, nil, nil, nil, + !meta.isComponentCallbackEnabled, + ) + + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toDocumentTransformerNode(node document.Transformer, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(components.ComponentOfTransformer, node) + info := getNodeInfo(opts...) + run := runnableLambda(node.Transform, nil, nil, nil, + !meta.isComponentCallbackEnabled, + ) + + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(ComponentOfToolsNode, node) + info := getNodeInfo(opts...) + run := runnableLambda(node.Invoke, node.Stream, nil, nil, + true, + ) + + gn := toNode(info, run, nil, meta, node, opts...) + + return gn +} + +func toLambdaNode(node *Lambda, opts ...GraphAddNodeOpt) *graphNode { + info := getNodeInfo(opts...) + + gn := toNode(info, node.executor, nil, node.executor.meta, node, opts...) + + return gn +} + +func toAnyGraphNode(node AnyGraph, opts ...GraphAddNodeOpt) *graphNode { + meta := parseExecutorInfoFromComponent(node.component(), node) + info := getNodeInfo(opts...) + + gn := toNode(info, nil, node, meta, node, opts...) + + return gn +} + +func toPassthroughNode(opts ...GraphAddNodeOpt) *graphNode { + node := composablePassthrough() + info := getNodeInfo(opts...) + gn := toNode(info, node, nil, node.meta, node, opts...) + return gn +} + +func toNode(nodeInfo *nodeInfo, executor *composableRunnable, graph AnyGraph, meta *executorMeta, instance any, opts ...GraphAddNodeOpt) *graphNode { // nolint: byted_s_args_length_limit + if meta == nil { + meta = &executorMeta{} + } + + gn := &graphNode{ + nodeInfo: nodeInfo, + + cr: executor, + g: graph, + executorMeta: meta, + + instance: instance, + opts: opts, + } + + gn.nodeInfo.name = gn.getNodeName() + + return gn +} diff --git a/compose/dag.go b/compose/dag.go new file mode 100644 index 0000000..014e356 --- /dev/null +++ b/compose/dag.go @@ -0,0 +1,78 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" +) + +func dagChannelBuilder(dependencies []string) channel { + return &dagChannel{ + values: make(map[string]any), + waitList: dependencies, + } +} + +type dagChannel struct { + values map[string]any + waitList []string + value any +} + +func (ch *dagChannel) update(ctx context.Context, ins map[string]any) error { + for k, v := range ins { + if _, ok := ch.values[k]; ok { + return fmt.Errorf("dag channel update, calculate node repeatedly: %s", k) + } + ch.values[k] = v + } + + for i := range ch.waitList { + if _, ok := ch.values[ch.waitList[i]]; !ok { + return nil + } + } + + if len(ch.waitList) == 1 { + ch.value = ch.values[ch.waitList[0]] + return nil + } + v, err := mergeValues(mapToList(ch.values)) + if err != nil { + return fmt.Errorf("dag channel merge value fail: %w", err) + } + ch.value = v + + return nil +} + +func (ch *dagChannel) get(ctx context.Context) (any, error) { + if ch.value == nil { + return nil, fmt.Errorf("dag channel not ready, value is nil") + } + return ch.value, nil +} + +func (ch *dagChannel) ready(ctx context.Context) bool { + return ch.value != nil +} + +func (ch *dagChannel) clear(ctx context.Context) { + ch.values = make(map[string]any) + ch.values = nil +} diff --git a/compose/dag_test.go b/compose/dag_test.go new file mode 100644 index 0000000..1b78f9c --- /dev/null +++ b/compose/dag_test.go @@ -0,0 +1,228 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + "io" + "testing" +) + +func TestDAG(t *testing.T) { + var err error + + g := NewGraph[string, string]() + err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + }), WithOutputKey("1")) + if err != nil { + t.Fatal(err) + } + + err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + }), WithOutputKey("2")) + if err != nil { + t.Fatal(err) + } + + err = g.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input map[string]any) (output string, err error) { + if _, ok := input["1"]; !ok { + return "", fmt.Errorf("node 1 output fail: %+v", input) + } + if _, ok := input["2"]; !ok { + return "", fmt.Errorf("node 2 output fail: %+v", input) + } + return input["1"].(string) + input["2"].(string), nil + })) + if err != nil { + t.Fatal(err) + } + + err = g.AddLambdaNode("4", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + })) + if err != nil { + t.Fatal(err) + } + + err = g.AddLambdaNode("5", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + })) + if err != nil { + t.Fatal(err) + } + + err = g.AddLambdaNode("6", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + }), WithOutputKey("6")) + if err != nil { + t.Fatal(err) + } + + err = g.AddLambdaNode("7", InvokableLambda(func(ctx context.Context, input map[string]any) (output string, err error) { + if _, ok := input["1"]; !ok { + return "", fmt.Errorf("7:node 1 output fail: %+v", input) + } + if _, ok := input["6"]; !ok { + return "", fmt.Errorf("7:node 6 output fail: %+v", input) + } + return input["1"].(string) + input["6"].(string), nil + })) + if err != nil { + t.Fatal(err) + } + + err = g.AddEdge("1", "3") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("2", "3") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("3", "4") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("4", "5") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("4", "6") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("6", "7") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("1", "7") + if err != nil { + t.Fatal(err) + } + + err = g.AddEdge(START, "1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "2") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("7", END) + if err != nil { + t.Fatal(err) + } + + r, err := g.compile(context.Background(), &graphCompileOptions{nodeTriggerMode: AllPredecessor, maxRunSteps: 10}) + if err != nil { + t.Fatal(err) + } + + // success + ctx := context.Background() + out, err := r.i(ctx, "hello") + if err != nil { + t.Fatal(err) + } + if out.(string) != "hellohellohello" { + t.Fatalf("node7 fail") + } + + // test Compile[I,O] + runner, err := g.Compile(context.Background(), WithMaxRunSteps(100), WithNodeTriggerMode(AllPredecessor)) + if err != nil { + t.Fatal(err) + } + result, err := runner.Invoke(ctx, "1") + if err != nil { + t.Fatal(err) + } + if result != "111" { + t.Fatalf("runner invoke fail, output: %s", result) + } + streamResult, err := runner.Stream(ctx, "1") + if err != nil { + t.Fatal(err) + } + defer streamResult.Close() + ret := "" + for { + chunk, err := streamResult.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + ret += chunk + } + if ret != "111" { + t.Fatalf("runner stream fail, output: %s", ret) + } + + // loop + gg := NewGraph[string, map[string]any]() + err = gg.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + }), WithOutputKey("1")) + if err != nil { + t.Fatal(err) + } + + err = gg.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input map[string]any) (output string, err error) { + return input["1"].(string), nil + })) + if err != nil { + t.Fatal(err) + } + + err = gg.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + }), WithOutputKey("3")) + if err != nil { + t.Fatal(err) + } + + err = gg.AddEdge("1", "2") + if err != nil { + t.Fatal(err) + } + err = gg.AddEdge("2", "3") + if err != nil { + t.Fatal(err) + } + err = gg.AddEdge("3", "2") + if err != nil { + t.Fatal(err) + } + err = gg.AddEdge(START, "1") + if err != nil { + t.Fatal(err) + } + err = gg.AddEdge("3", END) + if err != nil { + t.Fatal(err) + } + + _, err = gg.compile(ctx, &graphCompileOptions{nodeTriggerMode: AllPredecessor}) + if err == nil { + t.Fatal("cannot validate loop") + } +} diff --git a/compose/doc.go b/compose/doc.go new file mode 100644 index 0000000..e072b18 --- /dev/null +++ b/compose/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose diff --git a/compose/error.go b/compose/error.go new file mode 100644 index 0000000..db1cb18 --- /dev/null +++ b/compose/error.go @@ -0,0 +1,61 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "errors" + "fmt" + "reflect" +) + +// ErrExceedMaxSteps graph will throw this error when the number of steps exceeds the maximum number of steps. +var ErrExceedMaxSteps = errors.New("exceeds max steps") + +func newUnexpectedInputTypeErr(expected reflect.Type, got reflect.Type) error { + return fmt.Errorf("unexpected input type. expected: %v, got: %v", expected, got) +} + +type defaultImplErrCausedType string +type defaultImplAction string + +const ( + streamConcat defaultImplErrCausedType = "concat stream items" + internalCall defaultImplErrCausedType = "call internal action" +) +const ( + actionInvokeByStream defaultImplAction = "InvokeByStream" + actionInvokeByCollect defaultImplAction = "InvokeByCollect" + actionInvokeByTransform defaultImplAction = "InvokeByTransform" + actionStreamByInvoke defaultImplAction = "StreamByInvoke" + actionStreamByTransform defaultImplAction = "StreamByTransform" + actionStreamByCollect defaultImplAction = "StreamByCollect" + actionCollectByTransform defaultImplAction = "CollectByTransform" + actionCollectByInvoke defaultImplAction = "CollectByInvoke" + actionCollectByStream defaultImplAction = "CollectByStream" + actionTransformByStream defaultImplAction = "TransformByStream" + actionTransformByCollect defaultImplAction = "TransformByCollect" + actionTransformByInvoke defaultImplAction = "TransformByInvoke" +) + +func newDefaultImplErr(action defaultImplAction, causedType defaultImplErrCausedType, causedErr error) error { + return fmt.Errorf( + "default implementation: '%s' got error, when try to %s, err: \n%w", action, causedType, causedErr) +} + +func newStreamReadError(err error) error { + return fmt.Errorf("failed to read from stream. error: %w", err) +} diff --git a/compose/generic_graph.go b/compose/generic_graph.go new file mode 100644 index 0000000..1aac926 --- /dev/null +++ b/compose/generic_graph.go @@ -0,0 +1,104 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + + "github.com/cloudwego/eino/utils/generic" +) + +// NewGraph create a directed graph that can compose components, lambda, chain, parallel etc. +// simultaneously provide flexible and multi-granular aspect governance capabilities. +// I: the input type of graph compiled product +// O: the output type of graph compiled product +func NewGraph[I, O any]() *Graph[I, O] { + return &Graph[I, O]{ + newGraph( + generic.TypeOf[I](), + generic.TypeOf[O](), + defaultStreamMapFilter[I], + defaultValueChecker[I], + defaultValueChecker[O], + defaultStreamConverter[I], + defaultStreamConverter[O], + defaultGraphKey(), + )} +} + +// Graph is a generic graph that can be used to compose components. +// I: the input type of graph compiled product +// O: the output type of graph compiled product +type Graph[I, O any] struct { + *graph +} + +func (g *Graph[I, O]) component() component { + return ComponentOfGraph +} + +// Compile take the raw graph and compile it into a form ready to be run. +// eg. +// +// graph, err := compose.NewGraph[string, string]() +// if err != nil {...} +// +// runnable, err := graph.Compile(ctx, compose.WithGraphName("my_graph")) +// if err != nil {...} +// +// runnable.Invoke(ctx, "input") // invoke +// runnable.Stream(ctx, "input") // stream +// runnable.Collect(ctx, inputReader) // collect +// runnable.Transform(ctx, inputReader) // transform +func (g *Graph[I, O]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { + if len(globalGraphCompileCallbacks) > 0 { + opts = append([]GraphCompileOption{WithGraphCompileCallbacks(globalGraphCompileCallbacks...)}, opts...) + } + option := newGraphCompileOptions(opts...) + + cr, err := g.graph.compile(ctx, option) + if err != nil { + return nil, err + } + + // option component can override the default graph component. + comp := option.component + if len(comp) == 0 { + comp = g.component() + } + + cr.meta = &executorMeta{ + component: comp, + isComponentCallbackEnabled: true, + componentImplType: "", + } + + cr.nodeInfo = &nodeInfo{ + name: generateName(option.graphName, cr.meta), + } + + ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { + return initGraphCallbacks(ctx, cr.nodeInfo, cr.meta, opts...) + } + + rp, err := toGenericRunnable[I, O](cr, ctxWrapper) + if err != nil { + return nil, err + } + + return rp, nil +} diff --git a/compose/graph.go b/compose/graph.go new file mode 100644 index 0000000..63e4d1c --- /dev/null +++ b/compose/graph.go @@ -0,0 +1,1004 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "errors" + "fmt" + "reflect" + "runtime" + + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/internal/gmap" + "github.com/cloudwego/eino/schema" +) + +// START is the start node of the graph. You can add your first edge with START. +const START = "start" + +// END is the end node of the graph. You can add your last edge with END. +const END = "end" + +// errGraphFrozen is the error returned when the graph is frozen but still add node or edge. +var errGraphFrozen = errors.New("graph already frozen") + +// GraphBranchCondition is the condition type for the branch. +type GraphBranchCondition[T any] func(ctx context.Context, in T) (endNode string, err error) + +// StreamGraphBranchCondition is the condition type for the stream branch. +type StreamGraphBranchCondition[T any] func(ctx context.Context, in *schema.StreamReader[T]) (endNode string, err error) + +// GraphBranch is the branch type for the graph. +// It is used to determine the next node based on the condition. +type GraphBranch struct { + condition *composableRunnable + endNodes map[string]bool + idx int // used to distinguish branches in parallel +} + +// GetEndNode returns the all end nodes of the branch. +func (gb *GraphBranch) GetEndNode() map[string]bool { + return gb.endNodes +} + +// NewGraphBranch creates a new graph branch. +// It is used to determine the next node based on the condition. +// eg. +// +// condition := func(ctx context.Context, in string) (string, error) { +// // logic to determine the next node +// return "next_node_key", nil +// } +// endNodes := map[string]bool{"path01": true, "path02": true} +// branch := compose.NewGraphBranch(condition, endNodes) +// +// graph.AddBranch("key_of_node_before_branch", branch) +func NewGraphBranch[T any](condition GraphBranchCondition[T], endNodes map[string]bool) *GraphBranch { + condRun := func(ctx context.Context, in T, opts ...any) (string, error) { + endNode, err := condition(ctx, in) + if err != nil { + return "", err + } + + if !endNodes[endNode] { + return "", fmt.Errorf("branch invocation returns unintended end node: %s", endNode) + } + + return endNode, nil + } + + r := runnableLambda(condRun, nil, nil, nil, false) + + return &GraphBranch{ + condition: r, + endNodes: endNodes, + } +} + +// NewStreamGraphBranch creates a new stream graph branch. +// It is used to determine the next node based on the condition of stream input. +// eg. +// +// condition := func(ctx context.Context, in *schema.StreamReader[T]) (string, error) { +// // logic to determine the next node. +// // to use the feature of stream, you can use the first chunk to determine the next node. +// return "next_node_key", nil +// } +// endNodes := map[string]bool{"path01": true, "path02": true} +// branch := compose.NewStreamGraphBranch(condition, endNodes) +// +// graph.AddBranch("key_of_node_before_branch", branch) +func NewStreamGraphBranch[T any](condition StreamGraphBranchCondition[T], + endNodes map[string]bool) *GraphBranch { + + condRun := func(ctx context.Context, in *schema.StreamReader[T], opts ...any) (string, error) { + endNode, err := condition(ctx, in) + if err != nil { + return "", err + } + + if !endNodes[endNode] { + return "", fmt.Errorf("stream branch invocation returns unintended end node: %s", endNode) + } + + return endNode, nil + } + + r := runnableLambda(nil, nil, condRun, nil, false) + + return &GraphBranch{ + condition: r, + endNodes: endNodes, + } +} + +// graphRunType is a custom type used to control the running mode of the graph. +type graphRunType string + +const ( + // runTypePregel is a running mode of the graph that is suitable for large-scale graph processing tasks. Can have cycles in graph. Compatible with NodeTriggerType.AnyPredecessor. + runTypePregel graphRunType = "Pregel" + // runTypeDAG is a running mode of the graph that represents the graph as a directed acyclic graph, suitable for tasks that can be represented as a directed acyclic graph. Compatible with NodeTriggerType.AllPredecessor. + runTypeDAG graphRunType = "DAG" +) + +// String returns the string representation of the graph run type. +func (g graphRunType) String() string { + return string(g) +} + +type graph struct { + nodes map[string]*graphNode + edges map[string][]string + branches map[string][]*GraphBranch + startNodes []string + endNodes []string + + toValidateMap map[string][]string + + frozen bool + + runCtx func(ctx context.Context) context.Context + + addNodeChecker nodeChecker + compileChecker func(options *graphCompileOptions) error + + expectedInputType, expectedOutputType reflect.Type + inputStreamFilter streamMapFilter + inputValueChecker valueChecker + inputStreamConverter streamConverter + outputValueChecker valueChecker + outputStreamConverter streamConverter + + runtimeCheckEdges map[string]map[string]bool + runtimeCheckBranches map[string][]bool + runtimeGraphKey string + + buildError error +} + +func newGraph( // nolint: byted_s_args_length_limit + inputType, outputType reflect.Type, + filter streamMapFilter, + inputChecker, outputChecker valueChecker, + inputConv, outputConv streamConverter, + graphKey string, +) *graph { + return &graph{ + nodes: make(map[string]*graphNode), + edges: make(map[string][]string), + branches: make(map[string][]*GraphBranch), + + toValidateMap: make(map[string][]string), + + addNodeChecker: nodeCheckerOfForbidProcessor(nodeCheckerOfForbidNodeKey(baseNodeChecker)), + compileChecker: defaultCompileChecker, + + expectedInputType: inputType, + expectedOutputType: outputType, + inputStreamFilter: filter, + inputValueChecker: inputChecker, + inputStreamConverter: inputConv, + outputValueChecker: outputChecker, + outputStreamConverter: outputConv, + + runtimeCheckEdges: make(map[string]map[string]bool), + runtimeCheckBranches: make(map[string][]bool), + runtimeGraphKey: graphKey, + } +} + +func (g *graph) freeze() { + g.frozen = true +} + +func (g *graph) isFrozen() bool { + return g.frozen +} + +func (g *graph) addNode(name string, node *graphNode) (err error) { + if g.buildError != nil { + return g.buildError + } + defer func() { + if err != nil { + g.buildError = err + } + }() + + if g.frozen { + return errGraphFrozen + } + + if name == END || name == START { + return fmt.Errorf("node '%s' is reserved, cannot add manually", name) + } + + if _, ok := g.nodes[name]; ok { + return fmt.Errorf("node '%s' already present", name) + } + + if err = g.addNodeChecker(name, node); err != nil { + return err + } + + g.nodes[name] = node + + return nil +} + +// AddEdge adds an edge to the graph, edge means a data flow from startNode to endNode. +// the previous node's output type must can be set to the next node's input type. +// NOTE: startNode and endNode must have been added to the graph before adding edge. +// eg. +// +// graph.AddNode("start_node_key", compose.NewPassthroughNode()) +// graph.AddNode("end_node_key", compose.NewPassthroughNode()) +// +// err := graph.AddEdge("start_node_key", "end_node_key") +func (g *graph) AddEdge(startNode, endNode string) (err error) { + if g.buildError != nil { + return g.buildError + } + defer func() { + if err != nil { + g.buildError = err + } + }() + + if g.frozen { + return errGraphFrozen + } + + if startNode == END { + return errors.New("END cannot be a start node") + } + + if endNode == START { + return errors.New("START cannot be an end node") + } + + for i := range g.edges[startNode] { + if g.edges[startNode][i] == endNode { + return fmt.Errorf("edge[%s]-[%s] have been added yet", startNode, endNode) + } + } + + if _, ok := g.nodes[startNode]; !ok && startNode != START { + return fmt.Errorf("edge start node '%s' needs to be added to graph first", startNode) + } + + if _, ok := g.nodes[endNode]; !ok && endNode != END { + return fmt.Errorf("edge end node '%s' needs to be added to graph first", endNode) + } + + err = g.validateAndInferType(startNode, endNode) + if err != nil { + return err + } + + g.edges[startNode] = append(g.edges[startNode], endNode) + + if startNode == START { + g.startNodes = append(g.startNodes, endNode) + } + + if endNode == END { + g.endNodes = append(g.endNodes, startNode) + } + + err = g.updateToValidateMap() + if err != nil { + return err + } + + return nil +} + +// AddEmbeddingNode adds a node that implements embedding.Embedder. +// eg. +// +// embeddingNode, err := openai.NewEmbedder(ctx, &openai.EmbeddingConfig{ +// Model: "text-embedding-3-small", +// }) +// +// graph.AddEmbeddingNode("embedding_node_key", embeddingNode) +func (g *graph) AddEmbeddingNode(key string, node embedding.Embedder, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toEmbeddingNode(node, opts...)) +} + +// AddRetrieverNode adds a node that implements retriever.Retriever. +// eg. +// +// retriever, err := vikingdb.NewRetriever(ctx, &vikingdb.RetrieverConfig{}) +// +// graph.AddRetrieverNode("retriever_node_key", retrieverNode) +func (g *graph) AddRetrieverNode(key string, node retriever.Retriever, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toRetrieverNode(node, opts...)) +} + +// Deprecated: use AddLoaderNode instead. +func (g *graph) AddLoaderSplitterNode(key string, node document.LoaderSplitter, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toLoaderSplitterNode(node, opts...)) +} + +// AddLoaderNode adds a node that implements document.Loader. +// eg. +// +// loader, err := file.NewLoader(ctx, &file.LoaderConfig{}) +// +// graph.AddLoaderNode("loader_node_key", loader) +func (g *graph) AddLoaderNode(key string, node document.Loader, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toLoaderNode(node, opts...)) +} + +// AddIndexerNode adds a node that implements indexer.Indexer. +// eg. +// +// indexer, err := vikingdb.NewIndexer(ctx, &vikingdb.IndexerConfig{}) +// +// graph.AddIndexerNode("indexer_node_key", indexer) +func (g *graph) AddIndexerNode(key string, node indexer.Indexer, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toIndexerNode(node, opts...)) +} + +// AddChatModelNode add node that implements model.ChatModel. +// eg. +// +// chatModel, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{ +// Model: "gpt-4o", +// }) +// +// graph.AddChatModelNode("chat_model_node_key", chatModel) +func (g *graph) AddChatModelNode(key string, node model.ChatModel, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toChatModelNode(node, opts...)) +} + +// AddChatTemplateNode add node that implements prompt.ChatTemplate. +// eg. +// +// chatTemplate, err := prompt.FromMessages(schema.FString, &schema.Message{ +// Role: schema.System, +// Content: "You are acting as a {role}.", +// }) +// +// graph.AddChatTemplateNode("chat_template_node_key", chatTemplate) +func (g *graph) AddChatTemplateNode(key string, node prompt.ChatTemplate, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toChatTemplateNode(node, opts...)) +} + +// AddToolsNode adds a node that implements tools.ToolsNode. +// eg. +// +// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{}) +// +// graph.AddToolsNode("tools_node_key", toolsNode) +func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toToolsNode(node, opts...)) +} + +// AddDocumentTransformerNode adds a node that implements document.Transformer. +// eg. +// +// markdownSplitter, err := markdown.NewHeaderSplitter(ctx, &markdown.HeaderSplitterConfig{}) +// +// graph.AddDocumentTransformerNode("document_transformer_node_key", markdownSplitter) +func (g *graph) AddDocumentTransformerNode(key string, node document.Transformer, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toDocumentTransformerNode(node, opts...)) +} + +// AddLambdaNode add node that implements at least one of Invoke[I, O], Stream[I, O], Collect[I, O], Transform[I, O]. +// due to the lack of supporting method generics, we need to use function generics to generate Lambda run as Runnable[I, O]. +// for Invoke[I, O], use compose.InvokableLambda() +// for Stream[I, O], use compose.StreamableLambda() +// for Collect[I, O], use compose.CollectableLambda() +// for Transform[I, O], use compose.TransformableLambda() +// for arbitrary combinations of 4 kinds of lambda, use compose.AnyLambda() +func (g *graph) AddLambdaNode(key string, node *Lambda, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toLambdaNode(node, opts...)) +} + +// AddGraphNode add one kind of Graph[I, O]、Chain[I, O]、StateChain[I, O, S] as a node. +// for Graph[I, O], comes from NewGraph[I, O]() +// for Chain[I, O], comes from NewChain[I, O]() +// for StateGraph[I, O, S], comes from NewStateGraph[I, O, S]() +func (g *graph) AddGraphNode(key string, node AnyGraph, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toAnyGraphNode(node, opts...)) +} + +// AddPassthroughNode adds a passthrough node to the graph. +// mostly used in pregel mode of graph. +// eg. +// +// graph.AddPassthroughNode("passthrough_node_key") +func (g *graph) AddPassthroughNode(key string, opts ...GraphAddNodeOpt) error { + return g.addNode(key, toPassthroughNode(opts...)) +} + +// AddBranch adds a branch to the graph. +// eg. +// +// condition := func(ctx context.Context, in string) (string, error) { +// return "next_node_key", nil +// } +// endNodes := map[string]bool{"path01": true, "path02": true} +// branch := compose.NewGraphBranch(condition, endNodes) +// +// graph.AddBranch("start_node_key", branch) +func (g *graph) AddBranch(startNode string, branch *GraphBranch) (err error) { + if g.buildError != nil { + return g.buildError + } + defer func() { + if err != nil { + g.buildError = err + } + }() + + if g.frozen { + return errGraphFrozen + } + + if startNode == END { + return errors.New("END cannot be a start node") + } + + if _, ok := g.nodes[startNode]; !ok && startNode != START { + return fmt.Errorf("branch start node '%s' needs to be added to graph first", startNode) + } + + if len(branch.endNodes) == 1 { + return fmt.Errorf("number of branches is 1") + } + + if _, ok := g.runtimeCheckBranches[startNode]; !ok { + g.runtimeCheckBranches[startNode] = []bool{} + } + branch.idx = len(g.runtimeCheckBranches[startNode]) + + // check branch condition type + result := checkAssignable(g.getNodeOutputType(startNode), branch.condition.inputType) + if result == assignableTypeMustNot { + return fmt.Errorf("condition input type[%s] and start node output type[%s] are mismatched", branch.condition.inputType.String(), g.getNodeOutputType(startNode).String()) + } else if result == assignableTypeMay { + g.runtimeCheckBranches[startNode] = append(g.runtimeCheckBranches[startNode], true) + } else { + g.runtimeCheckBranches[startNode] = append(g.runtimeCheckBranches[startNode], false) + } + + for endNode := range branch.endNodes { + if _, ok := g.nodes[endNode]; !ok { + if endNode != END { + return fmt.Errorf("branch end node '%s' needs to be added to graph first", endNode) + } + } + + err := g.validateAndInferType(startNode, endNode) + if err != nil { + return err + } + + if startNode == START { + g.startNodes = append(g.startNodes, endNode) + } + if endNode == END { + g.endNodes = append(g.endNodes, startNode) + } + + err = g.updateToValidateMap() + if err != nil { + return err + } + } + + g.branches[startNode] = append(g.branches[startNode], branch) + + return nil +} + +func (g *graph) validateAndInferType(startNode, endNode string) error { + startNodeOutputType := g.getNodeOutputType(startNode) + endNodeInputType := g.getNodeInputType(endNode) + + // assume that START and END type isn't empty + // check and update current node. if cannot validate, save edge to toValidateMap + if startNodeOutputType == nil && endNodeInputType == nil { + // type of passthrough have not been inferred yet. defer checking to compile. + g.toValidateMap[startNode] = append(g.toValidateMap[startNode], endNode) + } else if startNodeOutputType != nil && endNodeInputType == nil { + // end node is passthrough, propagate start node output type to it + g.nodes[endNode].cr.inputType = startNodeOutputType + g.nodes[endNode].cr.outputType = g.nodes[endNode].cr.inputType + } else if startNodeOutputType == nil /* redundant condition && endNodeInputType != nil */ { + // start node is passthrough, propagate end node input type to it + g.nodes[startNode].cr.inputType = endNodeInputType + g.nodes[startNode].cr.outputType = g.nodes[startNode].cr.inputType + } else { + // common node check + result := checkAssignable(startNodeOutputType, endNodeInputType) + if result == assignableTypeMustNot { + return fmt.Errorf("graph edge[%s]-[%s]: start node's output type[%s] and end node's input type[%s] mismatch", + startNode, endNode, startNodeOutputType.String(), endNodeInputType.String()) + } else if result == assignableTypeMay { + // add runtime check edges + if _, ok := g.runtimeCheckEdges[startNode]; !ok { + g.runtimeCheckEdges[startNode] = make(map[string]bool) + } + g.runtimeCheckEdges[startNode][endNode] = true + } + } + return nil +} + +// updateToValidateMap after update node, check validate map +// check again if nodes in toValidateMap have been updated. because when there are multiple linked passthrough nodes, in the worst scenario, only one node can be updated at a time. +func (g *graph) updateToValidateMap() error { + var startNodeOutputType, endNodeInputType reflect.Type + for { + hasChanged := false + for startNode := range g.toValidateMap { + startNodeOutputType = g.getNodeOutputType(startNode) + + for i := 0; i < len(g.toValidateMap[startNode]); i++ { + endNode := g.toValidateMap[startNode][i] + + endNodeInputType = g.getNodeInputType(endNode) + if startNodeOutputType == nil && endNodeInputType == nil { + continue + } + + // update toValidateMap + g.toValidateMap[startNode] = append(g.toValidateMap[startNode][:i], g.toValidateMap[startNode][i+1:]...) + i-- + + hasChanged = true + // assume that START and END type isn't empty + if startNodeOutputType != nil && endNodeInputType == nil { + g.nodes[endNode].cr.inputType = startNodeOutputType + g.nodes[endNode].cr.outputType = g.nodes[endNode].cr.inputType + } else if startNodeOutputType == nil /* redundant condition && endNodeInputType != nil */ { + g.nodes[startNode].cr.inputType = endNodeInputType + g.nodes[startNode].cr.outputType = g.nodes[startNode].cr.inputType + } else { + // common node check + result := checkAssignable(startNodeOutputType, endNodeInputType) + if result == assignableTypeMustNot { + return fmt.Errorf("graph edge[%s]-[%s]: start node's output type[%s] and end node's input type[%s] mismatch", + startNode, endNode, startNodeOutputType.String(), endNodeInputType.String()) + } else if result == assignableTypeMay { + // add runtime check edges + if _, ok := g.runtimeCheckEdges[startNode]; !ok { + g.runtimeCheckEdges[startNode] = make(map[string]bool) + } + g.runtimeCheckEdges[startNode][endNode] = true + } + } + } + } + if !hasChanged { + break + } + } + + return nil +} + +func (g *graph) getNodeInputType(name string) reflect.Type { + if name == START { + return g.inputType() + } else if name == END { + return g.outputType() + } + return g.nodes[name].inputType() +} + +func (g *graph) getNodeOutputType(name string) reflect.Type { + if name == START { + return g.inputType() + } else if name == END { + return g.outputType() + } + return g.nodes[name].outputType() +} + +func (g *graph) inputType() reflect.Type { + return g.expectedInputType +} + +func (g *graph) outputType() reflect.Type { + return g.expectedOutputType +} + +func (g *graph) compile(ctx context.Context, opt *graphCompileOptions) (*composableRunnable, error) { + if g.buildError != nil { + return nil, g.buildError + } + + err := g.compileChecker(opt) + if err != nil { + return nil, err + } + + runType := runTypePregel + cb := pregelChannelBuilder + if opt != nil { + if opt.nodeTriggerMode == AllPredecessor { + runType = runTypeDAG + cb = dagChannelBuilder + } + } + + if len(g.startNodes) == 0 { + return nil, errors.New("start node not set") + } + if len(g.endNodes) == 0 { + return nil, errors.New("end node not set") + } + + // toValidateMap isn't empty means there are nodes that cannot infer type + for _, v := range g.toValidateMap { + if len(v) > 0 { + return nil, fmt.Errorf("some node's input or output types cannot be inferred: %v", g.toValidateMap) + } + } + + // dag doesn't support branch + if runType == runTypeDAG && len(g.branches) > 0 { + return nil, fmt.Errorf("dag doesn't support branch for now") + } + + key2SubGraphs := g.beforeChildGraphsCompile(opt) + chanSubscribeTo := make(map[string]*chanCall) + for name, node := range g.nodes { + node.beforeChildGraphCompile(name, key2SubGraphs) + + r, err := node.compileIfNeeded(ctx) + if err != nil { + return nil, err + } + + writeTo := g.edges[name] + chCall := &chanCall{ + action: r, + writeTo: writeTo, + + preProcessor: node.nodeInfo.preProcessor, + postProcessor: node.nodeInfo.postProcessor, + } + + branches := g.branches[name] + if len(branches) > 0 { + branchRuns := make([]*GraphBranch, 0, len(branches)) + for _, branch := range branches { + branchRuns = append(branchRuns, branch) + } + + chCall.writeToBranches = branchRuns + } + + chanSubscribeTo[name] = chCall + + } + + invertedEdges := make(map[string][]string) + for start, ends := range g.edges { + for _, end := range ends { + if _, ok := invertedEdges[end]; !ok { + invertedEdges[end] = []string{start} + } else { + invertedEdges[end] = append(invertedEdges[end], start) + } + + } + } + + inputChannels := &chanCall{ + writeTo: g.edges[START], + writeToBranches: make([]*GraphBranch, len(g.branches[START])), + } + for i := range g.branches[START] { + inputChannels.writeToBranches[i] = g.branches[START][i] + } + + // validate dag + if runType == runTypeDAG { + for _, node := range g.startNodes { + if len(invertedEdges[node]) != 1 { + return nil, fmt.Errorf("dag start node[%s] should not have predecessor other than 'start', but got: %v", node, invertedEdges[node]) + } + } + } + + r := &runner{ + invertedEdges: invertedEdges, + chanSubscribeTo: chanSubscribeTo, + inputChannels: inputChannels, + + runCtx: g.runCtx, + chanBuilder: cb, + + inputType: g.inputType(), + outputType: g.outputType(), + inputStreamFilter: g.inputStreamFilter, + inputValueChecker: g.inputValueChecker, + inputStreamConverter: g.inputStreamConverter, + outputValueChecker: g.outputValueChecker, + outputStreamConverter: g.outputStreamConverter, + + runtimeCheckEdges: g.runtimeCheckEdges, + runtimeCheckBranches: g.runtimeCheckBranches, + } + + if runType == runTypeDAG { + err = validateDAG(r.chanSubscribeTo, r.invertedEdges) + if err != nil { + return nil, err + } + } + + if opt != nil { + r.options = *opt + } + + // default options + if r.options.maxRunSteps == 0 { + r.options.maxRunSteps = len(r.chanSubscribeTo) + 10 + } + + g.freeze() + + g.onCompileFinish(ctx, opt, key2SubGraphs) + + return r.toComposableRunnable() +} + +type subGraphCompileCallback struct { + closure func(ctx context.Context, info *GraphInfo) +} + +// OnFinish is called when the graph is compiled. +func (s *subGraphCompileCallback) OnFinish(ctx context.Context, info *GraphInfo) { + s.closure(ctx, info) +} + +func (g *graph) beforeChildGraphsCompile(opt *graphCompileOptions) map[string]*GraphInfo { + if opt == nil || len(opt.callbacks) == 0 { + return nil + } + + return make(map[string]*GraphInfo) +} + +func (gn *graphNode) beforeChildGraphCompile(nodeKey string, key2SubGraphs map[string]*GraphInfo) { + if gn.g == nil || key2SubGraphs == nil { + return + } + + subGraphCallback := func(ctx2 context.Context, subGraph *GraphInfo) { + key2SubGraphs[nodeKey] = subGraph + } + + gn.nodeInfo.compileOption.callbacks = append(gn.nodeInfo.compileOption.callbacks, &subGraphCompileCallback{closure: subGraphCallback}) +} + +func (g *graph) toGraphInfo(ctx context.Context, opt *graphCompileOptions, key2SubGraphs map[string]*GraphInfo) *GraphInfo { + + graphKey := g.runtimeGraphKey + if opt.graphKey != "" { + graphKey = opt.graphKey + } + + gInfo := &GraphInfo{ + Key: graphKey, + CompileOptions: opt.origOpts, + Nodes: make(map[string]GraphNodeInfo, len(g.nodes)), + Edges: gmap.Clone(g.edges), + Branches: gmap.Map(g.branches, func(startNode string, branches []*GraphBranch) (string, []GraphBranch) { + branchInfo := make([]GraphBranch, 0, len(branches)) + for _, b := range branches { + branchInfo = append(branchInfo, GraphBranch{ + condition: b.condition, + endNodes: gmap.Clone(b.endNodes), + }) + } + return startNode, branchInfo + }), + InputType: g.expectedInputType, + OutputType: g.expectedOutputType, + } + + for key := range g.nodes { + gNode := g.nodes[key] + if gNode.executorMeta.component == ComponentOfPassthrough { + gInfo.Nodes[key] = GraphNodeInfo{ + Component: gNode.executorMeta.component, + GraphAddNodeOpts: gNode.opts, + InputType: gNode.cr.inputType, + OutputType: gNode.cr.outputType, + Name: gNode.getNodeName(), + InputKey: gNode.cr.nodeInfo.inputKey, + OutputKey: gNode.cr.nodeInfo.outputKey, + } + continue + } + + gNodeInfo := &GraphNodeInfo{ + Component: gNode.executorMeta.component, + Instance: gNode.instance, + GraphAddNodeOpts: gNode.opts, + InputType: gNode.cr.inputType, + OutputType: gNode.cr.outputType, + Name: gNode.getNodeName(), + InputKey: gNode.cr.nodeInfo.inputKey, + OutputKey: gNode.cr.nodeInfo.outputKey, + } + + if gi, ok := key2SubGraphs[key]; ok { + gNodeInfo.GraphInfo = gi + } + + gInfo.Nodes[key] = *gNodeInfo + } + + if g.runCtx != nil { + gInfo.GenStateFn = func(ctx context.Context) any { + stateCtx := g.runCtx(ctx) + state, err := GetState[any](stateCtx) + if err != nil { + return nil + } + + return state + } + } + + return gInfo +} + +func (g *graph) onCompileFinish(ctx context.Context, opt *graphCompileOptions, key2SubGraphs map[string]*GraphInfo) { + if opt == nil { + return + } + + if len(opt.callbacks) == 0 { + return + } + + gInfo := g.toGraphInfo(ctx, opt, key2SubGraphs) + + for _, cb := range opt.callbacks { + cb.OnFinish(ctx, gInfo) + } +} + +func (g *graph) GetType() string { + return "" +} + +func defaultCompileChecker(options *graphCompileOptions) error { + return nil +} + +func transferTask(script [][]string, invertedEdges map[string][]string) [][]string { + utilMap := map[string]bool{} + for i := len(script) - 1; i >= 0; i-- { + for j := 0; j < len(script[i]); j++ { + // deduplicate + if _, ok := utilMap[script[i][j]]; ok { + script[i] = append(script[i][:j], script[i][j+1:]...) + j-- + continue + } + utilMap[script[i][j]] = true + + target := i + for k := i + 1; k < len(script); k++ { + hasDependencies := false + for l := range script[k] { + for _, dependency := range invertedEdges[script[i][j]] { + if script[k][l] == dependency { + hasDependencies = true + break + } + } + if hasDependencies { + break + } + } + if hasDependencies { + break + } + target = k + } + if target != i { + script[target] = append(script[target], script[i][j]) + script[i] = append(script[i][:j], script[i][j+1:]...) + j-- + } + } + } + + return script +} + +func validateDAG(chanSubscribeTo map[string]*chanCall, invertedEdges map[string][]string) error { + m := map[string]int{} + for node := range chanSubscribeTo { + if edges, ok := invertedEdges[node]; ok { + m[node] = len(edges) + for _, pre := range edges { + if pre == START { + m[node] -= 1 + } + } + } else { + m[node] = 0 + } + } + hasChanged := true + for hasChanged { + hasChanged = false + for node := range m { + if m[node] == 0 { + hasChanged = true + for _, subNode := range chanSubscribeTo[node].writeTo { + if subNode == END { + continue + } + m[subNode]-- + } + m[node] = -1 + } + } + } + + for k, v := range m { + if v > 0 { + return fmt.Errorf("DAG invalid, node[%s] has loop", k) + } + } + return nil +} + +func wrapCompileChecker(checkers ...func(options *graphCompileOptions) error) func(options *graphCompileOptions) error { + return func(options *graphCompileOptions) error { + for _, checker := range checkers { + if checker == nil { + continue + } + + if err := checker(options); err != nil { + return err + } + } + + return nil + } +} + +func defaultGraphKey() string { + pcs := make([]uintptr, 1) + _ = runtime.Callers(3, pcs) + frame, _ := runtime.CallersFrames(pcs).Next() + return fmt.Sprintf("%s:%d", frame.Function, frame.Line) +} diff --git a/compose/graph_add_node_options.go b/compose/graph_add_node_options.go new file mode 100644 index 0000000..e6b1c65 --- /dev/null +++ b/compose/graph_add_node_options.go @@ -0,0 +1,152 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +type graphAddNodeOpts struct { + nodeOptions *nodeOptions + processor *processorOpts +} + +// GraphAddNodeOpt is a functional option type for adding a node to a graph. +// eg. +// +// graph.AddNode("node_name", node, compose.WithInputKey("input_key"), compose.WithOutputKey("output_key")) +type GraphAddNodeOpt func(o *graphAddNodeOpts) + +type nodeOptions struct { + // same as graphNode.name + nodeName string + + // same as graphNode.key + nodeKey string + + inputKey string + outputKey string + + graphCompileOption []GraphCompileOption // when this node is itself an AnyGraph, this option will be used to compile the node as a nested graph +} + +// WithNodeName sets the name of the node. +func WithNodeName(n string) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.nodeOptions.nodeName = n + } +} + +// WithNodeKey set the node key, which is used to identify the node in the chain. +// only for use in Chain/StateChain. +func WithNodeKey(key string) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.nodeOptions.nodeKey = key + } +} + +// WithInputKey sets the input key of the node. +// this will change the input value of the node, for example, if the pre node's output is map[string]any{"key01": "value01"}, +// and the current node's input key is "key01", then the current node's input value will be "value01". +func WithInputKey(k string) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.nodeOptions.inputKey = k + } +} + +// WithOutputKey sets the output key of the node. +// this will change the output value of the node, for example, if the current node's output key is "key01", +// then the node's output value will be map[string]any{"key01": value}. +func WithOutputKey(k string) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.nodeOptions.outputKey = k + } +} + +// WithGraphCompileOptions when the node is an AnyGraph, use this option to set compile option for the node. +// eg. +// +// graph.AddNode("node_name", node, compose.WithGraphCompileOptions(compose.WithGraphName("my_sub_graph"))) +func WithGraphCompileOptions(opts ...GraphCompileOption) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.nodeOptions.graphCompileOption = opts + } +} + +// WithStatePreHandler modify node's input of I according to state S and input or store input information into state, and it's thread-safe. +// notice: this option is only for StateGraph. it will cause an error when passed to Graph. +// I: input type of the Node like ChatModel, Lambda, Retriever etc. +// S: state type of StateGraph +func WithStatePreHandler[I, S any](pre StatePreHandler[I, S]) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.processor.statePreHandler = convertPreHandler(pre) + } +} + +// WithStatePostHandler modify node's output of O according to state S and output or store output information into state, and it's thread-safe. +// notice: this option is only for StateGraph. it will cause an error when passed to Graph. +// O: output type of the Node like ChatModel, Lambda, Retriever etc. +// S: state type of StateGraph +func WithStatePostHandler[O, S any](post StatePostHandler[O, S]) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.processor.statePostHandler = convertPostHandler(post) + } +} + +// WithStreamStatePreHandler modify node's streaming input of I according to state S and input or store input information into state, and it's thread-safe. +// notice: this option is only for StateGraph. it will cause an error when passed to Graph. +// when to use: when upstream node's output is an actual stream, and you want the current node's input to remain an actual stream after state pre handler. +// caution: while StreamStatePreHandler is thread safe, modifying state within your own goroutine is NOT. +// I: input type of the Node like ChatModel, Lambda, Retriever etc. +// S: state type of StateGraph +func WithStreamStatePreHandler[I, S any](pre StreamStatePreHandler[I, S]) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.processor.statePreHandler = streamConvertPreHandler(pre) + } +} + +// WithStreamStatePostHandler modify node's streaming output of O according to state S and output or store output information into state, and it's thread-safe. +// notice: this option is only for StateGraph. it will cause an error when passed to Graph. +// when to use: when current node's output is an actual stream, and you want the downstream node's input to remain an actual stream after state post handler. +// caution: while StreamStatePostHandler is thread safe, modifying state within your own goroutine is NOT. +// O: output type of the Node like ChatModel, Lambda, Retriever etc. +// S: state type of StateGraph +func WithStreamStatePostHandler[O, S any](post StreamStatePostHandler[O, S]) GraphAddNodeOpt { + return func(o *graphAddNodeOpts) { + o.processor.statePostHandler = streamConvertPostHandler(post) + } +} + +type processorOpts struct { + statePreHandler *composableRunnable + statePostHandler *composableRunnable +} + +func getGraphAddNodeOpts(opts ...GraphAddNodeOpt) *graphAddNodeOpts { + opt := &graphAddNodeOpts{ + nodeOptions: &nodeOptions{ + nodeName: "", + nodeKey: "", + }, + processor: &processorOpts{ + statePreHandler: nil, + statePostHandler: nil, + }, + } + + for _, fn := range opts { + fn(opt) + } + + return opt +} diff --git a/compose/graph_call_options.go b/compose/graph_call_options.go new file mode 100644 index 0000000..38ea490 --- /dev/null +++ b/compose/graph_call_options.go @@ -0,0 +1,190 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "fmt" + "reflect" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" +) + +// Option is a functional option type for calling a graph. +type Option struct { + options []any + handler []callbacks.Handler + + nodeHandler []callbacks.Handler // deprecated + keys []string + + graphHandler []callbacks.Handler // deprecated + maxRunSteps int +} + +// DesignateNode set the key of the node which will be used to. +// eg. +// +// embeddingOption := compose.WithEmbeddingOption(embedding.WithModel("text-embedding-3-small")) +// runnable.Invoke(ctx, "input", embeddingOption.DesignateNode("my_embedding_node")) +func (o Option) DesignateNode(key ...string) Option { + o.keys = append(o.keys, key...) + return o +} + +// WithEmbeddingOption is a functional option type for embedding component. +// eg. +// +// embeddingOption := compose.WithEmbeddingOption(embedding.WithModel("text-embedding-3-small")) +// runnable.Invoke(ctx, "input", embeddingOption) +func WithEmbeddingOption(opts ...embedding.Option) Option { + return withComponentOption(opts...) +} + +// WithRetrieverOption is a functional option type for retriever component. +// eg. +// +// retrieverOption := compose.WithRetrieverOption(retriever.WithIndex("my_index")) +// runnable.Invoke(ctx, "input", retrieverOption) +func WithRetrieverOption(opts ...retriever.Option) Option { + return withComponentOption(opts...) +} + +// Deprecated: use WithLoaderOption instead. +func WithLoaderSplitterOption(opts ...document.LoaderSplitterOption) Option { + return withComponentOption(opts...) +} + +// WithLoaderOption is a functional option type for loader component. +// eg. +// +// loaderOption := compose.WithLoaderOption(document.WithCollection("my_collection")) +// runnable.Invoke(ctx, "input", loaderOption) +func WithLoaderOption(opts ...document.LoaderOption) Option { + return withComponentOption(opts...) +} + +// WithDocumentTransformerOption is a functional option type for document transformer component. +func WithDocumentTransformerOption(opts ...document.TransformerOption) Option { + return withComponentOption(opts...) +} + +// WithIndexerOption is a functional option type for indexer component. +// eg. +// +// indexerOption := compose.WithIndexerOption(indexer.WithSubIndexes([]string{"my_sub_index"})) +// runnable.Invoke(ctx, "input", indexerOption) +func WithIndexerOption(opts ...indexer.Option) Option { + return withComponentOption(opts...) +} + +// WithChatModelOption is a functional option type for chat model component. +// eg. +// +// chatModelOption := compose.WithChatModelOption(model.WithTemperature(0.7)) +// runnable.Invoke(ctx, "input", chatModelOption) +func WithChatModelOption(opts ...model.Option) Option { + return withComponentOption(opts...) +} + +// WithChatTemplateOption is a functional option type for chat template component. +func WithChatTemplateOption(opts ...prompt.Option) Option { + return withComponentOption(opts...) +} + +// WithToolsNodeOption is a functional option type for tools node component. +func WithToolsNodeOption(opts ...ToolsNodeOption) Option { + return withComponentOption(opts...) +} + +// WithLambdaOption is a functional option type for lambda component. +func WithLambdaOption(opts ...any) Option { + return Option{ + options: opts, + keys: make([]string, 0), + } +} + +// WithCallbacks set callback handlers for all components in a single call. +// eg. +// +// runnable.Invoke(ctx, "input", compose.WithCallbacks(&myCallbacks{})) +func WithCallbacks(cbs ...callbacks.Handler) Option { + return Option{ + handler: cbs, + } +} + +// Deprecated: use WithCallbacks and perform the type checking for component within it instead +func WithNodeCallbacks(cbs ...callbacks.Handler) Option { + return Option{ + nodeHandler: cbs, + } +} + +// Deprecated: use WithCallbacks and perform the type checking for component within it instead +func WithGraphCallbacks(cbs ...callbacks.Handler) Option { + return Option{ + graphHandler: cbs, + } +} + +// Deprecated: use WithRuntimeMaxSteps directly instead. +func WithGraphRunOption(opt Option) Option { + return opt +} + +// WithRuntimeMaxSteps sets the maximum number of steps for the graph runtime. +// eg. +// +// runnable.Invoke(ctx, "input", compose.WithRuntimeMaxSteps(20)) +func WithRuntimeMaxSteps(maxSteps int) Option { + return Option{ + maxRunSteps: maxSteps, + } +} + +func withComponentOption[TOption any](opts ...TOption) Option { + o := make([]any, 0, len(opts)) + for i := range opts { + o = append(o, opts[i]) + } + return Option{ + options: o, + keys: make([]string, 0), + } +} + +func convertOption[TOption any](opts ...any) ([]TOption, error) { + if len(opts) == 0 { + return nil, nil + } + ret := make([]TOption, 0, len(opts)) + for i := range opts { + o, ok := opts[i].(TOption) + if !ok { + return nil, fmt.Errorf("unexpected component option type, expected:%s, actual:%s", reflect.TypeOf((*TOption)(nil)).Elem().String(), reflect.TypeOf(opts[i]).String()) + } + ret = append(ret, o) + } + return ret, nil +} diff --git a/compose/graph_call_options_test.go b/compose/graph_call_options_test.go new file mode 100644 index 0000000..dbd333f --- /dev/null +++ b/compose/graph_call_options_test.go @@ -0,0 +1,303 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/embedding" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/retriever" + mockDocument "github.com/cloudwego/eino/internal/mock/components/document" + mockEmbedding "github.com/cloudwego/eino/internal/mock/components/embedding" + mockRetriever "github.com/cloudwego/eino/internal/mock/components/retriever" + "github.com/cloudwego/eino/schema" +) + +var optionSuccess = true +var idx int + +func checkOption(opts ...model.Option) bool { + if len(opts) != 2 { + return false + } + o := model.GetCommonOptions(&model.Options{}, opts...) + if o.TopP == nil || *o.TopP != 1.0 { + return false + } + if o.Model == nil { + return false + } + if idx == 0 { + idx = 1 + if o.Model == nil || *o.Model != "123" { + return false + } + } else { + idx = 0 + if o.Model == nil || *o.Model != "456" { + return false + } + } + + return true +} + +type testModel struct{} + +func (t *testModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func (t *testModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + if !checkOption(opts...) { + optionSuccess = false + } + return &schema.Message{}, nil +} + +func (t *testModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + if !checkOption(opts...) { + optionSuccess = false + } + sr, sw := schema.Pipe[*schema.Message](1) + sw.Send(nil, nil) + sw.Close() + return sr, nil +} + +func TestCallOption(t *testing.T) { + g := NewGraph[[]*schema.Message, *schema.Message]() + err := g.AddLambdaNode("1", InvokableLambdaWithOption(func(ctx context.Context, input []*schema.Message, opts ...string) (output []*schema.Message, err error) { + if len(opts) != 1 || opts[0] != "1" { + t.Fatalf("lambda option length isn't 1 or content isn't '1': %v", opts) + } + return input, nil + })) + assert.Nil(t, err) + + err = g.AddChatModelNode("2", &testModel{}) + assert.Nil(t, err) + + err = g.AddLambdaNode("-", InvokableLambda(func(ctx context.Context, input *schema.Message) (output []*schema.Message, err error) { + return []*schema.Message{input}, nil + })) + assert.Nil(t, err) + + err = g.AddChatModelNode("3", &testModel{}) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("1", "2") + assert.Nil(t, err) + + err = g.AddEdge("2", "-") + assert.Nil(t, err) + + err = g.AddEdge("-", "3") + assert.Nil(t, err) + + err = g.AddEdge("3", END) + assert.Nil(t, err) + + ctx := context.Background() + + r, err := g.Compile(ctx) + assert.Nil(t, err) + + sessionKey := struct{}{} + startCnt := 0 + endCnt := 0 + opts := []Option{ + WithChatModelOption( + model.WithModel("123"), + ).DesignateNode("2"), + WithChatModelOption( + model.WithModel("456"), + ).DesignateNode("3"), + WithChatModelOption( + model.WithTopP(1.0), + ), + WithGraphCallbacks(), + WithNodeCallbacks(callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + startCnt++ + return context.WithValue(ctx, sessionKey, "start") + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + if ctx.Value(sessionKey).(string) == "start" { + endCnt++ + return context.WithValue(ctx, sessionKey, "end") + } + return ctx + }).Build()).DesignateNode("3"), + WithLambdaOption("1").DesignateNode("1"), + } + + _, err = r.Invoke(ctx, []*schema.Message{}, + opts...) + if err != nil { + t.Fatal(err) + } + if !optionSuccess { + t.Fatal("invoke option fail") + } + if startCnt != 1 { + t.Fatal("node callback fail") + } + if endCnt != 1 { + t.Fatal("node callback fail") + } + _, err = r.Stream(ctx, []*schema.Message{}, + opts...) + if err != nil { + t.Fatal(err) + } + if !optionSuccess { + t.Fatal("stream option fail") + } + + srOfCollect, swOfCollect := schema.Pipe[[]*schema.Message](1) + swOfCollect.Send([]*schema.Message{}, nil) + swOfCollect.Close() + _, err = r.Collect(ctx, srOfCollect, opts...) + assert.Nil(t, err) + + if !optionSuccess { + t.Fatal("collect option fail") + } + + srOfTransform, swOfTransform := schema.Pipe[[]*schema.Message](1) + swOfTransform.Send([]*schema.Message{}, nil) + swOfTransform.Close() + _, err = r.Transform(ctx, srOfTransform, opts...) + assert.Nil(t, err) + + if !optionSuccess { + t.Fatal("transform option fail") + } +} + +func TestCallOptionsOneByOne(t *testing.T) { + ctx := context.Background() + t.Run("common_option", func(t *testing.T) { + type option struct { + uid int64 + } + + opt := withComponentOption(&option{uid: 100}) + assert.Len(t, opt.options, 1) + assert.IsType(t, &option{}, opt.options[0]) + assert.Equal(t, &option{uid: 100}, opt.options[0]) + }) + + t.Run("embedding_option", func(t *testing.T) { + ctrl := gomock.NewController(t) + inst := mockEmbedding.NewMockEmbedder(ctrl) + var opt *embedding.Options + inst.EXPECT().EmbedStrings(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + opt = embedding.GetCommonOptions(&embedding.Options{}, opts...) + return nil, nil + }).Times(1) + ch := NewChain[map[string]any, map[string]any]() + ch.AppendEmbedding(inst, WithInputKey("input"), WithOutputKey("output")) + r, err := ch.Compile(ctx) + assert.NoError(t, err) + outs, err := r.Invoke(ctx, + map[string]any{"input": []string{}}, + WithEmbeddingOption(embedding.WithModel("123")), + ) + assert.NoError(t, err) + assert.Contains(t, outs, "output") + + assert.NotNil(t, opt.Model) + assert.Equal(t, "123", *opt.Model) + }) + + t.Run("retriever_option", func(t *testing.T) { + ctrl := gomock.NewController(t) + inst := mockRetriever.NewMockRetriever(ctrl) + var opt *retriever.Options + inst.EXPECT().Retrieve(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + opt = retriever.GetCommonOptions(&retriever.Options{}, opts...) + return nil, nil + }). + Times(1) + ch := NewChain[map[string]any, map[string]any]() + ch.AppendRetriever(inst, WithInputKey("input"), WithOutputKey("output")) + r, err := ch.Compile(ctx) + assert.NoError(t, err) + outs, err := r.Invoke(ctx, + map[string]any{"input": "hi"}, + WithRetrieverOption(retriever.WithIndex("123")), + ) + assert.NoError(t, err) + assert.Contains(t, outs, "output") + + assert.NotNil(t, opt.Index) + assert.Equal(t, "123", *opt.Index) + }) + + t.Run("loader_option", func(t *testing.T) { + ctrl := gomock.NewController(t) + inst := mockDocument.NewMockLoader(ctrl) + type implOption struct { + uid int64 + } + + type implOptFn func(o *implOption) + + withUID := func(uid int64) document.LoaderOption { + return document.WrapLoaderImplSpecificOptFn[implOption](func(i *implOption) { + i.uid = uid + }) + } + + var opt *implOption + + inst.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, src document.Source, opts ...document.LoaderOption) ([]*schema.Document, error) { + opt = document.GetLoaderImplSpecificOptions[implOption](&implOption{uid: 1}, opts...) + return nil, nil + }). + Times(1) + ch := NewChain[map[string]any, map[string]any]() + ch.AppendLoader(inst, WithInputKey("input"), WithOutputKey("output")) + r, err := ch.Compile(ctx) + assert.NoError(t, err) + outs, err := r.Invoke(ctx, + map[string]any{"input": document.Source{}}, + WithLoaderOption(withUID(123)), + ) + assert.NoError(t, err) + assert.Contains(t, outs, "output") + + assert.Equal(t, int64(123), opt.uid) + }) +} diff --git a/compose/graph_compile_options.go b/compose/graph_compile_options.go new file mode 100644 index 0000000..5d96b1e --- /dev/null +++ b/compose/graph_compile_options.go @@ -0,0 +1,68 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +// GraphCompileOption options for compiling AnyGraph. +type GraphCompileOption func(*graphCompileOptions) + +func WithMaxRunSteps(maxSteps int) GraphCompileOption { + return func(o *graphCompileOptions) { + o.maxRunSteps = maxSteps + } +} + +func WithGraphName(graphName string) GraphCompileOption { + return func(o *graphCompileOptions) { + o.graphName = graphName + } +} + +func WithGraphKey(graphKey string) GraphCompileOption { + return func(o *graphCompileOptions) { + o.graphKey = graphKey + } +} + +// WithNodeTriggerMode sets node trigger mode for the graph. +// Different node trigger mode will affect graph execution order and result for specific graphs, such as those with parallel branches having different length of nodes. +func WithNodeTriggerMode(triggerMode NodeTriggerMode) GraphCompileOption { + return func(o *graphCompileOptions) { + o.nodeTriggerMode = triggerMode + } +} + +// WithGraphCompileCallbacks sets callbacks for graph compilation. +func WithGraphCompileCallbacks(cbs ...GraphCompileCallback) GraphCompileOption { + return func(o *graphCompileOptions) { + o.callbacks = append(o.callbacks, cbs...) + } +} + +// withComponent sets the component type of the graph. ONLY FOR INTERNAL. +func withComponent(component component) GraphCompileOption { + return func(o *graphCompileOptions) { + o.component = component + } +} + +// InitGraphCompileCallbacks set global graph compile callbacks, +// which ONLY will be added to top level graph compile options +func InitGraphCompileCallbacks(cbs []GraphCompileCallback) { + globalGraphCompileCallbacks = cbs +} + +var globalGraphCompileCallbacks []GraphCompileCallback diff --git a/compose/graph_node.go b/compose/graph_node.go new file mode 100644 index 0000000..cbe4d15 --- /dev/null +++ b/compose/graph_node.go @@ -0,0 +1,188 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "errors" + "reflect" + + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/utils/generic" +) + +// the info of most original executable object directly provided by the user +type executorMeta struct { + + // automatically identified based on the way of addNode + component component + + // indicates whether the executable object user provided could execute the callback aspect itself. + // if it could, the callback in the corresponding graph node won't be executed + // for components, the value comes from callbacks.Checker + isComponentCallbackEnabled bool + + // for components, the value comes from components.Typer + // for lambda, the value comes from the user's explicit config + // if componentImplType is empty, then the class name or func name in the instance will be inferred, but no guarantee. + componentImplType string +} + +type nodeInfo struct { + + // the name of graph node for display purposes, not unique. + // passed from WithNodeName() + // if not set, it will be inferred from the component type and component name + name string + + // node key: the identity key of the node in a graph + // passed from WithNodeKey() + key string + + inputKey string + outputKey string + + preProcessor, postProcessor *composableRunnable + + compileOption *graphCompileOptions // if the node is an AnyGraph, it will need compile options of its own +} + +// graphNode the complete information of the node in graph +type graphNode struct { + cr *composableRunnable + + g AnyGraph + + nodeInfo *nodeInfo + executorMeta *executorMeta + + instance any + opts []GraphAddNodeOpt +} + +func (gn *graphNode) isPassthrough() bool { + // priority follow compile + return gn.g == nil && gn.cr != nil && gn.cr.isPassthrough +} + +func (gn *graphNode) inputType() reflect.Type { + if gn.nodeInfo != nil && len(gn.nodeInfo.inputKey) != 0 { + return generic.TypeOf[map[string]any]() + } + // priority follow compile + if gn.g != nil { + return gn.g.inputType() + } else if gn.cr != nil { + return gn.cr.inputType + } + + return nil +} + +func (gn *graphNode) outputType() reflect.Type { + if gn.nodeInfo != nil && len(gn.nodeInfo.outputKey) != 0 { + return generic.TypeOf[map[string]any]() + } + // priority follow compile + if gn.g != nil { + return gn.g.outputType() + } else if gn.cr != nil { + return gn.cr.outputType + } + + return nil +} + +func (gn *graphNode) setOutputKey(key string) { + gn.nodeInfo.outputKey = key +} + +func (gn *graphNode) compileIfNeeded(ctx context.Context) (*composableRunnable, error) { + var r *composableRunnable + if gn.g != nil { + cr, err := gn.g.compile(ctx, gn.nodeInfo.compileOption) + if err != nil { + return nil, err + } + + r = cr + gn.cr = cr + } else if gn.cr != nil { + r = gn.cr + } else { + return nil, errors.New("no graph or component provided") + } + + r.meta = gn.executorMeta + r.nodeInfo = gn.nodeInfo + + if gn.nodeInfo.outputKey != "" { + r = outputKeyedComposableRunnable(gn.nodeInfo.outputKey, r) + } + + if gn.nodeInfo.inputKey != "" { + r = inputKeyedComposableRunnable(gn.nodeInfo.inputKey, r) + } + + return r, nil +} + +func (gn *graphNode) getNodeName() string { + return generateName(gn.nodeInfo.name, gn.executorMeta) +} + +func generateName(name string, meta *executorMeta) string { + if name != "" { + return name + } + + if meta.componentImplType != "" { + return meta.componentImplType + string(meta.component) + } + + return string(meta.component) + +} + +func parseExecutorInfoFromComponent(c component, executor any) *executorMeta { + + componentImplType, ok := components.GetType(executor) + if !ok { + componentImplType = generic.ParseTypeName(reflect.ValueOf(executor)) + } + + return &executorMeta{ + component: c, + isComponentCallbackEnabled: components.IsCallbacksEnabled(executor), + componentImplType: componentImplType, + } +} + +func getNodeInfo(opts ...GraphAddNodeOpt) *nodeInfo { + + opt := getGraphAddNodeOpts(opts...) + + return &nodeInfo{ + name: opt.nodeOptions.nodeName, + key: opt.nodeOptions.nodeKey, + inputKey: opt.nodeOptions.inputKey, + outputKey: opt.nodeOptions.outputKey, + preProcessor: opt.processor.statePreHandler, + postProcessor: opt.processor.statePostHandler, + compileOption: newGraphCompileOptions(opt.nodeOptions.graphCompileOption...), + } +} diff --git a/compose/graph_node_checker.go b/compose/graph_node_checker.go new file mode 100644 index 0000000..ea09e80 --- /dev/null +++ b/compose/graph_node_checker.go @@ -0,0 +1,34 @@ +package compose + +import ( + "fmt" +) + +type nodeChecker func(nodeKey string, node *graphNode) error + +func baseNodeChecker(nodeKey string, node *graphNode) error { + + if node.executorMeta.component == ComponentOfPassthrough && len(node.nodeInfo.inputKey) > 0 { + return fmt.Errorf("paasthrough cannot be set input key, nodeKey=%v", nodeKey) + } + + return nil +} + +func nodeCheckerOfForbidProcessor(next nodeChecker) nodeChecker { + return func(nodeKey string, node *graphNode) error { + if node.nodeInfo.preProcessor != nil || node.nodeInfo.postProcessor != nil { + return fmt.Errorf("only StateGraph support pre/post processor, nodeKey=%v", nodeKey) + } + return next(nodeKey, node) + } +} + +func nodeCheckerOfForbidNodeKey(next nodeChecker) nodeChecker { + return func(nodeKey string, node *graphNode) error { + if node.nodeInfo.key != "" { + return fmt.Errorf("only Chain support WithNodeKey(), nodeKey=%v", nodeKey) + } + return next(nodeKey, node) + } +} diff --git a/compose/graph_run.go b/compose/graph_run.go new file mode 100644 index 0000000..efec570 --- /dev/null +++ b/compose/graph_run.go @@ -0,0 +1,544 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "errors" + "fmt" + "reflect" + "runtime/debug" + "sync" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/safe" +) + +type graphCompileOptions struct { + maxRunSteps int + graphName string + graphKey string + nodeTriggerMode NodeTriggerMode // default to AnyPredecessor (pregel) + + callbacks []GraphCompileCallback + + origOpts []GraphCompileOption + + component component +} + +func newGraphCompileOptions(opts ...GraphCompileOption) *graphCompileOptions { + option := &graphCompileOptions{} + + for _, o := range opts { + o(option) + } + + option.origOpts = opts + + return option +} + +type chanCall struct { + action *composableRunnable + writeTo []string + writeToBranches []*GraphBranch + + preProcessor, postProcessor *composableRunnable +} + +type channel interface { + update(context.Context, map[string]any) error + get(context.Context) (any, error) + ready(context.Context) bool + clear(context.Context) +} + +type chanBuilder func(d []string) channel + +type runner struct { + chanSubscribeTo map[string]*chanCall + invertedEdges map[string][]string + inputChannels *chanCall + + chanBuilder chanBuilder // could be nil + + runCtx func(ctx context.Context) context.Context + + options graphCompileOptions + + inputType reflect.Type + outputType reflect.Type + + inputStreamFilter streamMapFilter + inputValueChecker valueChecker + inputStreamConverter streamConverter + outputValueChecker valueChecker + outputStreamConverter streamConverter + + runtimeCheckEdges map[string]map[string]bool + runtimeCheckBranches map[string][]bool +} + +func (r *runner) toComposableRunnable() (*composableRunnable, error) { + cr := &composableRunnable{ + i: func(ctx context.Context, input any, opts ...any) (output any, err error) { + tos, err := convertOption[Option](opts...) + if err != nil { + return nil, err + } + return r.invoke(ctx, input, tos...) + }, + t: func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { + tos, err := convertOption[Option](opts...) + if err != nil { + return nil, err + } + return r.transform(ctx, input, tos...) + }, + + inputType: r.inputType, + outputType: r.outputType, + inputStreamFilter: r.inputStreamFilter, + inputValueChecker: r.inputValueChecker, + inputStreamConverter: r.inputStreamConverter, + optionType: nil, // if option type is nil, graph will transmit all options. + + isPassthrough: false, + } + + cr.i = genericInvokeWithCallbacks(cr.i) + cr.t = genericTransformWithCallbacks(cr.t) + + return cr, nil +} + +func (r *runner) buildChannels() map[string]channel { + builder := r.chanBuilder + if builder == nil { + builder = func(d []string) channel { + return &pregelChannel{} + } + } + + chs := make(map[string]channel) + for ch := range r.chanSubscribeTo { + chs[ch] = builder(r.invertedEdges[ch]) + } + + chs[END] = builder(r.invertedEdges[END]) + + return chs +} + +type runnableCallWrapper func(context.Context, *composableRunnable, any, ...any) (any, error) + +func runnableInvoke(ctx context.Context, r *composableRunnable, input any, opts ...any) (any, error) { + return r.i(ctx, input, opts...) +} + +func runnableTransform(ctx context.Context, r *composableRunnable, input any, opts ...any) (any, error) { + return r.t(ctx, input.(streamReader), opts...) +} + +func (r *runner) invoke(ctx context.Context, input any, opts ...Option) (any, error) { + return r.run(ctx, false, input, opts...) +} + +func (r *runner) transform(ctx context.Context, input streamReader, opts ...Option) (streamReader, error) { + s, err := r.run(ctx, true, input, opts...) + if err != nil { + return nil, err + } + + return s.(streamReader), nil +} + +func copyItem(item any, n int) []any { + if n < 2 { + return []any{item} + } + + ret := make([]any, n) + if s, ok := item.(streamReader); ok { + ss := s.copy(n) + for i := range ret { + ret[i] = ss[i] + } + + return ret + } + + for i := range ret { + ret[i] = item + } + + return ret +} + +func (r *runner) run(ctx context.Context, isStream bool, input any, opts ...Option) (any, error) { + var err error + var runWrapper runnableCallWrapper + runWrapper = runnableInvoke + if isStream { + runWrapper = runnableTransform + } + + chs := r.buildChannels() + + maxSteps := r.options.maxRunSteps + for i := range opts { + if opts[i].maxRunSteps > 0 { + maxSteps = opts[i].maxRunSteps + } + } + + if maxSteps < 1 { + return nil, errors.New("recursion_limit must be at least 1") + } + + if r.runCtx != nil { + ctx = r.runCtx(ctx) + } + + optMap, err := extractOption(r.chanSubscribeTo, opts...) + if err != nil { + return nil, fmt.Errorf("graph extract option fail: %w", err) + } + + type task struct { + nodeKey string + call *chanCall + input any + output any + option []any + err error + } + + taskPreProcessor := func(ctx context.Context, t *task) error { + if t.call.preProcessor == nil { + return nil + } + var e error + t.input, e = runWrapper(ctx, t.call.preProcessor, t.input, t.option...) + return e + } + + taskPostProcessor := func(ctx context.Context, t *task) error { + if t.call.postProcessor == nil { + return nil + } + var e error + t.output, e = runWrapper(ctx, t.call.postProcessor, t.output, t.option...) + if e != nil { + t.err = e + t.output = nil + return e + } + return nil + } + + run := func(ctx context.Context, t *task) { + defer func() { + panicInfo := recover() + if panicInfo != nil { + t.output = nil + t.err = safe.NewPanicErr(panicInfo, debug.Stack()) + } + }() + + // callback + ctx = initNodeCallbacks(ctx, t.nodeKey, t.call.action.nodeInfo, t.call.action.meta, opts...) + + out, e := runWrapper(ctx, t.call.action, t.input, t.option...) + if e != nil { + t.output = out + t.err = e + return + } + + t.output = out + t.err = nil + } + + nextTasks := make([]*task, 0) + // init start task + nextTasks = append(nextTasks, &task{ + nodeKey: START, + call: r.inputChannels, + output: input, + }) + + for step := 0; ; step++ { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("context has been canceled, error: %w", ctx.Err()) + default: + } + + if step == maxSteps { + return nil, ErrExceedMaxSteps + } + + // calculate next tasks + wChValues := make(map[string]map[string]any) + for _, t := range nextTasks { + // update channel & new_next_tasks + vs_ := copyItem(t.output, len(t.call.writeTo)+len(t.call.writeToBranches)*2) + nexts, err_ := r.calculateNext(ctx, t.nodeKey, t.call, runWrapper, + vs_[len(t.call.writeTo)+len(t.call.writeToBranches):], isStream) + if err_ != nil { + return nil, fmt.Errorf("calculate next step fail, node: %s, error: %w", t.nodeKey, err_) + } + for i, next := range nexts { + if _, ok := wChValues[next]; !ok { + wChValues[next] = make(map[string]any) + } + // check type if needed + vs_[i], err = r.parserOrValidateTypeIfNeeded(t.nodeKey, next, isStream, vs_[i]) + if err != nil { + return nil, err + } + + wChValues[next][t.nodeKey] = vs_[i] + } + } + + // return directly when arrive end. + if values, ok := wChValues[END]; ok { + err = chs[END].update(ctx, values) + if err != nil { + return nil, err + } + break + } + + var newNextTasks []*task + for wCh, values := range wChValues { + ch, ok := chs[wCh] + if !ok { + return nil, fmt.Errorf("write_to_channel (node): %s not present in the graph", wCh) + } + + err = ch.update(ctx, values) + if err != nil { + return nil, err + } + + if ch.ready(ctx) { + in, e := ch.get(ctx) + if e != nil { + return nil, fmt.Errorf("get node[%s] input from channel fail: %w", wCh, e) + } + var call *chanCall + call, ok = r.chanSubscribeTo[wCh] + if !ok { + return nil, fmt.Errorf("node[%s] has not been registered", wCh) + } + newNextTasks = append(newNextTasks, &task{nodeKey: wCh, call: call, input: in, option: optMap[wCh]}) + } + } + nextTasks = newNextTasks + + if len(nextTasks) == 0 { + return nil, errors.New("no tasks to execute") + } + + for i := 0; i < len(nextTasks); i++ { + e := taskPreProcessor(ctx, nextTasks[i]) + if e != nil { + return nil, fmt.Errorf("pre-process[%s] input error: %w", nextTasks[i].nodeKey, e) + } + } + + if len(nextTasks) == 1 { + run(ctx, nextTasks[0]) + } else { + var wg sync.WaitGroup + for i := 1; i < len(nextTasks); i++ { + wg.Add(1) + go func(t *task) { + defer wg.Done() + defer func() { + panicErr := recover() + if panicErr != nil { + t.err = safe.NewPanicErr(panicErr, debug.Stack()) // nolint: byted_returned_err_should_do_check + } + }() + run(ctx, t) + }(nextTasks[i]) + } + run(ctx, nextTasks[0]) + wg.Wait() + } + + for i := 0; i < len(nextTasks); i++ { + t := nextTasks[i] + if t.err != nil { + return nil, fmt.Errorf("node[%s] execute fail: \n%w", t.nodeKey, t.err) + } + + e := taskPostProcessor(ctx, t) + if e != nil { + return nil, fmt.Errorf("post-process[%s] input error: %w", t.nodeKey, e) + } + } + } + + if !chs[END].ready(ctx) { + return nil, fmt.Errorf("arrives at END node but its value is not ready") + } + out, err := chs[END].get(ctx) + if err != nil { + return nil, err + } + + return out, nil +} + +func (r *runner) calculateNext(ctx context.Context, curNodeKey string, startChan *chanCall, runWrapper runnableCallWrapper, input []any, isStream bool) ([]string, error) { // nolint: byted_s_args_length_limit + if len(input) < len(startChan.writeToBranches) { + // unreachable + return nil, errors.New("calculate next input length is shorter than branches") + } + + ret := make([]string, 0, len(startChan.writeTo)) + ret = append(ret, startChan.writeTo...) + + for i, branch := range startChan.writeToBranches { + // check branch input type if needed + if r.runtimeCheckBranches[curNodeKey][branch.idx] { + if isStream { + input[i] = branch.condition.inputStreamConverter(input[i].(streamReader)) + } else { + err := branch.condition.inputValueChecker(input[i]) + if err != nil { + return nil, fmt.Errorf("branch[%s]-[%d] runtime value check fail: %w", curNodeKey, branch.idx, err) + } + } + } + + wCh, e := runWrapper(ctx, branch.condition, input[i]) + if e != nil { // nolint:byted_s_too_many_nests_in_func + return nil, fmt.Errorf("branch run error: %w", e) + } + + // process branch output + var w string + var ok bool + if isStream { // nolint:byted_s_too_many_nests_in_func + var sr streamReader + var csr *schema.StreamReader[string] + sr, ok = wCh.(streamReader) + if !ok { + return nil, errors.New("stream branch return isn't IStreamReader") + } + csr, ok = unpackStreamReader[string](sr) + if !ok { + return nil, errors.New("unpack branch result fail") + } + + var se error + w, se = concatStreamReader(csr) + if se != nil { + return nil, fmt.Errorf("concat branch result error: %w", se) + } + } else { // nolint:byted_s_too_many_nests_in_func + w, ok = wCh.(string) + if !ok { + return nil, errors.New("invoke branch result isn't string") + } + } + ret = append(ret, w) + } + return ret, nil +} + +func (r *runner) parserOrValidateTypeIfNeeded(cur, next string, isStream bool, value any) (any, error) { + if _, ok := r.runtimeCheckEdges[cur]; !ok { + return value, nil + } + if _, ok := r.runtimeCheckEdges[cur][next]; !ok { + return value, nil + } + + if next == END { + if isStream { + value = r.outputStreamConverter(value.(streamReader)) + return value, nil + } + err := r.outputValueChecker(value) + if err != nil { + return nil, fmt.Errorf("edge[%s]-[%s] runtime value check fail: %w", cur, next, err) + } + return value, nil + + } + if isStream { + value = r.chanSubscribeTo[next].action.inputStreamConverter(value.(streamReader)) + return value, nil + } + err := r.chanSubscribeTo[next].action.inputValueChecker(value) + if err != nil { + return nil, fmt.Errorf("edge[%s]-[%s] runtime value check fail: %w", cur, next, err) + } + return value, nil +} + +func initNodeCallbacks(ctx context.Context, key string, info *nodeInfo, meta *executorMeta, opts ...Option) context.Context { + ri := &callbacks.RunInfo{} + if meta != nil { + ri.Component = meta.component + ri.Type = meta.componentImplType + } + + if info != nil { + ri.Name = info.name + } + + var cbs []callbacks.Handler + for i := range opts { + if len(opts[i].nodeHandler) != 0 { + if len(opts[i].keys) == 0 { + cbs = append(cbs, opts[i].nodeHandler...) + } else { + for _, k := range opts[i].keys { + if k == key { + cbs = append(cbs, opts[i].nodeHandler...) + break + } + } + } + } + + if len(opts[i].handler) != 0 { + if len(opts[i].keys) == 0 { + cbs = append(cbs, opts[i].handler...) + } else { + for _, k := range opts[i].keys { + if k == key { + cbs = append(cbs, opts[i].handler...) + break + } + } + } + } + } + + return callbacks.InitCallbacks(ctx, ri, cbs...) +} diff --git a/compose/graph_test.go b/compose/graph_test.go new file mode 100644 index 0000000..d40b22f --- /dev/null +++ b/compose/graph_test.go @@ -0,0 +1,1456 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + "io" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/schema" +) + +func TestRuntimeGraphKey(t *testing.T) { + ctx := context.Background() + lbd := InvokableLambda[string, string](func(ctx context.Context, input string) (output string, err error) { + return input, nil + }) + + t.Run("Graph", func(t *testing.T) { + c := &cb{} + g := NewGraph[string, string]() + _ = g.AddLambdaNode("lbd", lbd) + _ = g.AddEdge(START, "lbd") + _ = g.AddEdge("lbd", END) + _, err := g.Compile(ctx, []GraphCompileOption{WithGraphCompileCallbacks(c)}...) + assert.NoError(t, err) + assert.Equal(t, "github.com/cloudwego/eino/compose.TestRuntimeGraphKey.func2:43", c.gInfo.Key) + }) + + t.Run("Chain", func(t *testing.T) { + c := &cb{} + ch := NewChain[string, string]().AppendLambda(lbd) + _, err := ch.Compile(ctx, []GraphCompileOption{WithGraphCompileCallbacks(c)}...) + assert.NoError(t, err) + assert.Equal(t, "github.com/cloudwego/eino/compose.TestRuntimeGraphKey.func3:54", c.gInfo.Key) + }) + + t.Run("StateGraph", func(t *testing.T) { + c := &cb{} + g := NewStateGraph[string, string, string](func(ctx context.Context) (state string) { + return "" + }) + _ = g.AddLambdaNode("lbd", lbd) + _ = g.AddEdge(START, "lbd") + _ = g.AddEdge("lbd", END) + _, err := g.Compile(ctx, []GraphCompileOption{WithGraphCompileCallbacks(c)}...) + assert.NoError(t, err) + assert.Equal(t, "github.com/cloudwego/eino/compose.TestRuntimeGraphKey.func4:62", c.gInfo.Key) + }) +} + +func TestSingleGraph(t *testing.T) { + + const ( + nodeOfModel = "model" + nodeOfPrompt = "prompt" + ) + + ctx := context.Background() + g := NewGraph[map[string]any, *schema.Message]() + + pt := prompt.FromMessages(schema.FString, + schema.UserMessage("what's the weather in {location}?"), + ) + + err := g.AddChatTemplateNode("prompt", pt) + assert.NoError(t, err) + + cm := &chatModel{ + msgs: []*schema.Message{ + { + Role: schema.Assistant, + Content: "the weather is good", + }, + }, + } + + err = g.AddChatModelNode(nodeOfModel, cm, WithNodeName("MockChatModel")) + assert.NoError(t, err) + + err = g.AddEdge(START, nodeOfPrompt) + assert.NoError(t, err) + + err = g.AddEdge(nodeOfPrompt, nodeOfModel) + assert.NoError(t, err) + + err = g.AddEdge(nodeOfModel, END) + assert.NoError(t, err) + + r, err := g.Compile(context.Background(), WithMaxRunSteps(10)) + assert.NoError(t, err) + + in := map[string]any{"location": "beijing"} + ret, err := r.Invoke(ctx, in) + assert.NoError(t, err) + t.Logf("invoke result: %v", ret) + + // stream + s, err := r.Stream(ctx, in) + assert.NoError(t, err) + + msg, err := concatStreamReader(s) + assert.NoError(t, err) + t.Logf("stream result: %v", msg) + + sr, sw := schema.Pipe[map[string]any](1) + _ = sw.Send(in, nil) + sw.Close() + + // transform + s, err = r.Transform(ctx, sr) + assert.NoError(t, err) + + msg, err = concatStreamReader(s) + assert.NoError(t, err) + t.Logf("transform result: %v", msg) + + // error test + in = map[string]any{"wrong key": 1} + _, err = r.Invoke(ctx, in) + assert.Errorf(t, err, "could not find key: location") + t.Logf("invoke error: %v", err) + + _, err = r.Stream(ctx, in) + assert.Errorf(t, err, "could not find key: location") + t.Logf("stream error: %v", err) + + sr, sw = schema.Pipe[map[string]any](1) + _ = sw.Send(in, nil) + sw.Close() + + _, err = r.Transform(ctx, sr) + assert.Errorf(t, err, "could not find key: location") + t.Logf("transform error: %v", err) +} + +type person interface { + Say() string +} + +type doctor struct { + say string +} + +func (d *doctor) Say() string { + return d.say +} + +func TestGraphWithImplementableType(t *testing.T) { + + const ( + node1 = "1st" + node2 = "2nd" + ) + + ctx := context.Background() + + g := NewGraph[string, string]() + + err := g.AddLambdaNode(node1, InvokableLambda(func(ctx context.Context, input string) (output *doctor, err error) { + return &doctor{say: input}, nil + })) + assert.NoError(t, err) + + err = g.AddLambdaNode(node2, InvokableLambda(func(ctx context.Context, input person) (output string, err error) { + return input.Say(), nil + })) + assert.NoError(t, err) + + err = g.AddEdge(START, node1) + assert.NoError(t, err) + + err = g.AddEdge(node1, node2) + assert.NoError(t, err) + + err = g.AddEdge(node2, END) + assert.NoError(t, err) + + r, err := g.Compile(context.Background(), WithMaxRunSteps(10)) + assert.NoError(t, err) + + out, err := r.Invoke(ctx, "how are you", WithRuntimeMaxSteps(1)) + assert.Error(t, err) + assert.Equal(t, ErrExceedMaxSteps, err) + + out, err = r.Invoke(ctx, "how are you", WithGraphRunOption(WithRuntimeMaxSteps(1))) + assert.Error(t, err) + assert.Equal(t, ErrExceedMaxSteps, err) + + out, err = r.Invoke(ctx, "how are you") + assert.NoError(t, err) + assert.Equal(t, "how are you", out) + + outStream, err := r.Stream(ctx, "i'm fine") + assert.NoError(t, err) + defer outStream.Close() + + say, err := outStream.Recv() + assert.NoError(t, err) + assert.Equal(t, "i'm fine", say) +} + +func TestNestedGraph(t *testing.T) { + const ( + nodeOfLambda1 = "lambda1" + nodeOfLambda2 = "lambda2" + nodeOfSubGraph = "sub_graph" + nodeOfModel = "model" + nodeOfPrompt = "prompt" + ) + + ctx := context.Background() + g := NewGraph[string, *schema.Message]() + sg := NewGraph[map[string]any, *schema.Message]() + + l1 := InvokableLambda[string, map[string]any]( + func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"location": input}, nil + }) + + l2 := InvokableLambda[*schema.Message, *schema.Message]( + func(ctx context.Context, input *schema.Message) (output *schema.Message, err error) { + input.Content = fmt.Sprintf("after lambda 2: %s", input.Content) + return input, nil + }) + + pt := prompt.FromMessages(schema.FString, + schema.UserMessage("what's the weather in {location}?"), + ) + + err := sg.AddChatTemplateNode("prompt", pt) + assert.NoError(t, err) + + cm := &chatModel{ + msgs: []*schema.Message{ + { + Role: schema.Assistant, + Content: "the weather is good", + }, + }, + } + + err = sg.AddChatModelNode(nodeOfModel, cm, WithNodeName("MockChatModel")) + assert.NoError(t, err) + + err = sg.AddEdge(START, nodeOfPrompt) + assert.NoError(t, err) + + err = sg.AddEdge(nodeOfPrompt, nodeOfModel) + assert.NoError(t, err) + + err = sg.AddEdge(nodeOfModel, END) + assert.NoError(t, err) + + err = g.AddLambdaNode(nodeOfLambda1, l1, WithNodeName("Lambda1")) + assert.NoError(t, err) + + err = g.AddGraphNode(nodeOfSubGraph, sg, WithNodeName("SubGraphName")) + assert.NoError(t, err) + + err = g.AddLambdaNode(nodeOfLambda2, l2, WithNodeName("Lambda2")) + assert.NoError(t, err) + + err = g.AddEdge(START, nodeOfLambda1) + assert.NoError(t, err) + + err = g.AddEdge(nodeOfLambda1, nodeOfSubGraph) + assert.NoError(t, err) + + err = g.AddEdge(nodeOfSubGraph, nodeOfLambda2) + assert.NoError(t, err) + + err = g.AddEdge(nodeOfLambda2, END) + assert.NoError(t, err) + + r, err := g.Compile(context.Background(), + WithMaxRunSteps(10), + WithGraphName("GraphName"), + ) + assert.NoError(t, err) + + ck := "depth" + cb := callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + v, ok := ctx.Value(ck).(int) + if ok { + v++ + } + + t.Logf("Name=%s, Component=%v, Type=%v, Depth=%d", info.Name, info.Component, info.Type, v) + + return context.WithValue(ctx, ck, v) + }). + OnStartWithStreamInputFn(func(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context { + input.Close() + + v, ok := ctx.Value(ck).(int) + if ok { + v++ + } + + t.Logf("Name=%s, Component=%v, Type=%v, Depth=%d", info.Name, info.Component, info.Type, v) + + return context.WithValue(ctx, ck, v) + }).Build() + + // invoke + ri, err := r.Invoke(ctx, "london", WithNodeCallbacks(cb)) + assert.NoError(t, err) + t.Log(ri) + + // stream + rs, err := r.Stream(ctx, "london", WithNodeCallbacks(cb)) + assert.NoError(t, err) + for { + ri, err = rs.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + t.Log(ri) + } + + // collect + sr, sw := schema.Pipe[string](5) + _ = sw.Send("london", nil) + sw.Close() + + rc, err := r.Collect(ctx, sr, WithNodeCallbacks(cb)) + assert.NoError(t, err) + t.Log(rc) + + // transform + sr, sw = schema.Pipe[string](5) + _ = sw.Send("london", nil) + sw.Close() + + rt, err := r.Transform(ctx, sr, WithNodeCallbacks(cb)) + assert.NoError(t, err) + for { + ri, err = rt.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + t.Log(ri) + } +} + +type chatModel struct { + msgs []*schema.Message +} + +func (c *chatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func (c *chatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return c.msgs[0], nil +} + +func (c *chatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + sr, sw := schema.Pipe[*schema.Message](len(c.msgs)) + go func() { + for _, msg := range c.msgs { + sw.Send(msg, nil) + } + sw.Close() + }() + return sr, nil +} + +func TestValidate(t *testing.T) { + // test unmatched nodes + g := NewGraph[string, string]() + err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil })) + assert.NoError(t, err) + + err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input int) (output string, err error) { return "", nil })) + assert.NoError(t, err) + + err = g.AddEdge("1", "2") + assert.ErrorContains(t, err, "graph edge[1]-[2]: start node's output type[string] and end node's input type[int] mismatch") + + // test unmatched passthrough node + g = NewGraph[string, string]() + err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { return "", nil })) + assert.NoError(t, err) + + err = g.AddPassthroughNode("2") + assert.NoError(t, err) + + err = g.AddLambdaNode("3", InvokableLambda(func(ctx context.Context, input int) (output string, err error) { return "", nil })) + assert.NoError(t, err) + + err = g.AddEdge("1", "2") + assert.NoError(t, err) + + err = g.AddEdge("2", "3") + assert.ErrorContains(t, err, "graph edge[2]-[3]: start node's output type[string] and end node's input type[int] mismatch") + + // test unmatched graph type + g = NewGraph[string, string]() + err = g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input int) (output string, err error) { return "", nil })) + assert.NoError(t, err) + + err = g.AddLambdaNode("2", InvokableLambda(func(ctx context.Context, input string) (output int, err error) { return 0, nil })) + assert.NoError(t, err) + + err = g.AddEdge("1", "2") + assert.NoError(t, err) + + err = g.AddEdge(START, "1") + assert.ErrorContains(t, err, "graph edge[start]-[1]: start node's output type[string] and end node's input type[int] mismatch") + + // sub graph implement + type A interface { + A() + } + type B interface { + B() + } + + type AB interface { + A + B + } + lA := InvokableLambda(func(ctx context.Context, input A) (output string, err error) { return "", nil }) + lB := InvokableLambda(func(ctx context.Context, input B) (output string, err error) { return "", nil }) + lAB := InvokableLambda(func(ctx context.Context, input string) (output AB, err error) { return nil, nil }) + + p := NewParallel().AddLambda("1", lA).AddLambda("2", lB) + c := NewChain[string, map[string]any]().AppendLambda(lAB).AppendParallel(p) + _, err = c.Compile(context.Background()) + assert.NoError(t, err) + + // error usage + p = NewParallel().AddLambda("1", lA).AddLambda("2", lAB) + c = NewChain[string, map[string]any]().AppendParallel(p) + _, err = c.Compile(context.Background()) + assert.ErrorContains(t, err, "add parallel edge[start]-[Chain[0]_Parallel[0]_Lambda] to chain failed: graph edge[start]-[Chain[0]_Parallel[0]_Lambda]: start node's output type[string] and end node's input type[compose.A] mismatch") + + // test graph output type check + gg := NewGraph[string, A]() + err = gg.AddLambdaNode("nodeA", InvokableLambda(func(ctx context.Context, input string) (output A, err error) { return nil, nil })) + assert.NoError(t, err) + + err = gg.AddLambdaNode("nodeA2", InvokableLambda(func(ctx context.Context, input string) (output A, err error) { return nil, nil })) + assert.NoError(t, err) + + err = gg.AddLambdaNode("nodeB", InvokableLambda(func(ctx context.Context, input string) (output B, err error) { return nil, nil })) + assert.NoError(t, err) + + err = gg.AddEdge("nodeA", END) + assert.NoError(t, err) + + err = gg.AddEdge("nodeB", END) + assert.ErrorContains(t, err, "graph edge[nodeB]-[end]: start node's output type[compose.B] and end node's input type[compose.A] mismatch") + + err = gg.AddEdge("nodeA2", END) + assert.ErrorContains(t, err, "graph edge[nodeB]-[end]: start node's output type[compose.B] and end node's input type[compose.A] mismatch") + + // test any type + anyG := NewGraph[any, string]() + err = anyG.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil })) + assert.NoError(t, err) + + err = anyG.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil })) + assert.NoError(t, err) + + err = anyG.AddEdge(START, "node1") + assert.NoError(t, err) + + err = anyG.AddEdge("node1", "node2") + assert.NoError(t, err) + + err = anyG.AddEdge("node2", END) + if err != nil { + t.Fatal(err) + } + r, err := anyG.Compile(context.Background()) + assert.NoError(t, err) + + result, err := r.Invoke(context.Background(), "start") + assert.NoError(t, err) + assert.Equal(t, "startnode1node2", result) + + streamResult, err := r.Stream(context.Background(), "start") + assert.NoError(t, err) + + result = "" + for { + chunk, err := streamResult.Recv() + if err != nil { + if err == io.EOF { + break + } + assert.NoError(t, err) + } + result += chunk + } + + assert.Equal(t, "startnode1node2", result) + + // test any type runtime error + anyG = NewGraph[any, string]() + err = anyG.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return 123, nil })) + if err != nil { + t.Fatal(err) + } + err = anyG.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil })) + if err != nil { + t.Fatal(err) + } + err = anyG.AddEdge(START, "node1") + if err != nil { + t.Fatal(err) + } + err = anyG.AddEdge("node1", "node2") + if err != nil { + t.Fatal(err) + } + err = anyG.AddEdge("node2", END) + if err != nil { + t.Fatal(err) + } + r, err = anyG.Compile(context.Background()) + if err != nil { + t.Fatal(err) + } + _, err = r.Invoke(context.Background(), "start") + if err == nil || !strings.Contains(err.Error(), "runtime") { + t.Fatal("test any type runtime error fail, error is nil or error doesn't contain key word runtime") + } + _, err = r.Stream(context.Background(), "start") + if err == nil || !strings.Contains(err.Error(), "runtime") { + t.Fatal("test any type runtime error fail, error is nil or error doesn't contain key word runtime") + } + + // test branch any type + // success + g = NewGraph[string, string]() + err = g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node3", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node3", nil })) + if err != nil { + t.Fatal(err) + } + err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + return "node2", nil + }, map[string]bool{"node2": true, "node3": true})) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "node1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node2", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node3", END) + if err != nil { + t.Fatal(err) + } + rr, err := g.Compile(context.Background()) + if err != nil { + t.Fatal(err) + } + ret, err := rr.Invoke(context.Background(), "start") + if err != nil { + t.Fatal(err) + } + if ret != "startnode1node2" { + t.Fatal("test branch any type fail, result is unexpected") + } + streamResult, err = rr.Stream(context.Background(), "start") + if err != nil { + t.Fatal(err) + } + ret, err = concatStreamReader(streamResult) + if err != nil { + t.Fatal(err) + } + if ret != "startnode1node2" { + t.Fatal("test branch any type fail, result is unexpected") + } + // fail + g = NewGraph[string, string]() + err = g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return 1 /*error type*/, nil })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node3", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node3", nil })) + if err != nil { + t.Fatal(err) + } + err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + return "node2", nil + }, map[string]bool{"node2": true, "node3": true})) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "node1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node2", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node3", END) + if err != nil { + t.Fatal(err) + } + rr, err = g.Compile(context.Background()) + if err != nil { + t.Fatal(err) + } + _, err = rr.Invoke(context.Background(), "start") + if err == nil || !strings.Contains(err.Error(), "runtime") { + t.Fatal("test branch any type fail, haven't report runtime error") + } + _, err = rr.Stream(context.Background(), "start") + if err == nil || !strings.Contains(err.Error(), "runtime") { + t.Fatal("test branch any type fail, haven't report runtime error") + } +} + +func TestValidateMultiAnyValueBranch(t *testing.T) { + // success + g := NewGraph[string, map[string]any]() + err := g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"node2": true}, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node3", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"node3": true}, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node4", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"node4": true}, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node5", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"node5": true}, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + return "node2", nil + }, map[string]bool{"node2": true, "node3": true})) + if err != nil { + t.Fatal(err) + } + err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + return "node4", nil + }, map[string]bool{"node4": true, "node5": true})) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "node1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node2", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node3", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node4", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node5", END) + if err != nil { + t.Fatal(err) + } + rr, err := g.Compile(context.Background()) + if err != nil { + t.Fatal(err) + } + ret, err := rr.Invoke(context.Background(), "start") + if err != nil { + t.Fatal(err) + } + if !ret["node2"].(bool) || !ret["node4"].(bool) { + t.Fatal("test branch any type fail, result is unexpected") + } + streamResult, err := rr.Stream(context.Background(), "start") + if err != nil { + t.Fatal(err) + } + ret, err = concatStreamReader(streamResult) + if err != nil { + t.Fatal(err) + } + if !ret["node2"].(bool) || !ret["node4"].(bool) { + t.Fatal("test branch any type fail, result is unexpected") + } + + // fail + g = NewGraph[string, map[string]any]() + err = g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"node2": true}, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node3", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"node3": true}, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node4", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"node4": true}, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node5", InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{"node5": true}, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + return "node2", nil + }, map[string]bool{"node2": true, "node3": true})) + if err != nil { + t.Fatal(err) + } + err = g.AddBranch("node1", NewGraphBranch(func(ctx context.Context, in int /*error type*/) (endNode string, err error) { + return "node4", nil + }, map[string]bool{"node4": true, "node5": true})) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "node1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node2", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node3", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node4", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node5", END) + if err != nil { + t.Fatal(err) + } + rr, err = g.Compile(context.Background()) + if err != nil { + t.Fatal(err) + } + _, err = rr.Invoke(context.Background(), "start") + if err == nil || !strings.Contains(err.Error(), "runtime") { + t.Fatal("test multi branch any type fail, haven't report runtime error") + } + _, err = rr.Stream(context.Background(), "start") + if err == nil || !strings.Contains(err.Error(), "runtime") { + t.Fatal("test multi branch any type fail, haven't report runtime error") + } +} + +func TestAnyTypeWithKey(t *testing.T) { + g := NewGraph[any, map[string]any]() + err := g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node1", nil }), WithInputKey("node1")) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output any, err error) { return input + "node2", nil }), WithOutputKey("node2")) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "node1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node1", "node2") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node2", END) + if err != nil { + t.Fatal(err) + } + r, err := g.Compile(context.Background()) + if err != nil { + t.Fatal(err) + } + result, err := r.Invoke(context.Background(), map[string]any{"node1": "start"}) + if err != nil { + t.Fatal(err) + } + if result["node2"] != "startnode1node2" { + t.Fatal("test any type with key fail, result is unexpected") + } + + streamResult, err := r.Stream(context.Background(), map[string]any{"node1": "start"}) + if err != nil { + t.Fatal(err) + } + ret, err := concatStreamReader(streamResult) + if err != nil { + t.Fatal(err) + } + if ret["node2"] != "startnode1node2" { + t.Fatal("test any type with key fail, result is unexpected") + } +} + +func TestInputKey(t *testing.T) { + g := NewGraph[map[string]any, map[string]any]() + err := g.AddChatTemplateNode("1", prompt.FromMessages(schema.FString, schema.UserMessage("{var1}")), WithOutputKey("1"), WithInputKey("1")) + if err != nil { + t.Fatal(err) + } + err = g.AddChatTemplateNode("2", prompt.FromMessages(schema.FString, schema.UserMessage("{var2}")), WithOutputKey("2"), WithInputKey("2")) + if err != nil { + t.Fatal(err) + } + err = g.AddChatTemplateNode("3", prompt.FromMessages(schema.FString, schema.UserMessage("{var3}")), WithOutputKey("3"), WithInputKey("3")) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "2") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "3") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("1", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("2", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("3", END) + if err != nil { + t.Fatal(err) + } + r, err := g.Compile(context.Background(), WithMaxRunSteps(100)) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + result, err := r.Invoke(ctx, map[string]any{ + "1": map[string]any{"var1": "a"}, + "2": map[string]any{"var2": "b"}, + "3": map[string]any{"var3": "c"}, + }) + if err != nil { + t.Fatal(err) + } + if result["1"].([]*schema.Message)[0].Content != "a" || + result["2"].([]*schema.Message)[0].Content != "b" || + result["3"].([]*schema.Message)[0].Content != "c" { + t.Fatal("invoke different") + } + + sr, sw := schema.Pipe[map[string]any](10) + sw.Send(map[string]any{"1": map[string]any{"var1": "a"}}, nil) + sw.Send(map[string]any{"2": map[string]any{"var2": "b"}}, nil) + sw.Send(map[string]any{"3": map[string]any{"var3": "c"}}, nil) + sw.Close() + + streamResult, err := r.Transform(ctx, sr) + if err != nil { + t.Fatal(err) + } + defer streamResult.Close() + + result = make(map[string]any) + for { + chunk, err := streamResult.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + for k, v := range chunk { + result[k] = v + } + } + if result["1"].([]*schema.Message)[0].Content != "a" || + result["2"].([]*schema.Message)[0].Content != "b" || + result["3"].([]*schema.Message)[0].Content != "c" { + t.Fatal("transform different") + } +} + +func TestTransferTask(t *testing.T) { + in := [][]string{ + { + "1", + "2", + }, + { + "3", + "4", + "5", + "6", + }, + { + "5", + "6", + "7", + }, + { + "7", + "8", + }, + { + "8", + }, + } + invertedEdges := map[string][]string{ + "1": {"3", "4"}, + "2": {"5", "6"}, + "3": {"5"}, + "4": {"6"}, + "5": {"7"}, + "7": {"8"}, + } + in = transferTask(in, invertedEdges) + + if !reflect.DeepEqual( + [][]string{ + { + "1", + }, + { + "3", + "2", + }, + { + "5", + }, + { + "7", + "4", + }, + { + "8", + "6", + }, + }, in) { + t.Fatal("not equal") + } +} + +func TestPregelEnd(t *testing.T) { + g := NewGraph[string, string]() + err := g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "node1", nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddLambdaNode("node2", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "node2", nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "node1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node1", END) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node1", "node2") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node2", END) + if err != nil { + t.Fatal(err) + } + runner, err := g.Compile(context.Background()) + if err != nil { + t.Fatal(err) + } + + out, err := runner.Invoke(context.Background(), "") + if err != nil { + t.Fatal(err) + } + if out != "node1" { + t.Fatal("graph output is unexpected") + } +} + +func TestGraphAddNodeChecker(t *testing.T) { + t.Run("graph_checker_failed", func(t *testing.T) { + g := NewGraph[string, string]() + err := g.AddLambdaNode("node1", + InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "node1", nil + }), + WithStatePostHandler(func(ctx context.Context, out string, state string) (string, error) { + return out, nil + }), + ) + assert.ErrorContains(t, err, "only StateGraph support pre/post processor") + }) + + t.Run("state_graph_checker_success", func(t *testing.T) { + g := NewStateGraph[string, string, string](func(ctx context.Context) (state string) { + return "" + }) + err := g.AddLambdaNode("node1", + InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "node1", nil + }), + WithStatePostHandler(func(ctx context.Context, out string, state string) (string, error) { + return out, nil + }), + ) + assert.NoError(t, err) + }) +} + +type cb struct { + gInfo *GraphInfo +} + +func (c *cb) OnFinish(ctx context.Context, info *GraphInfo) { + c.gInfo = info +} + +func TestGraphCompileCallback(t *testing.T) { + t.Run("graph compile callback", func(t *testing.T) { + type s struct{} + + g := NewStateGraph[map[string]any, map[string]any, *s](func(ctx context.Context) *s { return &s{} }) + + lambda := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "node1", nil + }) + lambdaOpts := []GraphAddNodeOpt{WithNodeName("lambda_1"), WithInputKey("input_key")} + err := g.AddLambdaNode("node1", lambda, lambdaOpts...) + assert.NoError(t, err) + + err = g.AddPassthroughNode("pass1") + assert.NoError(t, err) + err = g.AddPassthroughNode("pass2") + assert.NoError(t, err) + + condition := func(ctx context.Context, input string) (string, error) { + return input, nil + } + + branch := NewGraphBranch(condition, map[string]bool{"pass1": true, "pass2": true}) + err = g.AddBranch("node1", branch) + assert.NoError(t, err) + + err = g.AddEdge(START, "node1") + assert.NoError(t, err) + + lambda2 := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "node2", nil + }) + lambdaOpts2 := []GraphAddNodeOpt{WithNodeName("lambda_2")} + subSubGraph := NewGraph[string, string]() + err = subSubGraph.AddLambdaNode("sub1", lambda2, lambdaOpts2...) + assert.NoError(t, err) + err = subSubGraph.AddEdge(START, "sub1") + assert.NoError(t, err) + err = subSubGraph.AddEdge("sub1", END) + assert.NoError(t, err) + + subGraph := NewGraph[string, string]() + ssGraphCompileOpts := []GraphCompileOption{WithGraphKey("k3")} + ssGraphOpts := []GraphAddNodeOpt{WithGraphCompileOptions(ssGraphCompileOpts...)} + err = subGraph.AddGraphNode("sub_sub_1", subSubGraph, ssGraphOpts...) + assert.NoError(t, err) + err = subGraph.AddEdge(START, "sub_sub_1") + assert.NoError(t, err) + err = subGraph.AddEdge("sub_sub_1", END) + assert.NoError(t, err) + + subGraphCompileOpts := []GraphCompileOption{WithMaxRunSteps(2), WithGraphKey("k2")} + subGraphOpts := []GraphAddNodeOpt{WithGraphCompileOptions(subGraphCompileOpts...)} + err = g.AddGraphNode("sub_graph", subGraph, subGraphOpts...) + assert.NoError(t, err) + + err = g.AddEdge("pass1", "sub_graph") + assert.NoError(t, err) + err = g.AddEdge("pass2", "sub_graph") + assert.NoError(t, err) + + lambda3 := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "node3", nil + }) + lambdaOpts3 := []GraphAddNodeOpt{WithNodeName("lambda_3"), WithOutputKey("lambda_3")} + err = g.AddLambdaNode("node3", lambda3, lambdaOpts3...) + assert.NoError(t, err) + + lambda4 := InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return "node4", nil + }) + lambdaOpts4 := []GraphAddNodeOpt{WithNodeName("lambda_4"), WithOutputKey("lambda_4")} + err = g.AddLambdaNode("node4", lambda4, lambdaOpts4...) + assert.NoError(t, err) + + err = g.AddEdge("sub_graph", "node3") + assert.NoError(t, err) + err = g.AddEdge("sub_graph", "node4") + assert.NoError(t, err) + err = g.AddEdge("node3", END) + assert.NoError(t, err) + err = g.AddEdge("node4", END) + assert.NoError(t, err) + + c := &cb{} + opt := []GraphCompileOption{WithGraphCompileCallbacks(c), WithGraphKey("k1")} + _, err = g.Compile(context.Background(), opt...) + assert.NoError(t, err) + expected := &GraphInfo{ + Key: "k1", + CompileOptions: append(opt, withComponent(g.component())), + Nodes: map[string]GraphNodeInfo{ + "node1": { + Component: ComponentOfLambda, + Instance: lambda, + GraphAddNodeOpts: lambdaOpts, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + Name: "lambda_1", + InputKey: "input_key", + }, + "pass1": { + Component: ComponentOfPassthrough, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + Name: "PassthroughPassthrough", + }, + "pass2": { + Component: ComponentOfPassthrough, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + Name: "PassthroughPassthrough", + }, + "sub_graph": { + Component: ComponentOfGraph, + Instance: subGraph, + GraphAddNodeOpts: subGraphOpts, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + Name: "Graph", + GraphInfo: &GraphInfo{ + Key: "k2", + CompileOptions: subGraphCompileOpts, + Nodes: map[string]GraphNodeInfo{ + "sub_sub_1": { + Component: ComponentOfGraph, + Instance: subSubGraph, + GraphAddNodeOpts: ssGraphOpts, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + Name: "Graph", + GraphInfo: &GraphInfo{ + Key: "k3", + CompileOptions: ssGraphCompileOpts, + Nodes: map[string]GraphNodeInfo{ + "sub1": { + Component: ComponentOfLambda, + Instance: lambda2, + GraphAddNodeOpts: lambdaOpts2, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + Name: "lambda_2", + }, + }, + Edges: map[string][]string{ + START: {"sub1"}, + "sub1": {END}, + }, + Branches: map[string][]GraphBranch{}, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + }, + }, + }, + Edges: map[string][]string{ + START: {"sub_sub_1"}, + "sub_sub_1": {END}, + }, + Branches: map[string][]GraphBranch{}, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + }, + }, + "node3": { + Component: ComponentOfLambda, + Instance: lambda3, + GraphAddNodeOpts: lambdaOpts3, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + Name: "lambda_3", + OutputKey: "lambda_3", + }, + "node4": { + Component: ComponentOfLambda, + Instance: lambda4, + GraphAddNodeOpts: lambdaOpts4, + InputType: reflect.TypeOf(""), + OutputType: reflect.TypeOf(""), + Name: "lambda_4", + OutputKey: "lambda_4", + }, + }, + Edges: map[string][]string{ + START: {"node1"}, + "pass1": {"sub_graph"}, + "pass2": {"sub_graph"}, + "sub_graph": {"node3", "node4"}, + "node3": {END}, + "node4": {END}, + }, + Branches: map[string][]GraphBranch{ + "node1": {*branch}, + }, + InputType: reflect.TypeOf(map[string]any{}), + OutputType: reflect.TypeOf(map[string]any{}), + } + + stateFn := c.gInfo.GenStateFn + assert.NotNil(t, stateFn) + assert.Equal(t, &s{}, stateFn(context.Background())) + + c.gInfo.GenStateFn = nil + + actualCompileOptions := newGraphCompileOptions(c.gInfo.CompileOptions...) + expectedCompileOptions := newGraphCompileOptions(expected.CompileOptions...) + assert.Equal(t, len(expectedCompileOptions.callbacks), len(actualCompileOptions.callbacks)) + assert.Same(t, expectedCompileOptions.callbacks[0], actualCompileOptions.callbacks[0]) + actualCompileOptions.callbacks = nil + actualCompileOptions.origOpts = nil + expectedCompileOptions.callbacks = nil + expectedCompileOptions.origOpts = nil + assert.Equal(t, expectedCompileOptions, actualCompileOptions) + + c.gInfo.CompileOptions = nil + expected.CompileOptions = nil + assert.Equal(t, expected, c.gInfo) + }) +} + +func TestCheckAddEdge(t *testing.T) { + g := NewGraph[string, string]() + err := g.AddPassthroughNode("1") + if err != nil { + t.Fatal(err) + } + err = g.AddPassthroughNode("2") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("1", "2") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("1", "2") + if err == nil { + t.Fatal("add edge repeatedly haven't report error") + } +} + +func TestStartWithEnd(t *testing.T) { + g := NewGraph[string, string]() + err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddBranch(START, NewGraphBranch(func(ctx context.Context, in string) (endNode string, err error) { + return END, nil + }, map[string]bool{"1": true, END: true})) + if err != nil { + t.Fatal(err) + } + r, err := g.Compile(context.Background()) + if err != nil { + t.Fatal(err) + } + sr, sw := schema.Pipe[string](1) + sw.Send("test", nil) + sw.Close() + result, err := r.Transform(context.Background(), sr) + if err != nil { + t.Fatal(err) + } + for { + chunk, err := result.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if chunk != "test" { + t.Fatal("result is out of expect") + } + } +} + +func TestToString(t *testing.T) { + ps := runTypePregel.String() + assert.Equal(t, "Pregel", ps) + + ds := runTypeDAG + assert.Equal(t, "DAG", ds.String()) +} + +func TestInputKeyError(t *testing.T) { + g := NewGraph[map[string]any, string]() + err := g.AddLambdaNode("node1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + }), WithInputKey("node1")) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "node1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("node1", END) + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + r, err := g.Compile(ctx) + if err != nil { + t.Fatal(err) + } + // invoke + _, err = r.Invoke(ctx, map[string]any{"unknown": "123"}) + if err == nil || !strings.Contains(err.Error(), "cannot find input key: node1") { + t.Fatal("cannot report input key error correctly") + } + + // transform + sr, sw := schema.Pipe[map[string]any](1) + sw.Send(map[string]any{"unknown": "123"}, nil) + sw.Close() + _, err = r.Transform(ctx, sr) + if err == nil || !strings.Contains(err.Error(), "stream reader is empty, concat fail") { + t.Fatal("cannot report input key error correctly") + } +} + +func TestContextCancel(t *testing.T) { + ctx := context.Background() + g := NewGraph[string, string]() + err := g.AddLambdaNode("1", InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + })) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("1", END) + if err != nil { + t.Fatal(err) + } + r, err := g.Compile(ctx) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(ctx) + cancel() + _, err = r.Invoke(ctx, "test") + if !strings.Contains(err.Error(), "context has been canceled") { + t.Fatal("graph have not returned canceled error") + } +} diff --git a/compose/introspect.go b/compose/introspect.go new file mode 100644 index 0000000..8368138 --- /dev/null +++ b/compose/introspect.go @@ -0,0 +1,54 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "reflect" + + "github.com/cloudwego/eino/components" +) + +// GraphNodeInfo the info which end users pass in when they are adding nodes to graph. +type GraphNodeInfo struct { + Component components.Component + Instance any + GraphAddNodeOpts []GraphAddNodeOpt + InputType, OutputType reflect.Type // mainly for lambda, whose input and output types cannot be inferred by component type + Name string + InputKey, OutputKey string + GraphInfo *GraphInfo +} + +// GraphInfo the info which end users pass in when they are compiling a graph. +// it is used in compile callback for user to get the node info and instance. +// you may need all details info of the graph for observation. +type GraphInfo struct { + Key string // graph key, default $CallerFunctionName:$LineNumber + CompileOptions []GraphCompileOption + Nodes map[string]GraphNodeInfo // node key -> node info + Edges map[string][]string // edge start node key -> edge end node key + Branches map[string][]GraphBranch // branch start node key -> branch + InputType, OutputType reflect.Type + + GenStateFn func(context.Context) any +} + +// GraphCompileCallback is the callback which will be called when graph compilation finishes. +type GraphCompileCallback interface { // nolint: byted_s_interface_name + OnFinish(ctx context.Context, info *GraphInfo) +} diff --git a/compose/pregel.go b/compose/pregel.go new file mode 100644 index 0000000..e648152 --- /dev/null +++ b/compose/pregel.go @@ -0,0 +1,72 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" +) + +func pregelChannelBuilder(dependencies []string) channel { + return &pregelChannel{} +} + +type pregelChannel struct { + value any +} + +func (ch *pregelChannel) update(_ context.Context, ins map[string]any) error { + if len(ins) == 0 { + ch.value = nil + return nil + } + + values := make([]any, 0, len(ins)) + for _, v := range ins { + values = append(values, v) + } + + if len(values) == 1 { + ch.value = values[0] + return nil + } + + // merge + v, err := mergeValues(values) + if err != nil { + return err + } + + ch.value = v + + return nil +} + +func (ch *pregelChannel) get(_ context.Context) (any, error) { + if ch.value == nil { + return nil, fmt.Errorf("pregel channel not ready, value is nil") + } + return ch.value, nil +} + +func (ch *pregelChannel) ready(_ context.Context) bool { + return ch.value != nil +} + +func (ch *pregelChannel) clear(_ context.Context) { + ch.value = nil +} diff --git a/compose/runnable.go b/compose/runnable.go new file mode 100644 index 0000000..64ef9a4 --- /dev/null +++ b/compose/runnable.go @@ -0,0 +1,626 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + "reflect" + + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +// Runnable is the interface for an executable object. Graph, StateGraph, Chain, StateChain can be compiled into Runnable. +// runnable is the core conception of eino, we do downgrade compatibility for four data flow patterns, +// and can automatically connect components that only implement one or more methods. +// eg, if a component only implements Stream() method, you can still call Invoke() to convert stream output to invoke output. +type Runnable[I, O any] interface { + Invoke(ctx context.Context, input I, opts ...Option) (output O, err error) + Stream(ctx context.Context, input I, opts ...Option) (output *schema.StreamReader[O], err error) + Collect(ctx context.Context, input *schema.StreamReader[I], opts ...Option) (output O, err error) + Transform(ctx context.Context, input *schema.StreamReader[I], opts ...Option) (output *schema.StreamReader[O], err error) +} + +type invoke func(ctx context.Context, input any, opts ...any) (output any, err error) +type transform func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) + +type streamMapFilter func(key string, isr streamReader) (streamReader, bool) +type streamConverter func(isr streamReader) streamReader +type valueChecker func(value any) error + +// composableRunnable the wrapper for all executable object directly provided by the user. +// one instance corresponds to one instance of the executable object. +// all information comes from executable object without any other dimensions of information. +// for the graphNode, ChainBranch, StatePreHandler, StatePostHandler etc. +type composableRunnable struct { + i invoke + t transform + + // used for passing generic type, is empty in passthrough + inputStreamFilter streamMapFilter + inputStreamConverter streamConverter + inputValueChecker valueChecker + + inputType reflect.Type + outputType reflect.Type + optionType reflect.Type + + isPassthrough bool + + meta *executorMeta + + // only available when in Graph node + // if composableRunnable not in Graph node, this field would be nil + nodeInfo *nodeInfo +} + +// nolint: byted_s_args_length_limit +func runnableLambda[I, O, TOption any](i Invoke[I, O, TOption], s Stream[I, O, TOption], c Collect[I, O, TOption], + t Transform[I, O, TOption], enableCallback bool) *composableRunnable { + rp := newRunnablePacker(i, s, c, t, enableCallback) + + return rp.toComposableRunnable() +} + +type runnablePacker[I, O, TOption any] struct { + i Invoke[I, O, TOption] + s Stream[I, O, TOption] + c Collect[I, O, TOption] + t Transform[I, O, TOption] +} + +func (rp *runnablePacker[I, O, TOption]) wrapRunnableCtx(ctxWrapper func(ctx context.Context, opts ...TOption) context.Context) { + i, s, c, t := rp.i, rp.s, rp.c, rp.t + rp.i = func(ctx context.Context, input I, opts ...TOption) (output O, err error) { + ctx = ctxWrapper(ctx, opts...) + return i(ctx, input, opts...) + } + rp.s = func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { + ctx = ctxWrapper(ctx, opts...) + return s(ctx, input, opts...) + } + rp.c = func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { + ctx = ctxWrapper(ctx, opts...) + return c(ctx, input, opts...) + } + + rp.t = func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) { + ctx = ctxWrapper(ctx, opts...) + return t(ctx, input, opts...) + } +} + +func defaultStreamMapFilter[T any](key string, isr streamReader) (streamReader, bool) { + sr, ok := unpackStreamReader[map[string]any](isr) + if !ok { + return nil, false + } + + convert := func(m map[string]any) (T, error) { + var t T + v, ok_ := m[key] + if !ok_ { + return t, schema.ErrNoValue + } + vv, ok_ := v.(T) + if !ok_ { + return t, fmt.Errorf( + "[defaultStreamMapFilter]fail, key[%s]'s value type[%s] isn't expected type[%s]", + key, reflect.TypeOf(v).String(), + generic.TypeOf[T]().String()) + } + return vv, nil + } + + ret := schema.StreamReaderWithConvert[map[string]any, T](sr, convert) + + return packStreamReader(ret), true +} + +func defaultStreamConverter[T any](reader streamReader) streamReader { + return packStreamReader(schema.StreamReaderWithConvert(reader.toAnyStreamReader(), func(v any) (T, error) { + vv, ok := v.(T) + if !ok { + var t T + return t, fmt.Errorf("runtime type check fail, expected type: %T, actual type: %T", t, v) + } + return vv, nil + })) +} + +func defaultValueChecker[T any](v any) error { + _, ok := v.(T) + if !ok { + var t T + return fmt.Errorf("runtime type check fail, expected type: %T, actual type: %T", t, v) + } + return nil +} + +func (rp *runnablePacker[I, O, TOption]) toComposableRunnable() *composableRunnable { + inputType := generic.TypeOf[I]() + outputType := generic.TypeOf[O]() + optionType := generic.TypeOf[TOption]() + c := &composableRunnable{ + inputStreamFilter: defaultStreamMapFilter[I], + inputStreamConverter: defaultStreamConverter[I], + inputValueChecker: defaultValueChecker[I], + inputType: inputType, + outputType: outputType, + optionType: optionType, + } + + i := func(ctx context.Context, input any, opts ...any) (output any, err error) { + in, ok := input.(I) + if !ok { + panic(newUnexpectedInputTypeErr(inputType, reflect.TypeOf(input))) + } + + tos, err := convertOption[TOption](opts...) + if err != nil { + return nil, err + } + return rp.Invoke(ctx, in, tos...) + } + + t := func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { + in, ok := unpackStreamReader[I](input) + if !ok { + panic(newUnexpectedInputTypeErr(reflect.TypeOf(in), input.getType())) + } + + tos, err := convertOption[TOption](opts...) + if err != nil { + return nil, err + } + + out, err := rp.Transform(ctx, in, tos...) + if err != nil { + return nil, err + } + + return packStreamReader(out), nil + } + + c.i = i + c.t = t + + return c +} + +// Invoke works like `ping => pong`. +func (rp *runnablePacker[I, O, TOption]) Invoke(ctx context.Context, + input I, opts ...TOption) (output O, err error) { + return rp.i(ctx, input, opts...) +} + +// Stream works like `ping => stream output`. +func (rp *runnablePacker[I, O, TOption]) Stream(ctx context.Context, + input I, opts ...TOption) (output *schema.StreamReader[O], err error) { + + return rp.s(ctx, input, opts...) +} + +// Collect works like `stream input => pong`. +func (rp *runnablePacker[I, O, TOption]) Collect(ctx context.Context, + input *schema.StreamReader[I], opts ...TOption) (output O, err error) { + return rp.c(ctx, input, opts...) +} + +// Transform works like `stream input => stream output`. +func (rp *runnablePacker[I, O, TOption]) Transform(ctx context.Context, + input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) { + return rp.t(ctx, input, opts...) +} + +func defaultImplConcatStreamReader[T any]( + sr *schema.StreamReader[T], action defaultImplAction) (T, error) { + + c, err := concatStreamReader(sr) + if err != nil { + var t T + return t, newDefaultImplErr(action, streamConcat, err) + } + + return c, nil +} + +func invokeByStream[I, O, TOption any](s Stream[I, O, TOption]) Invoke[I, O, TOption] { + return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { + action := actionInvokeByStream + + sr, err := s(ctx, input, opts...) + if err != nil { + return output, newDefaultImplErr(action, internalCall, err) + } + + return defaultImplConcatStreamReader(sr, action) + } +} + +func invokeByCollect[I, O, TOption any](c Collect[I, O, TOption]) Invoke[I, O, TOption] { + return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { + action := actionInvokeByCollect + + sr := schema.StreamReaderFromArray([]I{input}) + + output, err = c(ctx, sr, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + } + + return output, err + } +} + +func invokeByTransform[I, O, TOption any](t Transform[I, O, TOption]) Invoke[I, O, TOption] { + return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { + action := actionInvokeByTransform + + srInput := schema.StreamReaderFromArray([]I{input}) + + srOutput, err := t(ctx, srInput, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + return output, err + } + + return defaultImplConcatStreamReader(srOutput, action) + } +} + +func streamByTransform[I, O, TOption any](t Transform[I, O, TOption]) Stream[I, O, TOption] { + return func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { + action := actionStreamByTransform + + srInput := schema.StreamReaderFromArray([]I{input}) + + output, err = t(ctx, srInput, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + } + + return output, err + } +} + +func streamByInvoke[I, O, TOption any](i Invoke[I, O, TOption]) Stream[I, O, TOption] { + return func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { + action := actionStreamByInvoke + + out, err := i(ctx, input, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + + return nil, err + } + + return schema.StreamReaderFromArray([]O{out}), nil + } +} + +func streamByCollect[I, O, TOpion any](c Collect[I, O, TOpion]) Stream[I, O, TOpion] { + return func(ctx context.Context, input I, opts ...TOpion) (output *schema.StreamReader[O], err error) { + action := actionStreamByCollect + + srInput := schema.StreamReaderFromArray([]I{input}) + out, err := c(ctx, srInput, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + + return nil, err + } + + return schema.StreamReaderFromArray([]O{out}), nil + } +} + +func collectByTransform[I, O, TOption any](t Transform[I, O, TOption]) Collect[I, O, TOption] { + return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { + action := actionCollectByTransform + + srOutput, err := t(ctx, input, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + + return output, err + } + + return defaultImplConcatStreamReader(srOutput, action) + } +} + +func collectByInvoke[I, O, TOption any](i Invoke[I, O, TOption]) Collect[I, O, TOption] { + return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { + action := actionCollectByInvoke + + in, err := defaultImplConcatStreamReader(input, action) + if err != nil { + return output, err + } + + output, err = i(ctx, in, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + } + + return output, err + } +} + +func collectByStream[I, O, TOption any](s Stream[I, O, TOption]) Collect[I, O, TOption] { + return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { + action := actionCollectByStream + + in, err := defaultImplConcatStreamReader(input, action) + if err != nil { + return output, err + } + + srOutput, err := s(ctx, in, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + + return output, err + } + + return defaultImplConcatStreamReader(srOutput, action) + } +} + +func transformByStream[I, O, TOption any](s Stream[I, O, TOption]) Transform[I, O, TOption] { + return func(ctx context.Context, input *schema.StreamReader[I], + opts ...TOption) (output *schema.StreamReader[O], err error) { + + action := actionTransformByStream + + in, err := defaultImplConcatStreamReader(input, action) + if err != nil { + return output, err + } + + output, err = s(ctx, in, opts...) + + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + } + + return output, err + } +} + +func transformByCollect[I, O, TOption any](c Collect[I, O, TOption]) Transform[I, O, TOption] { + return func(ctx context.Context, input *schema.StreamReader[I], + opts ...TOption) (output *schema.StreamReader[O], err error) { + + action := actionTransformByCollect + + out, err := c(ctx, input, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + return output, err + } + + return schema.StreamReaderFromArray([]O{out}), nil + } +} + +func transformByInvoke[I, O, TOption any](i Invoke[I, O, TOption]) Transform[I, O, TOption] { + return func(ctx context.Context, input *schema.StreamReader[I], + opts ...TOption) (output *schema.StreamReader[O], err error) { + + action := actionTransformByInvoke + + in, err := defaultImplConcatStreamReader(input, action) + if err != nil { + return output, err + } + + out, err := i(ctx, in, opts...) + if err != nil { + err = newDefaultImplErr(action, internalCall, err) + return output, err + } + + return schema.StreamReaderFromArray([]O{out}), nil + } +} + +func newRunnablePacker[I, O, TOption any](i Invoke[I, O, TOption], s Stream[I, O, TOption], + c Collect[I, O, TOption], t Transform[I, O, TOption], enableCallback bool) *runnablePacker[I, O, TOption] { + + r := &runnablePacker[I, O, TOption]{} + + if enableCallback { + if i != nil { + i = invokeWithCallbacks(i) + } + + if s != nil { + s = streamWithCallbacks(s) + } + + if c != nil { + c = collectWithCallbacks(c) + } + + if t != nil { + t = transformWithCallbacks(t) + } + } + + if i != nil { + r.i = i + } else if s != nil { + r.i = invokeByStream(s) + } else if c != nil { + r.i = invokeByCollect(c) + } else { + r.i = invokeByTransform(t) + } + + if s != nil { + r.s = s + } else if t != nil { + r.s = streamByTransform(t) + } else if i != nil { + r.s = streamByInvoke(i) + } else { + r.s = streamByCollect(c) + } + + if c != nil { + r.c = c + } else if t != nil { + r.c = collectByTransform(t) + } else if i != nil { + r.c = collectByInvoke(i) + } else { + r.c = collectByStream(s) + } + + if t != nil { + r.t = t + } else if s != nil { + r.t = transformByStream(s) + } else if c != nil { + r.t = transformByCollect(c) + } else { + r.t = transformByInvoke(i) + } + + return r +} + +func toGenericRunnable[I, O any](cr *composableRunnable, ctxWrapper func(ctx context.Context, opts ...Option) context.Context) ( + *runnablePacker[I, O, Option], error) { + i := func(ctx context.Context, input I, opts ...Option) (output O, err error) { + out, err := cr.i(ctx, input, toAnyList(opts)...) + if err != nil { + return output, err + } + + return out.(O), err + } + + t := func(ctx context.Context, input *schema.StreamReader[I], + opts ...Option) (output *schema.StreamReader[O], err error) { + in := packStreamReader(input) + out, err := cr.t(ctx, in, toAnyList(opts)...) + + if err != nil { + return nil, err + } + + output, ok := unpackStreamReader[O](out) + if !ok { + panic("impossible") + } + + return output, nil + } + + r := newRunnablePacker(i, nil, nil, t, false) + r.wrapRunnableCtx(ctxWrapper) + + return r, nil +} + +func inputKeyedComposableRunnable(key string, r *composableRunnable) *composableRunnable { + wrapper := *r + wrapper.inputValueChecker = defaultValueChecker[map[string]any] + wrapper.inputStreamConverter = defaultStreamConverter[map[string]any] + i := r.i + wrapper.i = func(ctx context.Context, input any, opts ...any) (output any, err error) { + v, ok := input.(map[string]any)[key] + if !ok { + return nil, fmt.Errorf("cannot find input key: %s", key) + } + out, err := i(ctx, v, opts...) + if err != nil { + return nil, err + } + + return out, nil + } + + t := r.t + wrapper.t = func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { + nInput, ok := r.inputStreamFilter(key, input) + if !ok { + return nil, fmt.Errorf("inputStreamFilter failed, key= %s, node name= %s, err= %w", key, r.nodeInfo.name, err) + } + out, err := t(ctx, nInput, opts...) + if err != nil { + return nil, err + } + + return out, nil + } + + wrapper.inputType = generic.TypeOf[map[string]any]() + return &wrapper +} + +func outputKeyedComposableRunnable(key string, r *composableRunnable) *composableRunnable { + wrapper := *r + i := r.i + wrapper.i = func(ctx context.Context, input any, opts ...any) (output any, err error) { + out, err := i(ctx, input, opts...) + if err != nil { + return nil, err + } + + return map[string]any{key: out}, nil + } + + t := r.t + wrapper.t = func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { + out, err := t(ctx, input, opts...) + if err != nil { + return nil, err + } + + return out.withKey(key), nil + } + + wrapper.outputType = generic.TypeOf[map[string]any]() + + return &wrapper +} + +// composablePassthrough special runnable that passthrough input to output +func composablePassthrough() *composableRunnable { + r := &composableRunnable{isPassthrough: true, nodeInfo: &nodeInfo{}} + + r.i = func(ctx context.Context, input any, opts ...any) (output any, err error) { + return input, nil + } + + r.t = func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { + return input, nil + } + + r.meta = &executorMeta{ + component: ComponentOfPassthrough, + isComponentCallbackEnabled: false, + componentImplType: "Passthrough", + } + + return r +} diff --git a/compose/runnable_test.go b/compose/runnable_test.go new file mode 100644 index 0000000..95744b7 --- /dev/null +++ b/compose/runnable_test.go @@ -0,0 +1,210 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "errors" + "fmt" + "io" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestRunnableLambda(t *testing.T) { + ctx := context.Background() + + t.Run("invoke_to_runnable", func(t *testing.T) { + rl := runnableLambda( + func(ctx context.Context, input int, opts ...Option) (output string, err error) { + return strconv.Itoa(input) + "+" + opts[0].options[0].(string), nil + }, + nil, nil, nil, false) + + ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { + return ctx + } + gr, err := toGenericRunnable[int, string](rl, ctxWrapper) + assert.NoError(t, err) + out, err := gr.Invoke(ctx, 10, WithLambdaOption("100")) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sr, err := gr.Stream(ctx, 10, WithLambdaOption("100")) + assert.NoError(t, err) + out, err = concatStreamReader(sr) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sri, swi := schema.Pipe[int](1) + _ = swi.Send(10, nil) + swi.Close() + sriArr := sri.Copy(2) + + out, err = gr.Collect(ctx, sriArr[0], WithLambdaOption("100")) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sr, err = gr.Transform(ctx, sriArr[1], WithLambdaOption("100")) + assert.NoError(t, err) + out, err = concatStreamReader(sr) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + }) + + t.Run("stream_to_runnable", func(t *testing.T) { + rl := runnableLambda(nil, + func(ctx context.Context, input int, opts ...Option) (output *schema.StreamReader[string], err error) { + sro, swo := schema.Pipe[string](3) + _ = swo.Send(strconv.Itoa(input), nil) + _ = swo.Send("+", nil) + _ = swo.Send(opts[0].options[0].(string), nil) + swo.Close() + return sro, nil + }, nil, nil, false) + + ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { + return ctx + } + gr, err := toGenericRunnable[int, string](rl, ctxWrapper) + assert.NoError(t, err) + out, err := gr.Invoke(ctx, 10, WithLambdaOption("100")) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sr, err := gr.Stream(ctx, 10, WithLambdaOption("100")) + assert.NoError(t, err) + out, err = concatStreamReader(sr) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sri, swi := schema.Pipe[int](1) + _ = swi.Send(10, nil) + swi.Close() + sriArr := sri.Copy(2) + + out, err = gr.Collect(ctx, sriArr[0], WithLambdaOption("100")) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sr, err = gr.Transform(ctx, sriArr[1], WithLambdaOption("100")) + assert.NoError(t, err) + out, err = concatStreamReader(sr) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + }) + + t.Run("transform_to_runnable", func(t *testing.T) { + rl := runnableLambda( + nil, nil, nil, + func(ctx context.Context, input *schema.StreamReader[int], opts ...Option) (output *schema.StreamReader[string], err error) { + + in, e := input.Recv() + if errors.Is(e, io.EOF) { + return nil, fmt.Errorf("unpected EOF") + } + input.Close() + + sro, swo := schema.Pipe[string](3) + _ = swo.Send(strconv.Itoa(in), nil) + _ = swo.Send("+", nil) + _ = swo.Send(opts[0].options[0].(string), nil) + swo.Close() + return sro, nil + }, + false) + + ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { + return ctx + } + gr, err := toGenericRunnable[int, string](rl, ctxWrapper) + assert.NoError(t, err) + out, err := gr.Invoke(ctx, 10, WithLambdaOption("100")) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sr, err := gr.Stream(ctx, 10, WithLambdaOption("100")) + assert.NoError(t, err) + out, err = concatStreamReader(sr) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sri, swi := schema.Pipe[int](1) + _ = swi.Send(10, nil) + swi.Close() + sriArr := sri.Copy(2) + + out, err = gr.Collect(ctx, sriArr[0], WithLambdaOption("100")) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sr, err = gr.Transform(ctx, sriArr[1], WithLambdaOption("100")) + assert.NoError(t, err) + out, err = concatStreamReader(sr) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + }) + + t.Run("collect_to_runnable", func(t *testing.T) { + rl := runnableLambda(nil, nil, + func(ctx context.Context, input *schema.StreamReader[int], opts ...Option) (output string, err error) { + in, e := input.Recv() + if errors.Is(e, io.EOF) { + return "", fmt.Errorf("unpected EOF") + } + input.Close() + + return strconv.Itoa(in) + "+" + opts[0].options[0].(string), nil + }, + nil, false) + + ctxWrapper := func(ctx context.Context, opts ...Option) context.Context { + return ctx + } + + gr, err := toGenericRunnable[int, string](rl, ctxWrapper) + assert.NoError(t, err) + out, err := gr.Invoke(ctx, 10, WithLambdaOption("100")) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sr, err := gr.Stream(ctx, 10, WithLambdaOption("100")) + assert.NoError(t, err) + out, err = concatStreamReader(sr) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sri, swi := schema.Pipe[int](1) + _ = swi.Send(10, nil) + swi.Close() + sriArr := sri.Copy(2) + + out, err = gr.Collect(ctx, sriArr[0], WithLambdaOption("100")) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + + sr, err = gr.Transform(ctx, sriArr[1], WithLambdaOption("100")) + assert.NoError(t, err) + out, err = concatStreamReader(sr) + assert.NoError(t, err) + assert.Equal(t, "10+100", out) + }) +} diff --git a/compose/state.go b/compose/state.go new file mode 100644 index 0000000..1d17daa --- /dev/null +++ b/compose/state.go @@ -0,0 +1,235 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + "reflect" + + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +// NewStateGraph creates a new state graph. It requires a func of GenLocalState to generate the state. +// eg. +// +// type testState struct { +// UserInfo *UserInfo +// KVs map[string]any +// } +// +// genStateFunc := func(ctx context.Context) *testState { +// return &testState{} +// } +// +// graph := compose.NewStateGraph[string, string, testState](genStateFunc) +// +// // you can use WithPreHandler and WithPostHandler to do something with state of this graph. +// graph.AddNode("node1", someNode, compose.WithPreHandler(func(ctx context.Context, in string, state *testState) (string, error) { +// // do something with state +// return in, nil +// }), compose.WithPostHandler(func(ctx context.Context, out string, state *testState) (string, error) { +// // do something with state +// return out, nil +// })) +func NewStateGraph[I, O, S any](gen GenLocalState[S]) *StateGraph[I, O, S] { + sg := &StateGraph[I, O, S]{NewGraph[I, O]()} + + sg.graph.runtimeGraphKey = defaultGraphKey() + sg.runCtx = func(ctx context.Context) context.Context { + state := gen(ctx) + return context.WithValue(ctx, stateKey{}, state) + } + + sg.addNodeChecker = nodeCheckerOfForbidNodeKey(baseNodeChecker) + + sg.compileChecker = func(options *graphCompileOptions) error { + return nil + } + + return sg +} + +// StateGraph is a graph that shares state between nodes. It's useful when you want to share some data across nodes. +type StateGraph[I, O, S any] struct { + *Graph[I, O] +} + +func (s *StateGraph[I, O, S]) component() component { + return ComponentOfStateGraph +} + +// Compile the graph to runnable. +func (s *StateGraph[I, O, S]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { + opts = append(opts, withComponent(s.component())) + + return s.Graph.Compile(ctx, opts...) +} + +// NewStateChain creates a new state chain. It requires a func of GenLocalState to generate the state. +// eg. +// +// genStateFunc := func(ctx context.Context) *testState { +// // or may be you can create the state by params in ctx. +// return &testState{} +// } +// +// chain := compose.NewStateChain[string, string, testState](genStateFunc) +// +// chain.AppendXXX(someNode, compose.WithPreHandler(func(ctx context.Context, in string, state *testState) (string, error) { +// // do something with state +// return in, nil +// }), compose.WithPostHandler(func(ctx context.Context, out string, state *testState) (string, error) { +// // do something with state +// return out, nil +// })) +func NewStateChain[I, O, S any](gen GenLocalState[S]) *StateChain[I, O, S] { + sc := &StateChain[I, O, S]{NewChain[I, O]()} + + sc.gg.runCtx = func(ctx context.Context) context.Context { + state := gen(ctx) + return context.WithValue(ctx, stateKey{}, state) + } + + sc.gg.addNodeChecker = baseNodeChecker + + sc.gg.compileChecker = func(options *graphCompileOptions) error { + return nil + } + + return sc +} + +// StateChain is a chain that shares state between nodes. State is shared between nodes in the chain. +// It's useful when you want to share some data across nodes in a chain. +// you can use WithPreHandler and WithPostHandler to do something with state of this chain. +type StateChain[I, O, S any] struct { + *Chain[I, O] +} + +func (s *StateChain[I, O, S]) component() component { + return ComponentOfStateChain +} + +// Compile the chain to runnable. +func (s *StateChain[I, O, S]) Compile(ctx context.Context, opts ...GraphCompileOption) (Runnable[I, O], error) { + opts = append(opts, withComponent(s.component())) + + return s.Chain.Compile(ctx, opts...) +} + +// GenLocalState is a function that generates the state. +type GenLocalState[S any] func(ctx context.Context) (state S) + +type stateKey struct{} + +// StatePreHandler is a function that is called before the node is executed. +// Notice: if user called Stream but with StatePreHandler, the StatePreHandler will read all stream chunks and merge them into a single object. +type StatePreHandler[I, S any] func(ctx context.Context, in I, state S) (I, error) + +// StatePostHandler is a function that is called after the node is executed. +// Notice: if user called Stream but with StatePostHandler, the StatePostHandler will read all stream chunks and merge them into a single object. +type StatePostHandler[O, S any] func(ctx context.Context, out O, state S) (O, error) + +// StreamStatePreHandler is a function that is called before the node is executed with stream input and output. +type StreamStatePreHandler[I, S any] func(ctx context.Context, in *schema.StreamReader[I], state S) (*schema.StreamReader[I], error) + +// StreamStatePostHandler is a function that is called after the node is executed with stream input and output. +type StreamStatePostHandler[O, S any] func(ctx context.Context, out *schema.StreamReader[O], state S) (*schema.StreamReader[O], error) + +func convertPreHandler[I, S any](handler StatePreHandler[I, S]) *composableRunnable { + rf := func(ctx context.Context, in I, opts ...any) (I, error) { + cState, err := GetState[S](ctx) + if err != nil { + return in, err + } + + return handler(ctx, in, cState) + } + + return runnableLambda[I, I](rf, nil, nil, nil, false) +} + +func convertPostHandler[O, S any](handler StatePostHandler[O, S]) *composableRunnable { + rf := func(ctx context.Context, out O, opts ...any) (O, error) { + cState, err := GetState[S](ctx) + if err != nil { + return out, err + } + + return handler(ctx, out, cState) + } + + return runnableLambda[O, O](rf, nil, nil, nil, false) +} + +func streamConvertPreHandler[I, S any](handler StreamStatePreHandler[I, S]) *composableRunnable { + rf := func(ctx context.Context, in *schema.StreamReader[I], opts ...any) (*schema.StreamReader[I], error) { + cState, err := GetState[S](ctx) + if err != nil { + return in, err + } + + return handler(ctx, in, cState) + } + + return runnableLambda[I, I](nil, nil, nil, rf, false) +} + +func streamConvertPostHandler[O, S any](handler StreamStatePostHandler[O, S]) *composableRunnable { + rf := func(ctx context.Context, out *schema.StreamReader[O], opts ...any) (*schema.StreamReader[O], error) { + cState, err := GetState[S](ctx) + if err != nil { + return out, err + } + + return handler(ctx, out, cState) + } + + return runnableLambda[O, O](nil, nil, nil, rf, false) +} + +// GetState gets the state from the context. +// When using this method to read or write state in custom nodes, it may lead to data race because other nodes may concurrently access the state. +// You need to be aware of and resolve this situation, typically by adding a mutex. +// It's recommended to only READ the returned state. If you want to WRITE to state, consider using StatePreHandler / StatePostHandler because they are concurrency safe out of the box. +// eg. +// +// lambdaFunc := func(ctx context.Context, in string, opts ...any) (string, error) { +// state, err := compose.GetState[*testState](ctx) +// if err != nil { +// return "", err +// } +// // do something with state +// return in, nil +// } +// +// stateGraph := compose.NewStateGraph[string, string, testState](genStateFunc) +// stateGraph.AddNode("node1", lambdaFunc) +func GetState[S any](ctx context.Context) (S, error) { + state := ctx.Value(stateKey{}) + + cState, ok := state.(S) + if !ok { + var s S + return s, fmt.Errorf("unexpected state type. expected: %v, got: %v", + generic.TypeOf[S](), reflect.TypeOf(state)) + } + + return cState, nil +} diff --git a/compose/state_test.go b/compose/state_test.go new file mode 100644 index 0000000..2f0fae8 --- /dev/null +++ b/compose/state_test.go @@ -0,0 +1,321 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "io" + "strings" + "testing" + "unicode/utf8" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +type midStr string + +func TestStateGraphWithEdge(t *testing.T) { + + ctx := context.Background() + + const ( + nodeOfL1 = "invokable" + nodeOfL2 = "streamable" + nodeOfL3 = "transformable" + ) + + type testState struct { + ms []string + } + + gen := func(ctx context.Context) *testState { + return &testState{} + } + + sg := NewStateGraph[string, string, *testState](gen) + + l1 := InvokableLambda(func(ctx context.Context, in string) (out midStr, err error) { + return midStr("InvokableLambda: " + in), nil + }) + + l1StateToInput := func(ctx context.Context, in string, state *testState) (string, error) { + state.ms = append(state.ms, in) + return in, nil + } + + l1StateToOutput := func(ctx context.Context, out midStr, state *testState) (midStr, error) { + state.ms = append(state.ms, string(out)) + return out, nil + } + + err := sg.AddLambdaNode(nodeOfL1, l1, + WithStatePreHandler(l1StateToInput), WithStatePostHandler(l1StateToOutput)) + assert.NoError(t, err) + + l2 := StreamableLambda(func(ctx context.Context, input midStr) (output *schema.StreamReader[string], err error) { + outStr := "StreamableLambda: " + string(input) + + sr, sw := schema.Pipe[string](utf8.RuneCountInString(outStr)) + + go func() { + for _, field := range strings.Fields(outStr) { + sw.Send(field+" ", nil) + } + sw.Close() + }() + + return sr, nil + }) + + l2StateToOutput := func(ctx context.Context, out string, state *testState) (string, error) { + state.ms = append(state.ms, out) + return out, nil + } + + err = sg.AddLambdaNode(nodeOfL2, l2, WithStatePostHandler(l2StateToOutput)) + assert.NoError(t, err) + + l3 := TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) ( + output *schema.StreamReader[string], err error) { + + prefix := "TransformableLambda: " + sr, sw := schema.Pipe[string](20) + + go func() { + for _, field := range strings.Fields(prefix) { + sw.Send(field+" ", nil) + } + defer input.Close() + + for { + chunk, err := input.Recv() + if err != nil { + if err == io.EOF { + break + } + // TODO: how to trace this kind of error in the goroutine of processing stream + sw.Send(chunk, err) + break + } + + sw.Send(chunk, nil) + + } + sw.Close() + }() + + return sr, nil + }) + + l3StateToOutput := func(ctx context.Context, out string, state *testState) (string, error) { + state.ms = append(state.ms, out) + t.Logf("state result: ") + for idx, m := range state.ms { + t.Logf(" %vth: %v", idx, m) + } + assert.Len(t, state.ms, 4) + return out, nil + } + + err = sg.AddLambdaNode(nodeOfL3, l3, WithStatePostHandler(l3StateToOutput)) + assert.NoError(t, err) + + err = sg.AddEdge(START, nodeOfL1) + assert.NoError(t, err) + + err = sg.AddEdge(nodeOfL1, nodeOfL2) + assert.NoError(t, err) + + err = sg.AddEdge(nodeOfL2, nodeOfL3) + assert.NoError(t, err) + + err = sg.AddEdge(nodeOfL3, END) + assert.NoError(t, err) + + run, err := sg.Compile(ctx) + assert.NoError(t, err) + + out, err := run.Invoke(ctx, "how are you") + assert.NoError(t, err) + t.Logf("invoke result: %v", out) + assert.Equal(t, "TransformableLambda: StreamableLambda: InvokableLambda: how are you ", out) + + stream, err := run.Stream(ctx, "how are you") + assert.NoError(t, err) + out, err = concatStreamReader(stream) + assert.NoError(t, err) + t.Logf("stream result: %v", out) + assert.Equal(t, "TransformableLambda: StreamableLambda: InvokableLambda: how are you ", out) + + sr, sw := schema.Pipe[string](1) + sw.Send("how are you", nil) + sw.Close() + + stream, err = run.Transform(ctx, sr) + assert.NoError(t, err) + out, err = concatStreamReader(stream) + assert.NoError(t, err) + t.Logf("transform result: %v", out) + assert.Equal(t, "TransformableLambda: StreamableLambda: InvokableLambda: how are you ", out) +} + +func TestStateGraphUtils(t *testing.T) { + t.Run("getState_success", func(t *testing.T) { + type testStruct struct { + UserID int64 + } + + ctx := context.Background() + + ctx = context.WithValue(ctx, stateKey{}, &testStruct{UserID: 10}) + + ts, err := GetState[*testStruct](ctx) + assert.NoError(t, err) + assert.Equal(t, int64(10), ts.UserID) + }) + + t.Run("getState_nil", func(t *testing.T) { + type testStruct struct { + UserID int64 + } + + ctx := context.Background() + + _, err := GetState[*testStruct](ctx) + assert.ErrorContains(t, err, "unexpected state type. expected: *compose.testStruct, got: ") + }) + + t.Run("getState_type_error", func(t *testing.T) { + type testStruct struct { + UserID int64 + } + + ctx := context.Background() + ctx = context.WithValue(ctx, stateKey{}, &testStruct{UserID: 10}) + + _, err := GetState[string](ctx) + assert.ErrorContains(t, err, "unexpected state type. expected: string, got: *compose.testStruct") + + }) +} + +func TestStateChain(t *testing.T) { + ctx := context.Background() + type testState struct { + Field1 string + Field2 string + } + sc := NewStateChain[string, string, *testState](func(ctx context.Context) (state *testState) { + return &testState{} + }) + + r, err := sc.AppendLambda(InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + s, err := GetState[*testState](ctx) + if err != nil { + return "", err + } + s.Field1 = "node1" + return input, nil + }), WithStatePostHandler(func(ctx context.Context, out string, state *testState) (string, error) { + state.Field2 = "node2" + return out, nil + })). + AppendLambda(InvokableLambda(func(ctx context.Context, input string) (output string, err error) { + return input, nil + }), WithStatePreHandler(func(ctx context.Context, in string, state *testState) (string, error) { + return in + state.Field1 + state.Field2, nil + })).Compile(ctx) + if err != nil { + t.Fatal(err) + } + result, err := r.Invoke(ctx, "start") + if err != nil { + t.Fatal(err) + } + if result != "startnode1node2" { + t.Fatal("result is unexpected") + } +} + +func TestStreamState(t *testing.T) { + type testState struct { + Field1 string + } + ctx := context.Background() + s := &testState{Field1: "1"} + g := NewStateGraph[string, string, *testState](func(ctx context.Context) (state *testState) { return s }) + err := g.AddLambdaNode("1", TransformableLambda(func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) { + return input, nil + }), WithStreamStatePreHandler(func(ctx context.Context, in *schema.StreamReader[string], state *testState) (*schema.StreamReader[string], error) { + sr, sw := schema.Pipe[string](5) + for i := 0; i < 5; i++ { + sw.Send(state.Field1, nil) + } + sw.Close() + return sr, nil + }), WithStreamStatePostHandler(func(ctx context.Context, in *schema.StreamReader[string], state *testState) (*schema.StreamReader[string], error) { + ss := in.Copy(2) + for { + chunk, err := ss[0].Recv() + if err == io.EOF { + return ss[1], nil + } + if err != nil { + return nil, err + } + state.Field1 += chunk + } + })) + if err != nil { + t.Fatal(err) + } + err = g.AddEdge(START, "1") + if err != nil { + t.Fatal(err) + } + err = g.AddEdge("1", END) + if err != nil { + t.Fatal(err) + } + r, err := g.Compile(ctx) + if err != nil { + t.Fatal(err) + } + sr, _ := schema.Pipe[string](1) + streamResult, err := r.Transform(ctx, sr) + if err != nil { + t.Fatal(err) + } + if s.Field1 != "111111" { + t.Fatal("state is unexpected") + } + for i := 0; i < 5; i++ { + chunk, err := streamResult.Recv() + if err != nil { + t.Fatal(err) + } + if chunk != "1" { + t.Fatal("result is unexpected") + } + } + _, err = streamResult.Recv() + if err != io.EOF { + t.Fatal("result is unexpected") + } +} diff --git a/compose/stream_concat.go b/compose/stream_concat.go new file mode 100644 index 0000000..79fc9ce --- /dev/null +++ b/compose/stream_concat.go @@ -0,0 +1,271 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "fmt" + "io" + "reflect" + "strings" + + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +var ( + concatFuncs = map[reflect.Type]any{ + generic.TypeOf[*schema.Message](): schema.ConcatMessages, + generic.TypeOf[string](): concatStrings, + generic.TypeOf[[]*schema.Message](): concatMessageArray, + } +) + +func concatStrings(ss []string) (string, error) { + var n int + for _, s := range ss { + n += len(s) + } + + var b strings.Builder + b.Grow(n) + for _, s := range ss { + _, err := b.WriteString(s) + if err != nil { + return "", err + } + } + + return b.String(), nil +} + +func concatMessageArray(mas [][]*schema.Message) ([]*schema.Message, error) { + arrayLen := len(mas[0]) + + ret := make([]*schema.Message, arrayLen) + slicesToConcat := make([][]*schema.Message, arrayLen) + + for _, ma := range mas { + if len(ma) != arrayLen { + return nil, fmt.Errorf("unexpected array length. "+ + "Got %d, expected %d", len(ma), arrayLen) + } + + for i := 0; i < arrayLen; i++ { + m := ma[i] + if m != nil { + slicesToConcat[i] = append(slicesToConcat[i], m) + } + } + } + + for i, slice := range slicesToConcat { + if len(slice) == 0 { + ret[i] = nil + } else if len(slice) == 1 { + ret[i] = slice[0] + } else { + cm, err := schema.ConcatMessages(slice) + if err != nil { + return nil, err + } + + ret[i] = cm + } + } + + return ret, nil +} + +// RegisterStreamChunkConcatFunc registers a function to concat stream chunks. +// It's required when you want to concat stream chunks of a specific type. +// for example you call Invoke() but node only implements Stream(). +// call at process init +// not thread safe +// nolint: byted_global_write_slicemap +// eg. +// +// type testStruct struct { +// field1 string +// field2 int +// } +// compose.RegisterStreamChunkConcatFunc(func(items []testStruct) (testStruct, error) { +// return testStruct{ +// field1: items[1].field1, // may implement inplace logic by your scenario +// field2: items[0].field2 + items[1].field2, +// }, nil +// }) +func RegisterStreamChunkConcatFunc[T any](fn func([]T) (T, error)) { + concatFuncs[generic.TypeOf[T]()] = fn +} + +func getConcatFunc(tpe reflect.Type) func(reflect.Value) (reflect.Value, error) { + if fn, ok := concatFuncs[tpe]; ok { + return func(a reflect.Value) (reflect.Value, error) { + rvs := reflect.ValueOf(fn).Call([]reflect.Value{a}) + var err error + if !rvs[1].IsNil() { + err = rvs[1].Interface().(error) + } + return rvs[0], err + } + } + + return nil +} + +func toSliceValue(vs []any) (reflect.Value, error) { + typ := reflect.TypeOf(vs[0]) + + ret := reflect.MakeSlice(reflect.SliceOf(typ), len(vs), len(vs)) + ret.Index(0).Set(reflect.ValueOf(vs[0])) + + for i := 1; i < len(vs); i++ { + v := vs[i] + vt := reflect.TypeOf(v) + if typ != vt { + return reflect.Value{}, fmt.Errorf("unexpected slice element type. Got %v, expected %v", typ, vt) + } + + ret.Index(i).Set(reflect.ValueOf(v)) + } + + return ret, nil +} + +func concatStreamReader[T any](sr *schema.StreamReader[T]) (T, error) { + defer sr.Close() + + var items []T + + for { + chunk, err := sr.Recv() + if err != nil { + if err == io.EOF { + break + } + + var t T + return t, newStreamReadError(err) + } + + items = append(items, chunk) + } + + if len(items) == 0 { + var t T + return t, fmt.Errorf("stream reader is empty, concat fail") + } + + if len(items) == 1 { + return items[0], nil + } + + res, err := concatItems(items) + if err != nil { + var t T + return t, err + } + return res, nil +} + +// the caller should ensure len(items) > 1 +func concatItems[T any](items []T) (T, error) { + typ := generic.TypeOf[T]() + v := reflect.ValueOf(items) + + var cv reflect.Value + var err error + + // handle map kind + if typ.Kind() == reflect.Map { + cv, err = concatMaps(v) + } else { + cv, err = concatSliceValue(v) + } + + if err != nil { + var t T + return t, err + } + + return cv.Interface().(T), nil +} + +func concatMaps(ms reflect.Value) (reflect.Value, error) { + typ := ms.Type().Elem() + + rms := reflect.MakeMap(reflect.MapOf(typ.Key(), generic.TypeOf[[]any]())) + ret := reflect.MakeMap(typ) + + n := ms.Len() + for i := 0; i < n; i++ { + m := ms.Index(i) + + for _, key := range m.MapKeys() { + vals := rms.MapIndex(key) + if !vals.IsValid() { + var s []any + vals = reflect.ValueOf(s) + } + + val := m.MapIndex(key) + vals = reflect.Append(vals, val) + rms.SetMapIndex(key, vals) + } + } + + for _, key := range rms.MapKeys() { + vals := rms.MapIndex(key) + + anyVals := vals.Interface().([]any) + v, err := toSliceValue(anyVals) + if err != nil { + return reflect.Value{}, err + } + + var cv reflect.Value + + if v.Type().Elem().Kind() == reflect.Map { + cv, err = concatMaps(v) + } else { + cv, err = concatSliceValue(v) + } + + if err != nil { + return reflect.Value{}, err + } + + ret.SetMapIndex(key, cv) + } + + return ret, nil +} + +func concatSliceValue(val reflect.Value) (reflect.Value, error) { + elmType := val.Type().Elem() + + if val.Len() == 1 { + return val.Index(0), nil + } + + f := getConcatFunc(elmType) + if f == nil { + return reflect.Value{}, fmt.Errorf("cannot concat value of type %s", elmType) + } + + return f(val) +} diff --git a/compose/stream_concat_test.go b/compose/stream_concat_test.go new file mode 100644 index 0000000..a081efa --- /dev/null +++ b/compose/stream_concat_test.go @@ -0,0 +1,213 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "errors" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +type tStreamConcatItemForTest struct { + s string +} + +func concatTStreamForTest(items []tStreamConcatItemForTest) (tStreamConcatItemForTest, error) { + var s string + for _, item := range items { + s += item.s + } + + return tStreamConcatItemForTest{s: s}, nil +} + +func concatIntForTest(items []int) (int, error) { + var i int + for _, item := range items { + i += item + } + + return i, nil +} + +type tConcatErrForTest struct{} + +func concatTStreamError(_ []tConcatErrForTest) (tConcatErrForTest, error) { + return tConcatErrForTest{}, errors.New("test error") +} + +func TestConcatRegistry(t *testing.T) { + + RegisterStreamChunkConcatFunc(concatTStreamForTest) + + sr, sw := schema.Pipe[tStreamConcatItemForTest](10) + go func() { + for i := 0; i < 10; i++ { + sw.Send(tStreamConcatItemForTest{s: strconv.Itoa(i)}, nil) + } + t.Log("send finish") + sw.Close() + }() + + lastVal, err := concatStreamReader(sr) + assert.Nil(t, err) + + assert.Equal(t, "0123456789", lastVal.s) +} + +func TestStringConcat(t *testing.T) { + sr, sw := schema.Pipe[string](10) + go func() { + for i := 0; i < 10; i++ { + sw.Send(strconv.Itoa(i), nil) + } + + sw.Close() + t.Log("send finish") + }() + + lastVal, err := concatStreamReader(sr) + assert.Nil(t, err) + + assert.Equal(t, "0123456789", lastVal) +} + +func TestMessageConcat(t *testing.T) { + sr, sw := schema.Pipe[*schema.Message](10) + go func() { + for i := 0; i < 10; i++ { + content := schema.UserMessage(strconv.Itoa(i)) + if i%4 == 0 { + content.Extra = map[string]any{ + "key_1": strconv.Itoa(i), + strconv.Itoa(i): strconv.Itoa(i), + } + } + sw.Send(content, nil) + } + sw.Close() + t.Log("send finish") + }() + + lastVal, err := concatStreamReader(sr) + assert.Nil(t, err) + assert.Equal(t, "0123456789", lastVal.Content) + assert.Len(t, lastVal.Extra, 4) + assert.Equal(t, map[string]any{ + "key_1": "8", + "0": "0", + "4": "4", + "8": "8", + }, lastVal.Extra) + +} + +func TestMapConcat(t *testing.T) { + RegisterStreamChunkConcatFunc(concatTStreamForTest) + RegisterStreamChunkConcatFunc(concatIntForTest) + + t.Run("simple map", func(t *testing.T) { + sr, sw := schema.Pipe[map[string]any](10) + + go func() { + for i := 0; i < 10; i++ { + sw.Send(map[string]any{ + "string": strconv.Itoa(i), + "custom_concat": tStreamConcatItemForTest{s: strconv.Itoa(9 - i)}, + "count": i, + }, nil) + } + sw.Close() + t.Log("send finish") + }() + + lastVal, err := concatStreamReader(sr) + assert.Nil(t, err) + + assert.Equal(t, "0123456789", lastVal["string"]) + assert.Equal(t, "9876543210", lastVal["custom_concat"].(tStreamConcatItemForTest).s) + assert.Equal(t, 45, lastVal["count"]) + + }) + + t.Run("complex map", func(t *testing.T) { + sr, sw := schema.Pipe[map[string]any](10) + + go func() { + for i := 0; i < 10; i++ { + // 嵌套 map, 仅允许第一层做类型合并,第二层直接覆盖 + sw.Send(map[string]any{ // 嵌套 map + "string": strconv.Itoa(i), + "deep_map": map[string]any{ + "message": &schema.Message{ + Content: strconv.Itoa(i), + }, + "custom_concat_deep": tStreamConcatItemForTest{s: strconv.Itoa(9 - i)}, + "count": i, + }, + "custom_concat": tStreamConcatItemForTest{s: strconv.Itoa(9 - i)}, + "count": i, + }, nil) + } + sw.Close() + t.Log("send finish") + }() + + lastVal, err := concatStreamReader(sr) + assert.Nil(t, err) + + assert.Equal(t, "0123456789", lastVal["string"]) + assert.Equal(t, 45, lastVal["count"]) + assert.Equal(t, "0123456789", lastVal["deep_map"].(map[string]any)["message"].(*schema.Message).Content) + assert.Equal(t, "9876543210", lastVal["deep_map"].(map[string]any)["custom_concat_deep"].(tStreamConcatItemForTest).s) + assert.Equal(t, 45, lastVal["deep_map"].(map[string]any)["count"]) + }) +} + +func TestConcatError(t *testing.T) { + + t.Run("not register type", func(t *testing.T) { + type y struct{} + _, err := concatItems([]y{{}, {}}) + assert.NotNil(t, err) + }) + + t.Run("map type not equal", func(t *testing.T) { + a := map[string]any{ + "str": "string_01", + "x": "string_in_a", + } + + b := map[string]any{ + "str": "string_02", + "x": 123, + } + _, err := concatItems([]map[string]any{a, b}) + assert.NotNil(t, err) + }) + + t.Run("merge error", func(t *testing.T) { + RegisterStreamChunkConcatFunc(concatTStreamError) + + _, err := concatItems([]tConcatErrForTest{{}, {}}) + assert.NotNil(t, err) + }) +} diff --git a/compose/stream_reader.go b/compose/stream_reader.go new file mode 100644 index 0000000..ff16884 --- /dev/null +++ b/compose/stream_reader.go @@ -0,0 +1,116 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "reflect" + + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +type streamReader interface { + copy(n int) []streamReader + getType() reflect.Type + getChunkType() reflect.Type + merge([]streamReader) streamReader + withKey(string) streamReader + close() + toAnyStreamReader() *schema.StreamReader[any] +} + +type streamReaderPacker[T any] struct { + sr *schema.StreamReader[T] +} + +func (srp streamReaderPacker[T]) thisIsStreamReader() {} + +func (srp streamReaderPacker[T]) close() { + srp.sr.Close() +} + +func (srp streamReaderPacker[T]) copy(n int) []streamReader { + ret := make([]streamReader, n) + srs := srp.sr.Copy(n) + + for i := 0; i < n; i++ { + ret[i] = streamReaderPacker[T]{srs[i]} + } + + return ret +} + +func (srp streamReaderPacker[T]) getType() reflect.Type { + return reflect.TypeOf(srp.sr) +} + +func (srp streamReaderPacker[T]) getChunkType() reflect.Type { + return generic.TypeOf[T]() +} + +func (srp streamReaderPacker[T]) merge(isrs []streamReader) streamReader { + srs := make([]*schema.StreamReader[T], len(isrs)+1) + srs[0] = srp.sr + for i := 1; i < len(srs); i++ { + sr, ok := unpackStreamReader[T](isrs[i-1]) + if !ok { + return nil + } + + srs[i] = sr + } + + sr := schema.MergeStreamReaders(srs) + + return packStreamReader(sr) +} + +func (srp streamReaderPacker[T]) withKey(key string) streamReader { + convert := func(v T) (map[string]any, error) { + return map[string]any{key: v}, nil + } + + ret := schema.StreamReaderWithConvert[T, map[string]any](srp.sr, convert) + + return packStreamReader(ret) +} + +func (srp streamReaderPacker[T]) toAnyStreamReader() *schema.StreamReader[any] { + return schema.StreamReaderWithConvert(srp.sr, func(t T) (any, error) { + return t, nil + }) +} + +func packStreamReader[T any](sr *schema.StreamReader[T]) streamReader { + return streamReaderPacker[T]{sr} +} + +func unpackStreamReader[T any](isr streamReader) (*schema.StreamReader[T], bool) { + c, ok := isr.(streamReaderPacker[T]) + if ok { + return c.sr, true + } + + typ := generic.TypeOf[T]() + if typ.Kind() == reflect.Interface { + return schema.StreamReaderWithConvert(isr.toAnyStreamReader(), func(t any) (T, error) { + return t.(T), nil + }), true + } + + return nil, false +} diff --git a/compose/stream_reader_test.go b/compose/stream_reader_test.go new file mode 100644 index 0000000..d104a50 --- /dev/null +++ b/compose/stream_reader_test.go @@ -0,0 +1,102 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "io" + "reflect" + "testing" + + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" +) + +func TestArrayStreamMerge(t *testing.T) { + + t.Run("unpack_to_equal_type", func(t *testing.T) { + a1 := []int{1, 2, 3} + a2 := []int{4, 5, 6} + a3 := []int{7, 8, 9} + s1 := schema.StreamReaderFromArray(a1) + s2 := schema.StreamReaderFromArray(a2) + s3 := schema.StreamReaderFromArray(a3) + + sp1 := streamReaderPacker[int]{sr: s1} + sp2 := streamReaderPacker[int]{sr: s2} + sp3 := streamReaderPacker[int]{sr: s3} + + sp := sp1.merge([]streamReader{sp2, sp3}) + + sr, ok := unpackStreamReader[int](sp) + if !ok { + t.Fatal("unexpected") + } + + defer sr.Close() + + var result []int + for { + chunk, err := sr.Recv() + if err == io.EOF { + break + } + assert.Nil(t, err) + result = append(result, chunk) + } + if !reflect.DeepEqual(result, append(append(a1, a2...), a3...)) { + t.Fatalf("result: %v error", result) + } + }) + + t.Run("unpack_to_father_type", func(t *testing.T) { + a1 := []*doctor{{say: "a"}, {say: "b"}, {say: "c"}} + a2 := []*doctor{{say: "d"}, {say: "e"}, {say: "f"}} + a3 := []*doctor{{say: "g"}, {say: "h"}, {say: "i"}} + s1 := schema.StreamReaderFromArray(a1) + s2 := schema.StreamReaderFromArray(a2) + s3 := schema.StreamReaderFromArray(a3) + + sp1 := streamReaderPacker[*doctor]{sr: s1} + sp2 := streamReaderPacker[*doctor]{sr: s2} + sp3 := streamReaderPacker[*doctor]{sr: s3} + + sp := sp1.merge([]streamReader{sp2, sp3}) + + sr, ok := unpackStreamReader[person](sp) + assert.True(t, ok) + + defer sr.Close() + + var result []person + for { + chunk, err := sr.Recv() + if err == io.EOF { + break + } + assert.Nil(t, err) + result = append(result, chunk) + } + + baseline := append(append(a1, a2...), a3...) + + assert.Len(t, result, len(baseline)) + + for idx := range result { + assert.Equal(t, baseline[idx].say, result[idx].Say()) + } + }) +} diff --git a/compose/tool_node.go b/compose/tool_node.go new file mode 100644 index 0000000..03685ff --- /dev/null +++ b/compose/tool_node.go @@ -0,0 +1,290 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "errors" + "fmt" + "runtime/debug" + "sync" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/safe" +) + +type toolsNodeOptions struct { + ToolOptions []tool.Option +} + +// ToolsNodeOption is the option func type for ToolsNode. +type ToolsNodeOption func(o *toolsNodeOptions) + +// WithToolOption adds tool options to the ToolsNode. +func WithToolOption(opts ...tool.Option) ToolsNodeOption { + return func(o *toolsNodeOptions) { + o.ToolOptions = append(o.ToolOptions, opts...) + } +} + +// ToolsNode a node that can run tools in a graph. the interface in Graph Node as below: +// +// Invoke(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) ([]*schema.Message, error) +// Stream(ctx context.Context, input *schema.Message, opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error) +type ToolsNode struct { + runners []*runnablePacker[string, string, tool.Option] + toolsMeta []*executorMeta + indexes map[string]int // toolName vs index in runners +} + +// ToolsNodeConfig is the config for ToolsNode. It requires a list of tools. +// Tools are BaseTool but must implement InvokableTool or StreamableTool. +type ToolsNodeConfig struct { + Tools []tool.BaseTool +} + +// NewToolNode creates a new ToolsNode. +// eg. +// +// conf := &ToolsNodeConfig{ +// Tools: []tool.BaseTool{invokableTool1, streamableTool2}, +// } +// toolsNode, err := NewToolNode(ctx, conf) +func NewToolNode(ctx context.Context, conf *ToolsNodeConfig) (*ToolsNode, error) { + rps := make([]*runnablePacker[string, string, tool.Option], len(conf.Tools)) + toolsMeta := make([]*executorMeta, len(conf.Tools)) + indexes := make(map[string]int) + + for idx, bt := range conf.Tools { + + tl, err := bt.Info(ctx) + if err != nil { + return nil, fmt.Errorf("(NewToolNode) failed to get tool info at idx= %d: %w", idx, err) + } + + toolName := tl.Name + + var ( + st tool.StreamableTool + it tool.InvokableTool + + invokable func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) + streamable func(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) + + ok bool + meta *executorMeta + ) + + if st, ok = bt.(tool.StreamableTool); ok { + streamable = st.StreamableRun + } + + if it, ok = bt.(tool.InvokableTool); ok { + invokable = it.InvokableRun + } + + if st == nil && it == nil { + return nil, fmt.Errorf("tool %s is not invokable or streamable", toolName) + } + + if st != nil { + meta = parseExecutorInfoFromComponent(components.ComponentOfTool, st) + } else { + meta = parseExecutorInfoFromComponent(components.ComponentOfTool, it) + } + + toolsMeta[idx] = meta + rps[idx] = newRunnablePacker(invokable, streamable, + nil, nil, !meta.isComponentCallbackEnabled) + indexes[toolName] = idx + } + + return &ToolsNode{ + runners: rps, + toolsMeta: toolsMeta, + indexes: indexes, + }, nil +} + +type toolCallTask struct { + // in + r *runnablePacker[string, string, tool.Option] + meta *executorMeta + name string + arg string + callID string + + // out + output string + sOutput *schema.StreamReader[string] + err error +} + +func (tn *ToolsNode) genToolCallTasks(input *schema.Message) ([]toolCallTask, error) { + if input.Role != schema.Assistant { + return nil, fmt.Errorf("expected message role is Assistant, got %s", input.Role) + } + + n := len(input.ToolCalls) + if n == 0 { + return nil, errors.New("no tool call found in input message") + } + + toolCallTasks := make([]toolCallTask, n) + + for i := 0; i < n; i++ { + toolCall := input.ToolCalls[i] + index, ok := tn.indexes[toolCall.Function.Name] + if !ok { + return nil, fmt.Errorf("tool %s not found in toolsNode indexes", toolCall.Function.Name) + } + + toolCallTasks[i].r = tn.runners[index] + toolCallTasks[i].meta = tn.toolsMeta[index] + toolCallTasks[i].name = toolCall.Function.Name + toolCallTasks[i].arg = toolCall.Function.Arguments + toolCallTasks[i].callID = toolCall.ID + } + + return toolCallTasks, nil +} + +func runToolCallTaskByInvoke(ctx context.Context, task *toolCallTask, opts ...tool.Option) { + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{ + Name: task.name, + Type: task.meta.componentImplType, + Component: task.meta.component, + }) + task.output, task.err = task.r.Invoke(ctx, task.arg, opts...) // nolint: byted_returned_err_should_do_check +} + +func runToolCallTaskByStream(ctx context.Context, task *toolCallTask, opts ...tool.Option) { + ctx = callbacks.SwitchRunInfo(ctx, &callbacks.RunInfo{ + Name: task.name, + Type: task.meta.componentImplType, + Component: task.meta.component, + }) + task.sOutput, task.err = task.r.Stream(ctx, task.arg, opts...) // nolint: byted_returned_err_should_do_check +} + +func parallelRunToolCall(ctx context.Context, + run func(ctx2 context.Context, callTask *toolCallTask, opts ...tool.Option), tasks []toolCallTask, opts ...tool.Option) { + + if len(tasks) == 1 { + run(ctx, &tasks[0], opts...) + return + } + + var wg sync.WaitGroup + for i := 1; i < len(tasks); i++ { + wg.Add(1) + go func(ctx_ context.Context, t *toolCallTask, opts ...tool.Option) { + defer wg.Done() + defer func() { + panicErr := recover() + if panicErr != nil { + t.err = safe.NewPanicErr(panicErr, debug.Stack()) // nolint: byted_returned_err_should_do_check + } + }() + run(ctx_, t, opts...) + }(ctx, &tasks[i], opts...) + } + + run(ctx, &tasks[0], opts...) + wg.Wait() +} + +// Invoke calls the tools and collects the results of invokable tools. +// it's parallel if there are multiple tool calls in the input message. +func (tn *ToolsNode) Invoke(ctx context.Context, input *schema.Message, + opts ...ToolsNodeOption) ([]*schema.Message, error) { + + opt := getToolsNodeOptions(opts...) + + tasks, err := tn.genToolCallTasks(input) + if err != nil { + return nil, err + } + + parallelRunToolCall(ctx, runToolCallTaskByInvoke, tasks, opt.ToolOptions...) + + n := len(tasks) + output := make([]*schema.Message, n) + for i := 0; i < n; i++ { + if tasks[i].err != nil { + return nil, fmt.Errorf("failed to invoke tool call %s: %w", tasks[i].callID, tasks[i].err) + } + + output[i] = schema.ToolMessage(tasks[i].output, tasks[i].callID) + } + + return output, nil +} + +// Stream calls the tools and collects the results of stream readers. +// it's parallel if there are multiple tool calls in the input message. +func (tn *ToolsNode) Stream(ctx context.Context, input *schema.Message, + opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.Message], error) { + + opt := getToolsNodeOptions(opts...) + + tasks, err := tn.genToolCallTasks(input) + if err != nil { + return nil, err + } + + parallelRunToolCall(ctx, runToolCallTaskByStream, tasks, opt.ToolOptions...) + + n := len(tasks) + sOutput := make([]*schema.StreamReader[[]*schema.Message], n) + + for i := 0; i < n; i++ { + if tasks[i].err != nil { + return nil, fmt.Errorf("failed to stream tool call %s: %w", tasks[i].callID, tasks[i].err) + } + + index := i + callID := tasks[i].callID + convert := func(s string) ([]*schema.Message, error) { + ret := make([]*schema.Message, n) + ret[index] = schema.ToolMessage(s, callID) + + return ret, nil + } + + sOutput[i] = schema.StreamReaderWithConvert(tasks[i].sOutput, convert) + } + + return schema.MergeStreamReaders(sOutput), nil +} + +func (tn *ToolsNode) GetType() string { + return "" +} + +func getToolsNodeOptions(opts ...ToolsNodeOption) *toolsNodeOptions { + o := &toolsNodeOptions{ + ToolOptions: make([]tool.Option, 0), + } + for _, opt := range opts { + opt(o) + } + return o +} diff --git a/compose/tool_node_test.go b/compose/tool_node_test.go new file mode 100644 index 0000000..f682118 --- /dev/null +++ b/compose/tool_node_test.go @@ -0,0 +1,518 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + "io" + "testing" + + "github.com/bytedance/sonic" + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/components/tool/utils" + "github.com/cloudwego/eino/schema" +) + +const ( + toolNameOfUserCompany = "user_company" + toolIDOfUserCompany = "call_TRZhlagwBS0LpWbWPeZOvIXc" + + toolNameOfUserSalary = "user_salary" + toolIDOfUserSalary = "call_AqfoRW6fuF98k0o7696k2nzm" +) + +func TestToolsNode(t *testing.T) { + var err error + ctx := context.Background() + + userCompanyToolInfo := &schema.ToolInfo{ + Name: toolNameOfUserCompany, + Desc: "根据用户的姓名和邮箱,查询用户的公司和职位信息", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "name": { + Type: "string", + Desc: "用户的姓名", + }, + "email": { + Type: "string", + Desc: "用户的邮箱", + }, + }), + } + + userSalaryToolInfo := &schema.ToolInfo{ + Name: toolNameOfUserSalary, + Desc: "根据用户的姓名和邮箱,查询用户的薪酬信息", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "name": { + Type: "string", + Desc: "用户的姓名", + }, + "email": { + Type: "string", + Desc: "用户的邮箱", + }, + }), + } + + t.Run("success", func(t *testing.T) { + const ( + nodeOfTools = "tools" + nodeOfModel = "model" + ) + g := NewGraph[[]*schema.Message, []*schema.Message]() + + err = g.AddChatModelNode(nodeOfModel, &mockIntentChatModel{}) + assert.NoError(t, err) + + ui := utils.NewTool(userCompanyToolInfo, queryUserCompany) + us := utils.NewStreamTool(userSalaryToolInfo, queryUserSalary) + + toolsNode, err := NewToolNode(ctx, &ToolsNodeConfig{ + Tools: []tool.BaseTool{ui, us}, + }) + assert.NoError(t, err) + + err = g.AddToolsNode(nodeOfTools, toolsNode) + assert.NoError(t, err) + + err = g.AddEdge(START, nodeOfModel) + assert.NoError(t, err) + + err = g.AddEdge(nodeOfModel, nodeOfTools) + assert.NoError(t, err) + + err = g.AddEdge(nodeOfTools, END) + assert.NoError(t, err) + + r, err := g.Compile(ctx) + assert.NoError(t, err) + + out, err := r.Invoke(ctx, []*schema.Message{}) + assert.NoError(t, err) + t.Logf("tool message: %v", out) + + assert.Equal(t, toolIDOfUserCompany, findMsgByToolCallID(out, toolIDOfUserCompany).ToolCallID) + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","gender":"male","company":"bytedance","position":"CEO"}`, + findMsgByToolCallID(out, toolIDOfUserCompany).Content) + + assert.Equal(t, toolIDOfUserSalary, findMsgByToolCallID(out, toolIDOfUserSalary).ToolCallID) + assert.Contains(t, findMsgByToolCallID(out, toolIDOfUserSalary).Content, + `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":5000}{"user_id":"zhangsan-zhangsan@bytedance.com","salary":3000}{"user_id":"zhangsan-zhangsan@bytedance.com","salary":2000}`) + + // 测试流式调用 + reader, err := r.Stream(ctx, []*schema.Message{}) + assert.NoError(t, err) + loops := 0 + userSalaryTimes := 0 + + defer reader.Close() + + for ; loops < 10; loops++ { + msgs, err := reader.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + + t.Logf("stream message[%v]: %v", loops, msgs) + + assert.Len(t, msgs, 2) + if msg := findMsgByToolCallID(out, toolIDOfUserCompany); msg != nil { + assert.Equal(t, schema.Tool, msg.Role) + assert.Equal(t, toolIDOfUserCompany, msg.ToolCallID) + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","gender":"male","company":"bytedance","position":"CEO"}`, + msg.Content) + } else if msg := findMsgByToolCallID(out, toolIDOfUserSalary); msg != nil { + assert.Equal(t, schema.Tool, msg.Role) + assert.Equal(t, toolIDOfUserSalary, msg.ToolCallID) + + switch userSalaryTimes { + case 0: + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":5000}`, + msg.Content) + case 1: + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":3000}`, + msg.Content) + case 2: + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":2000}`, + msg.Content) + } + + userSalaryTimes++ + } else { + assert.Fail(t, "unexpected tool name") + } + } + + assert.Equal(t, 4, loops) + + sr, sw := schema.Pipe[[]*schema.Message](2) + sw.Send([]*schema.Message{ + { + Role: schema.User, + Content: `hi, how are you`, + }, + }, nil) + sw.Send([]*schema.Message{ + { + Role: schema.User, + Content: `i'm fine'`, + }, + }, nil) + sw.Close() + + reader, err = r.Transform(ctx, sr) + assert.NoError(t, err) + + defer reader.Close() + + loops = 0 + userSalaryTimes = 0 + + for ; loops < 10; loops++ { + msgs, err := reader.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + + t.Logf("stream message[%v]: %v", loops, msgs) + + assert.Len(t, msgs, 2) + if msg := findMsgByToolCallID(out, toolIDOfUserCompany); msg != nil { + assert.Equal(t, schema.Tool, msg.Role) + assert.Equal(t, toolIDOfUserCompany, msg.ToolCallID) + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","gender":"male","company":"bytedance","position":"CEO"}`, + msg.Content) + } else if msg := findMsgByToolCallID(out, toolIDOfUserSalary); msg != nil { + assert.Equal(t, schema.Tool, msg.Role) + assert.Equal(t, toolIDOfUserSalary, msg.ToolCallID) + + switch userSalaryTimes { + case 0: + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":5000}`, + msg.Content) + case 1: + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":3000}`, + msg.Content) + case 2: + assert.JSONEq(t, `{"user_id":"zhangsan-zhangsan@bytedance.com","salary":2000}`, + msg.Content) + } + + userSalaryTimes++ + } else { + assert.Fail(t, "unexpected tool name") + } + } + + assert.Equal(t, 4, loops) + }) +} + +type userCompanyRequest struct { + Name string `json:"name"` + Email string `json:"email"` +} + +type userCompanyResponse struct { + UserID string `json:"user_id"` + Gender string `json:"gender"` + Company string `json:"company"` + Position string `json:"position"` +} + +func queryUserCompany(ctx context.Context, req *userCompanyRequest) (resp *userCompanyResponse, err error) { + return &userCompanyResponse{ + UserID: fmt.Sprintf("%v-%v", req.Name, req.Email), + Gender: "male", + Company: "bytedance", + Position: "CEO", + }, nil +} + +type userSalaryRequest struct { + Name string `json:"name"` + Email string `json:"email"` +} + +type userSalaryResponse struct { + UserID string `json:"user_id"` + Salary int `json:"salary"` +} + +func queryUserSalary(ctx context.Context, req *userSalaryRequest) (resp *schema.StreamReader[*userSalaryResponse], err error) { + sr, sw := schema.Pipe[*userSalaryResponse](10) + sw.Send(&userSalaryResponse{ + UserID: fmt.Sprintf("%v-%v", req.Name, req.Email), + Salary: 5000, + }, nil) + + sw.Send(&userSalaryResponse{ + UserID: fmt.Sprintf("%v-%v", req.Name, req.Email), + Salary: 3000, + }, nil) + + sw.Send(&userSalaryResponse{ + UserID: fmt.Sprintf("%v-%v", req.Name, req.Email), + Salary: 2000, + }, nil) + sw.Close() + return sr, nil +} + +type mockIntentChatModel struct{} + +func (m *mockIntentChatModel) BindTools(tools []*schema.ToolInfo) error { + return nil +} + +func (m *mockIntentChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return &schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: toolIDOfUserCompany, + Function: schema.FunctionCall{ + Name: toolNameOfUserCompany, + Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, + }, + }, + { + ID: toolIDOfUserSalary, + Function: schema.FunctionCall{ + Name: toolNameOfUserSalary, + Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, + }, + }, + }, + }, nil +} + +func (m *mockIntentChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + sr, sw := schema.Pipe[*schema.Message](2) + sw.Send(&schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: toolIDOfUserCompany, + Function: schema.FunctionCall{ + Name: toolNameOfUserCompany, + Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, + }, + }, + }, + }, nil) + + sw.Send(&schema.Message{ + Role: schema.Assistant, + Content: "", + ToolCalls: []schema.ToolCall{ + { + ID: toolIDOfUserSalary, + Function: schema.FunctionCall{ + Name: toolNameOfUserSalary, + Arguments: `{"name": "zhangsan", "email": "zhangsan@bytedance.com"}`, + }, + }, + }, + }, nil) + + sw.Close() + + return sr, nil +} + +func TestToolsNodeOptions(t *testing.T) { + ctx := context.Background() + + t.Run("tool_option", func(t *testing.T) { + + g := NewGraph[*schema.Message, []*schema.Message]() + + mt := &mockTool{} + + tn, err := NewToolNode(ctx, &ToolsNodeConfig{ + Tools: []tool.BaseTool{mt}, + }) + assert.NoError(t, err) + + err = g.AddToolsNode("tools", tn) + assert.NoError(t, err) + + err = g.AddEdge(START, "tools") + assert.NoError(t, err) + err = g.AddEdge("tools", END) + assert.NoError(t, err) + + r, err := g.Compile(ctx) + assert.NoError(t, err) + + out, err := r.Invoke(ctx, &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: toolIDOfUserCompany, + Function: schema.FunctionCall{ + Name: "mock_tool", + Arguments: `{"name": "jack"}`, + }, + }, + }, + }, WithToolsNodeOption(WithToolOption(WithAge(10)))) + assert.NoError(t, err) + assert.Len(t, out, 1) + assert.JSONEq(t, `{"echo": "jack: 10"}`, out[0].Content) + + outMessages := make([][]*schema.Message, 0) + outStream, err := r.Stream(ctx, &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + ID: toolIDOfUserCompany, + Function: schema.FunctionCall{ + Name: "mock_tool", + Arguments: `{"name": "jack"}`, + }, + }, + }, + }, WithToolsNodeOption(WithToolOption(WithAge(10)))) + + assert.NoError(t, err) + + for { + msgs, err := outStream.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + outMessages = append(outMessages, msgs) + } + outStream.Close() + + msgs, err := concatMessageArray(outMessages) + assert.NoError(t, err) + + assert.Len(t, msgs, 1) + assert.JSONEq(t, `{"echo":"jack: 10"}`, msgs[0].Content) + }) + +} + +func findMsgByToolCallID(msgs []*schema.Message, toolCallID string) *schema.Message { + for _, msg := range msgs { + if msg.ToolCallID == toolCallID { + return msg + } + } + + return nil +} + +type mockToolOptions struct { + Age int +} + +func WithAge(age int) tool.Option { + return tool.WrapImplSpecificOptFn(func(o *mockToolOptions) { + o.Age = age + }) +} + +type mockToolRequest struct { + Name string `json:"name"` +} + +type mockToolResponse struct { + Echo string `json:"echo"` +} + +type mockTool struct{} + +func (m *mockTool) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "mock_tool", + Desc: "mock tool", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "name": { + Type: "string", + Desc: "name", + Required: true, + }, + }), + }, nil +} + +func (m *mockTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + opt := tool.GetImplSpecificOptions(&mockToolOptions{}, opts...) + + req := &mockToolRequest{} + + if e := sonic.UnmarshalString(argumentsInJSON, req); e != nil { + return "", e + } + + resp := &mockToolResponse{ + Echo: fmt.Sprintf("%v: %v", req.Name, opt.Age), + } + + return sonic.MarshalString(resp) +} + +func (m *mockTool) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (*schema.StreamReader[string], error) { + sr, sw := schema.Pipe[string](1) + go func() { + defer sw.Close() + + opt := tool.GetImplSpecificOptions(&mockToolOptions{}, opts...) + + req := &mockToolRequest{} + + if e := sonic.UnmarshalString(argumentsInJSON, req); e != nil { + sw.Send("", e) + return + } + + resp := mockToolResponse{ + Echo: fmt.Sprintf("%v: %v", req.Name, opt.Age), + } + + output, err := sonic.MarshalString(resp) + if err != nil { + sw.Send("", err) + return + } + + for i := 0; i < len(output); i++ { + sw.Send(string(output[i]), nil) + } + }() + + return sr, nil +} diff --git a/compose/types.go b/compose/types.go new file mode 100644 index 0000000..1697966 --- /dev/null +++ b/compose/types.go @@ -0,0 +1,48 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "github.com/cloudwego/eino/components" +) + +type component = components.Component + +// built-in component types in graph node. +// it represents the type of the most primitive executable object provided by the user. +const ( + ComponentOfUnknown component = "Unknown" + ComponentOfGraph component = "Graph" + ComponentOfStateGraph component = "StateGraph" + ComponentOfChain component = "Chain" + ComponentOfStateChain component = "StateChain" + ComponentOfPassthrough component = "Passthrough" + ComponentOfToolsNode component = "ToolsNode" + ComponentOfLambda component = "Lambda" +) + +// NodeTriggerMode controls the triggering mode of graph nodes. +type NodeTriggerMode string + +const ( + // AnyPredecessor means that the current node will be triggered as long as any of its predecessor nodes has finished running. + // Note that actual implementation organizes node execution in batches. + // In this context, 'any predecessor finishes' would means the other nodes of the same batch need to be finished too. + AnyPredecessor NodeTriggerMode = "any_predecessor" + // AllPredecessor means that the current node will only be triggered when all of its predecessor nodes have finished running. + AllPredecessor NodeTriggerMode = "all_predecessor" +) diff --git a/compose/types_composable.go b/compose/types_composable.go new file mode 100644 index 0000000..e5f3839 --- /dev/null +++ b/compose/types_composable.go @@ -0,0 +1,30 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "reflect" +) + +// AnyGraph the identifiers for composable and compilable Graph[I, O]、Chain[I, O] in Eino. +type AnyGraph interface { + compile(ctx context.Context, options *graphCompileOptions) (*composableRunnable, error) + inputType() reflect.Type + outputType() reflect.Type + component() component +} diff --git a/compose/types_lambda.go b/compose/types_lambda.go new file mode 100644 index 0000000..1adfab1 --- /dev/null +++ b/compose/types_lambda.go @@ -0,0 +1,265 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/schema" +) + +// Invoke is the type of the invokable lambda function. +type Invoke[I, O, TOption any] func(ctx context.Context, input I, opts ...TOption) (output O, err error) + +// Stream is the type of the streamable lambda function. +type Stream[I, O, TOption any] func(ctx context.Context, + input I, opts ...TOption) (output *schema.StreamReader[O], err error) + +// Collect is the type of the collectable lambda function. +type Collect[I, O, TOption any] func(ctx context.Context, + input *schema.StreamReader[I], opts ...TOption) (output O, err error) + +// Transform is the type of the transformable lambda function. +type Transform[I, O, TOption any] func(ctx context.Context, + input *schema.StreamReader[I], opts ...TOption) (output *schema.StreamReader[O], err error) + +// InvokeWOOpt is the type of the invokable lambda function without options. +type InvokeWOOpt[I, O any] func(ctx context.Context, input I) (output O, err error) + +// StreamWOOpt is the type of the streamable lambda function without options. +type StreamWOOpt[I, O any] func(ctx context.Context, + input I) (output *schema.StreamReader[O], err error) + +// CollectWOOpt is the type of the collectable lambda function without options. +type CollectWOOpt[I, O any] func(ctx context.Context, + input *schema.StreamReader[I]) (output O, err error) + +// TransformWOOpts is the type of the transformable lambda function without options. +type TransformWOOpts[I, O any] func(ctx context.Context, + input *schema.StreamReader[I]) (output *schema.StreamReader[O], err error) + +// Lambda is the node that wraps the user provided lambda function. +// It can be used as a node in Graph or Chain (include Parallel and Branch). +// Create a Lambda by using AnyLambda/InvokableLambda/StreamableLambda/CollectableLambda/TransformableLambda. +// eg. +// +// lambda := compose.InvokableLambda(func(ctx context.Context, input string) (output string, err error) { +// return input, nil +// }) +type Lambda struct { + executor *composableRunnable +} + +type lambdaOpts struct { + // same as executorMeta.isComponentCallbackEnabled + // indicates whether the executable lambda user provided could execute the callback aspect itself. + // if it could, the callback in the corresponding graph node won't be executed anymore + enableComponentCallback bool + + // same as executorMeta.componentImplType + // for AnyLambda, the value comes from the user's explicit config + // if componentImplType is empty, then the class name or func name in the instance will be inferred, but no guarantee. + componentImplType string +} + +// LambdaOpt is the option for creating a Lambda. +type LambdaOpt func(o *lambdaOpts) + +// WithLambdaCallbackEnable enables the callback aspect of the lambda function. +func WithLambdaCallbackEnable(y bool) LambdaOpt { + return func(o *lambdaOpts) { + o.enableComponentCallback = y + } +} + +// WithLambdaType sets the type of the lambda function. +func WithLambdaType(t string) LambdaOpt { + return func(o *lambdaOpts) { + o.componentImplType = t + } +} + +type unreachableOption struct{} + +// InvokableLambdaWithOption creates a Lambda with invokable lambda function and options. +func InvokableLambdaWithOption[I, O, TOption any](i Invoke[I, O, TOption], opts ...LambdaOpt) *Lambda { + return anyLambda(i, nil, nil, nil, opts...) +} + +// InvokableLambda creates a Lambda with invokable lambda function without options. +func InvokableLambda[I, O any](i InvokeWOOpt[I, O], opts ...LambdaOpt) *Lambda { + f := func(ctx context.Context, input I, opts_ ...unreachableOption) (output O, err error) { + return i(ctx, input) + } + + return anyLambda(f, nil, nil, nil, opts...) +} + +// StreamableLambdaWithOption creates a Lambda with streamable lambda function and options. +func StreamableLambdaWithOption[I, O, TOption any](s Stream[I, O, TOption], opts ...LambdaOpt) *Lambda { + return anyLambda(nil, s, nil, nil, opts...) +} + +// StreamableLambda creates a Lambda with streamable lambda function without options. +func StreamableLambda[I, O any](s StreamWOOpt[I, O], opts ...LambdaOpt) *Lambda { + f := func(ctx context.Context, input I, opts_ ...unreachableOption) ( + output *schema.StreamReader[O], err error) { + + return s(ctx, input) + } + + return anyLambda(nil, f, nil, nil, opts...) +} + +// CollectableLambdaWithOption creates a Lambda with collectable lambda function and options. +func CollectableLambdaWithOption[I, O, TOption any](c Collect[I, O, TOption], opts ...LambdaOpt) *Lambda { + return anyLambda(nil, nil, c, nil, opts...) +} + +// CollectableLambda creates a Lambda with collectable lambda function without options. +func CollectableLambda[I, O any](c CollectWOOpt[I, O], opts ...LambdaOpt) *Lambda { + f := func(ctx context.Context, input *schema.StreamReader[I], + opts_ ...unreachableOption) (output O, err error) { + + return c(ctx, input) + } + + return anyLambda(nil, nil, f, nil, opts...) +} + +// TransformableLambdaWithOption creates a Lambda with transformable lambda function and options. +func TransformableLambdaWithOption[I, O, TOption any](t Transform[I, O, TOption], opts ...LambdaOpt) *Lambda { + return anyLambda(nil, nil, nil, t, opts...) +} + +// TransformableLambda creates a Lambda with transformable lambda function without options. +func TransformableLambda[I, O any](t TransformWOOpts[I, O], opts ...LambdaOpt) *Lambda { + + f := func(ctx context.Context, input *schema.StreamReader[I], + opts_ ...unreachableOption) (output *schema.StreamReader[O], err error) { + + return t(ctx, input) + } + + return anyLambda(nil, nil, nil, f, opts...) +} + +// AnyLambda creates a Lambda with any lambda function. +// you can only implement one or more of the four lambda functions, and the rest use nil. +// eg. +// +// invokeFunc := func(ctx context.Context, input string, opts ...myOption) (output string, err error) { +// // ... +// } +// streamFunc := func(ctx context.Context, input string, opts ...myOption) (output *schema.StreamReader[string], err error) { +// // ... +// } +// +// lambda := compose.AnyLambda(invokeFunc, streamFunc, nil, nil) +func AnyLambda[I, O, TOption any](i Invoke[I, O, TOption], s Stream[I, O, TOption], + c Collect[I, O, TOption], t Transform[I, O, TOption], opts ...LambdaOpt) (*Lambda, error) { + + if i == nil && s == nil && c == nil && t == nil { + return nil, fmt.Errorf("needs to have at least one of four lambda types: invoke/stream/collect/tranform, got none") + } + + return anyLambda(i, s, c, t, opts...), nil +} + +func anyLambda[I, O, TOption any](i Invoke[I, O, TOption], s Stream[I, O, TOption], + c Collect[I, O, TOption], t Transform[I, O, TOption], opts ...LambdaOpt) *Lambda { + + opt := getLambdaOpt(opts...) + + executor := runnableLambda(i, s, c, t, + !opt.enableComponentCallback, + ) + executor.meta = &executorMeta{ + component: ComponentOfLambda, + isComponentCallbackEnabled: opt.enableComponentCallback, + componentImplType: opt.componentImplType, + } + + return &Lambda{ + executor: executor, + } +} + +func getLambdaOpt(opts ...LambdaOpt) *lambdaOpts { + opt := &lambdaOpts{ + enableComponentCallback: false, + componentImplType: "", + } + + for _, optFn := range opts { + optFn(opt) + } + return opt +} + +// ToList creates a Lambda that converts input I to a []I. +// It's useful when you want to convert a single input to a list of inputs. +// eg. +// +// lambda := compose.ToList[*schema.Message]() +// chain := compose.NewChain[[]*schema.Message, []*schema.Message]() +// +// chain.AddChatModel(chatModel) // chatModel returns *schema.Message, but we need []*schema.Message +// chain.AddLambda(lambda) // convert *schema.Message to []*schema.Message +func ToList[I any](opts ...LambdaOpt) *Lambda { + i := func(ctx context.Context, input I, opts_ ...unreachableOption) (output []I, err error) { + return []I{input}, nil + } + + f := func(ctx context.Context, inputS *schema.StreamReader[I], opts_ ...unreachableOption) (outputS *schema.StreamReader[[]I], err error) { + return schema.StreamReaderWithConvert(inputS, func(i I) ([]I, error) { + return []I{i}, nil + }), nil + } + + return anyLambda(i, nil, nil, f, opts...) +} + +// MessageParser creates a lambda that parses a message into an object T, usually used after a chatmodel. +// usage: +// +// parser := schema.NewMessageJSONParser[MyStruct](&schema.MessageJSONParseConfig{ +// ParseFrom: schema.MessageParseFromContent, +// }) +// parserLambda := MessageParser(parser) +// +// chain := NewChain[*schema.Message, MyStruct]() +// chain.AppendChatModel(chatModel) +// chain.AppendLambda(parserLambda) +// +// r, err := chain.Compile(context.Background()) +// +// // parsed is a MyStruct object +// parsed, err := r.Invoke(context.Background(), &schema.Message{ +// Role: schema.MessageRoleUser, +// Content: "return a json string for my struct", +// }) +func MessageParser[T any](p schema.MessageParser[T], opts ...LambdaOpt) *Lambda { + i := func(ctx context.Context, input *schema.Message, opts_ ...unreachableOption) (output T, err error) { + return p.Parse(ctx, input) + } + + opts = append([]LambdaOpt{WithLambdaType("MessageParse")}, opts...) + + return anyLambda(i, nil, nil, nil, opts...) +} diff --git a/compose/types_lambda_test.go b/compose/types_lambda_test.go new file mode 100644 index 0000000..96afc0b --- /dev/null +++ b/compose/types_lambda_test.go @@ -0,0 +1,212 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestLambda(t *testing.T) { + t.Run("InvokableLambda", func(t *testing.T) { + ld := InvokableLambdaWithOption( + func(ctx context.Context, input string, opts ...any) (output string, err error) { + return "good", nil + }, + WithLambdaCallbackEnable(false), + WithLambdaType("ForTest"), + ) + + assert.Equal(t, false, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) + + ld = InvokableLambda( + func(ctx context.Context, input string) (output string, err error) { + return "good", nil + }, + WithLambdaCallbackEnable(false), + WithLambdaType("ForTest"), + ) + + assert.Equal(t, false, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) + }) + + t.Run("StreamableLambda", func(t *testing.T) { + ld := StreamableLambdaWithOption( + func(ctx context.Context, input string, opts ...any) (output *schema.StreamReader[string], err error) { + sr, sw := schema.Pipe[string](1) + sw.Close() + return sr, nil + }, + WithLambdaCallbackEnable(false), + WithLambdaType("ForTest"), + ) + + assert.Equal(t, false, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) + + ld = StreamableLambda( + func(ctx context.Context, input string) (output *schema.StreamReader[string], err error) { + sr, sw := schema.Pipe[string](1) + sw.Close() + return sr, nil + }, + WithLambdaCallbackEnable(false), + WithLambdaType("ForTest"), + ) + + assert.Equal(t, false, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) + }) + + t.Run("CollectableLambda", func(t *testing.T) { + ld := CollectableLambdaWithOption( + func(ctx context.Context, input *schema.StreamReader[string], opts ...any) (output string, err error) { + return "good", nil + }, + WithLambdaCallbackEnable(true), + ) + + assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "", ld.executor.meta.componentImplType) + + ld = CollectableLambda( + func(ctx context.Context, input *schema.StreamReader[string]) (output string, err error) { + return "good", nil + }, + WithLambdaCallbackEnable(true), + ) + + assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "", ld.executor.meta.componentImplType) + }) + + t.Run("TransformableLambda", func(t *testing.T) { + ld := TransformableLambdaWithOption( + func(ctx context.Context, input *schema.StreamReader[string], opts ...any) (output *schema.StreamReader[string], err error) { + sr, sw := schema.Pipe[string](1) + sw.Close() + return sr, nil + }, + WithLambdaCallbackEnable(true), + ) + + assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "", ld.executor.meta.componentImplType) + + ld = TransformableLambda( + func(ctx context.Context, input *schema.StreamReader[string]) (output *schema.StreamReader[string], err error) { + sr, sw := schema.Pipe[string](1) + sw.Close() + return sr, nil + }, + WithLambdaCallbackEnable(true), + ) + + assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "", ld.executor.meta.componentImplType) + }) + + t.Run("AnyLambda", func(t *testing.T) { + ld, err := AnyLambda[string, string]( + func(ctx context.Context, input string, opts ...any) (output string, err error) { + return "good", nil + }, + func(ctx context.Context, input string, opts ...any) (output *schema.StreamReader[string], err error) { + sr, sw := schema.Pipe[string](1) + sw.Close() + return sr, nil + }, + func(ctx context.Context, input *schema.StreamReader[string], opts ...any) (output string, err error) { + return "good", nil + }, + func(ctx context.Context, input *schema.StreamReader[string], opts ...any) (output *schema.StreamReader[string], err error) { + sr, sw := schema.Pipe[string](1) + sw.Close() + return sr, nil + }, + WithLambdaCallbackEnable(true), + WithLambdaType("ForTest"), + ) + assert.NoError(t, err) + + assert.Equal(t, true, ld.executor.meta.isComponentCallbackEnabled) + assert.Equal(t, ComponentOfLambda, ld.executor.meta.component) + assert.Equal(t, "ForTest", ld.executor.meta.componentImplType) + }) +} + +type TestStructForParse struct { + ID int `json:"id"` +} + +func TestMessageParser(t *testing.T) { + t.Run("parse from content", func(t *testing.T) { + parser := schema.NewMessageJSONParser[TestStructForParse](&schema.MessageJSONParseConfig{ + ParseFrom: schema.MessageParseFromContent, + }) + + parserLambda := MessageParser(parser) + + chain := NewChain[*schema.Message, TestStructForParse]() + chain.AppendLambda(parserLambda) + + r, err := chain.Compile(context.Background()) + assert.Nil(t, err) + + parsed, err := r.Invoke(context.Background(), &schema.Message{ + Content: `{"id": 1}`, + }) + assert.Nil(t, err) + assert.Equal(t, 1, parsed.ID) + }) + + t.Run("parse from tool call", func(t *testing.T) { + parser := schema.NewMessageJSONParser[*TestStructForParse](&schema.MessageJSONParseConfig{ + ParseFrom: schema.MessageParseFromToolCall, + }) + + parserLambda := MessageParser(parser) + + chain := NewChain[*schema.Message, *TestStructForParse]() + chain.AppendLambda(parserLambda) + + r, err := chain.Compile(context.Background()) + assert.Nil(t, err) + + parsed, err := r.Invoke(context.Background(), &schema.Message{ + ToolCalls: []schema.ToolCall{ + {Function: schema.FunctionCall{Arguments: `{"id": 1}`}}, + }, + }) + assert.Nil(t, err) + assert.Equal(t, 1, parsed.ID) + }) +} diff --git a/compose/utils.go b/compose/utils.go new file mode 100644 index 0000000..c300b54 --- /dev/null +++ b/compose/utils.go @@ -0,0 +1,339 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + "fmt" + "reflect" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +func mergeMap(vs []any) (any, error) { + typ := reflect.TypeOf(vs[0]) + merged := reflect.MakeMap(typ) + for _, v := range vs { + if reflect.TypeOf(v) != typ { + return nil, fmt.Errorf( + "(mergeMap) field type mismatch. expected: '%v', got: '%v'", typ, reflect.TypeOf(v)) + } + + iter := reflect.ValueOf(v).MapRange() + for iter.Next() { + key, val := iter.Key(), iter.Value() + if merged.MapIndex(key).IsValid() { + return nil, fmt.Errorf("(mergeMap) duplicated key ('%v') found", key.Interface()) + } + merged.SetMapIndex(key, val) + } + } + + return merged.Interface(), nil +} + +// the caller should ensure len(vs) > 1 +func mergeValues(vs []any) (any, error) { + v0 := reflect.ValueOf(vs[0]) + t0 := v0.Type() + k0 := t0.Kind() + + if k0 == reflect.Map { + return mergeMap(vs) + } + + if s, ok := vs[0].(streamReader); ok { + if s.getChunkType().Kind() != reflect.Map { + return nil, fmt.Errorf("(mergeValues | stream type)"+ + " unsupported chunk type: %v", s.getChunkType()) + } + + ss := make([]streamReader, len(vs)-1) + for i := 0; i < len(ss); i++ { + s_, ok_ := vs[i+1].(streamReader) + if !ok_ { + return nil, fmt.Errorf("(mergeStream) unexpected type. "+ + "expect: %v, got: %v", t0, reflect.TypeOf(vs[i])) + } + + if s_.getChunkType() != s.getChunkType() { + return nil, fmt.Errorf("(mergeStream) chunk type mismatch. "+ + "expect: %v, got: %v", s.getChunkType(), s_.getChunkType()) + } + + ss[i] = s_ + } + + ms := s.merge(ss) + + return ms, nil + } + + return nil, fmt.Errorf("(mergeValues) unsupported type: %v", t0) +} + +func invokeWithCallbacks[I, O, TOption any](i Invoke[I, O, TOption]) Invoke[I, O, TOption] { + return func(ctx context.Context, input I, opts ...TOption) (output O, err error) { + if !callbacks.Needed(ctx) { + return i(ctx, input, opts...) + } + + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + + _ = callbacks.OnEnd(ctx, output) + + }() + + ctx = callbacks.OnStart(ctx, input) + + return i(ctx, input, opts...) + } +} + +func genericInvokeWithCallbacks(i invoke) invoke { + return func(ctx context.Context, input any, opts ...any) (output any, err error) { + if !callbacks.Needed(ctx) { + return i(ctx, input, opts...) + } + + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + + _ = callbacks.OnEnd(ctx, output) + + }() + + ctx = callbacks.OnStart(ctx, input) + + return i(ctx, input, opts...) + } +} + +func streamWithCallbacks[I, O, TOption any](s Stream[I, O, TOption]) Stream[I, O, TOption] { + return func(ctx context.Context, input I, opts ...TOption) (output *schema.StreamReader[O], err error) { + if !callbacks.Needed(ctx) { + return s(ctx, input, opts...) + } + + ctx = callbacks.OnStart(ctx, input) + + output, err = s(ctx, input, opts...) + if err != nil { + _ = callbacks.OnError(ctx, err) + return output, err + } + + _, newS := callbacks.OnEndWithStreamOutput(ctx, output) + + return newS, nil + } +} + +func collectWithCallbacks[I, O, TOption any](c Collect[I, O, TOption]) Collect[I, O, TOption] { + return func(ctx context.Context, input *schema.StreamReader[I], opts ...TOption) (output O, err error) { + if !callbacks.Needed(ctx) { + return c(ctx, input, opts...) + } + + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + return + } + _ = callbacks.OnEnd(ctx, output) + }() + + ctx, newS := callbacks.OnStartWithStreamInput(ctx, input) + + return c(ctx, newS, opts...) + } +} + +func transformWithCallbacks[I, O, TOption any](t Transform[I, O, TOption]) Transform[I, O, TOption] { + return func(ctx context.Context, input *schema.StreamReader[I], + opts ...TOption) (output *schema.StreamReader[O], err error) { + + if !callbacks.Needed(ctx) { + return t(ctx, input, opts...) + } + + ctx, input = callbacks.OnStartWithStreamInput(ctx, input) + + output, err = t(ctx, input, opts...) + if err != nil { + _ = callbacks.OnError(ctx, err) + return output, err + } + + _, output = callbacks.OnEndWithStreamOutput(ctx, output) + + return output, nil + } +} + +func genericTransformWithCallbacks(t transform) transform { + return func(ctx context.Context, input streamReader, opts ...any) (output streamReader, err error) { + if !callbacks.Needed(ctx) { + return t(ctx, input, opts...) + } + + inArr := input.copy(2) + is, ok := unpackStreamReader[callbacks.CallbackInput](inArr[1]) + if !ok { // unexpected + return t(ctx, inArr[0], opts...) + } + + ctx, is = callbacks.OnStartWithStreamInput(ctx, is) + + output, err = t(ctx, inArr[0], opts...) + if err != nil { + _ = callbacks.OnError(ctx, err) + return output, err + } + + outArr := output.copy(2) + os, ok := unpackStreamReader[callbacks.CallbackOutput](outArr[1]) + if !ok { // unexpected + return outArr[0], nil + } + + _, _ = callbacks.OnEndWithStreamOutput(ctx, os) + + return outArr[0], nil + } +} + +func initGraphCallbacks(ctx context.Context, info *nodeInfo, meta *executorMeta, opts ...Option) context.Context { + ri := &callbacks.RunInfo{} + if meta != nil { + ri.Component = meta.component + ri.Type = meta.componentImplType + } + + if info != nil { + ri.Name = info.name + } + + var cbs []callbacks.Handler + for i := range opts { + if len(opts[i].graphHandler) != 0 { + cbs = append(cbs, opts[i].graphHandler...) + } + + if len(opts[i].handler) != 0 && len(opts[i].keys) == 0 { + cbs = append(cbs, opts[i].handler...) + } + } + + return callbacks.InitCallbacks(ctx, ri, cbs...) +} + +func streamChunkConvertForCBOutput[O any](o O) (callbacks.CallbackOutput, error) { + return o, nil +} + +func streamChunkConvertForCBInput[I any](i I) (callbacks.CallbackInput, error) { + return i, nil +} + +func toAnyList[T any](in []T) []any { + ret := make([]any, len(in)) + for i := range in { + ret[i] = in[i] + } + return ret +} + +type assignableType uint8 + +const ( + assignableTypeMustNot assignableType = iota + assignableTypeMust + assignableTypeMay +) + +func checkAssignable(input, arg reflect.Type) assignableType { + if arg == nil || input == nil { + return assignableTypeMustNot + } + + if arg == input { + return assignableTypeMust + } + + if arg.Kind() == reflect.Interface && input.Implements(arg) { + return assignableTypeMust + } + if input.Kind() == reflect.Interface { + if arg.Implements(input) { + return assignableTypeMay + } + return assignableTypeMustNot + } + + return assignableTypeMustNot +} + +func extractOption(nodes map[string]*chanCall, opts ...Option) (map[string][]any, error) { + optMap := map[string][]any{} + for _, opt := range opts { + if len(opt.options) == 0 { + continue + } + if len(opt.keys) == 0 { + // common option, check type + for name, c := range nodes { + if reflect.TypeOf(opt.options[0]) == c.action.optionType { // assume that types of options are the same + optMap[name] = append(optMap[name], opt.options...) + } + } + } + for _, key := range opt.keys { + if _, ok := nodes[key]; !ok { + return nil, fmt.Errorf("option has designated an unknown node: %s", key) + } + if nodes[key].action.optionType != reflect.TypeOf(opt.options[0]) { // assume that types of options are the same + return nil, fmt.Errorf("option type[%s] is different from which the designated node[%s] expects[%s]", + reflect.TypeOf(opt.options[0]).String(), key, nodes[key].action.optionType.String()) + } + optMap[key] = append(optMap[key], opt.options...) + } + } + for k, v := range nodes { + if v.action.optionType == nil { + // sub graph + optMap[k] = toAnyList(opts) + } + } + return optMap, nil +} + +func mapToList(m map[string]any) []any { + ret := make([]any, 0, len(m)) + for _, v := range m { + ret = append(ret, v) + } + return ret +} diff --git a/compose/utils_test.go b/compose/utils_test.go new file mode 100644 index 0000000..4a627f7 --- /dev/null +++ b/compose/utils_test.go @@ -0,0 +1,163 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +func TestMergeValues(t *testing.T) { + // merge maps + m1 := map[int]int{1: 1, 2: 2, 3: 3, 4: 4} + m2 := map[int]int{5: 5, 6: 6, 7: 7, 8: 8} + m3 := map[int]int{9: 9, 10: 10, 11: 11} + mergedM, err := mergeValues([]any{m1, m2, m3}) + assert.Nil(t, err) + + m := mergedM.(map[int]int) + + // len(m) == len(m1) + len(m2) + len(m3) + assert.Equal(t, len(m), len(m1)+len(m2)+len(m3)) + + _, err = mergeValues([]any{m1, m2, m3, map[int]int{1: 1}}) + assert.NotNil(t, err) + + _, err = mergeValues([]any{m1, m2, m3, map[int]string{1: "1"}}) + assert.NotNil(t, err) + + // merge stream + ass := []any{ + packStreamReader(schema.StreamReaderFromArray[map[int]bool]([]map[int]bool{{1: true}})), + packStreamReader(schema.StreamReaderFromArray[map[int]bool]([]map[int]bool{{2: true}})), + packStreamReader(schema.StreamReaderFromArray[map[int]bool]([]map[int]bool{{3: true}})), + } + isr, err := mergeValues(ass) + assert.Nil(t, err) + ret, ok := unpackStreamReader[map[int]bool](isr.(streamReader)) + defer ret.Close() + + // check if merge ret is StreamReader + assert.True(t, ok) + + for i := 1; i <= 3; i++ { + num, err := ret.Recv() + assert.Nil(t, err) + + if num[i] != true { + t.Fatalf("stream read num:%d is out of expect", i) + } + } + _, err = ret.Recv() + if err != io.EOF { + t.Fatalf("stream reader isn't return EOF as expect: %v", err) + } +} + +type good interface { + ThisIsGood() bool +} + +type good2 interface { + ThisIsGood2() bool +} + +type good3 interface { + ThisIsGood() bool +} + +type goodImpl struct{} + +func (g *goodImpl) ThisIsGood() bool { + return true +} + +type goodNotImpl struct{} + +func TestValidateType(t *testing.T) { + + t.Run("equal_type", func(t *testing.T) { + arg := generic.TypeOf[int]() + input := generic.TypeOf[int]() + + result := checkAssignable(input, arg) + assert.Equal(t, assignableTypeMust, result) + }) + + t.Run("unequal_type", func(t *testing.T) { + arg := generic.TypeOf[int]() + input := generic.TypeOf[string]() + + result := checkAssignable(input, arg) + assert.Equal(t, assignableTypeMustNot, result) + }) + + t.Run("implement_interface", func(t *testing.T) { + arg := generic.TypeOf[good]() + input := generic.TypeOf[*goodImpl]() + + result := checkAssignable(input, arg) + assert.Equal(t, assignableTypeMust, result) + }) + + t.Run("may_implement_interface", func(t *testing.T) { + arg := generic.TypeOf[*goodImpl]() + input := generic.TypeOf[good]() + + result := checkAssignable(input, arg) + assert.Equal(t, assignableTypeMay, result) + }) + + t.Run("not_implement_interface", func(t *testing.T) { + arg := generic.TypeOf[good]() + input := generic.TypeOf[*goodNotImpl]() + + result := checkAssignable(input, arg) + assert.Equal(t, assignableTypeMustNot, result) + }) + + t.Run("interface_unequal_interface", func(t *testing.T) { + arg := generic.TypeOf[good]() + input := generic.TypeOf[good2]() + + result := checkAssignable(input, arg) + assert.Equal(t, assignableTypeMustNot, result) + }) + + t.Run("interface_equal_interface", func(t *testing.T) { + arg := generic.TypeOf[good]() + input := generic.TypeOf[good3]() + + result := checkAssignable(input, arg) + assert.Equal(t, assignableTypeMust, result) + }) +} + +func TestStreamChunkConvert(t *testing.T) { + o, err := streamChunkConvertForCBOutput(1) + assert.Nil(t, err) + assert.Equal(t, o, 1) + + i, err := streamChunkConvertForCBInput(1) + assert.Nil(t, err) + assert.Equal(t, i, 1) +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..a007d81 --- /dev/null +++ b/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package eino diff --git a/flow/agent/agent_option.go b/flow/agent/agent_option.go new file mode 100644 index 0000000..fa68d8e --- /dev/null +++ b/flow/agent/agent_option.go @@ -0,0 +1,70 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agent + +import "github.com/cloudwego/eino/compose" + +// AgentOption is the common option type for various agent and multi-agent implementations. +// For options intended to use with underlying graph or components, use WithComposeOptions to specify. +// For options intended to use with particular agent/multi-agent implementations, use WrapImplSpecificOptFn to specify. +type AgentOption struct { + implSpecificOptFn any + composeOptions []compose.Option +} + +// GetComposeOptions returns all compose options from the given agent options. +func GetComposeOptions(opts ...AgentOption) []compose.Option { + var result []compose.Option + for _, opt := range opts { + result = append(result, opt.composeOptions...) + } + + return result +} + +// WithComposeOptions returns an agent option that specifies compose options. +func WithComposeOptions(opts ...compose.Option) AgentOption { + return AgentOption{ + composeOptions: opts, + } +} + +// WrapImplSpecificOptFn returns an agent option that specifies a function to modify the implementation-specific options. +func WrapImplSpecificOptFn[T any](optFn func(*T)) AgentOption { + return AgentOption{ + implSpecificOptFn: optFn, + } +} + +// GetImplSpecificOptions returns the implementation-specific options from the given agent options. +func GetImplSpecificOptions[T any](base *T, opts ...AgentOption) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} diff --git a/flow/agent/multiagent/host/callback.go b/flow/agent/multiagent/host/callback.go new file mode 100644 index 0000000..eca57f6 --- /dev/null +++ b/flow/agent/multiagent/host/callback.go @@ -0,0 +1,121 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package host + +import ( + "context" + "io" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/callbacks/template" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/flow/agent" + "github.com/cloudwego/eino/schema" +) + +// MultiAgentCallback is the callback interface for host multi-agent. +type MultiAgentCallback interface { // nolint: byted_s_interface_name + OnHandOff(ctx context.Context, info *HandOffInfo) context.Context +} + +// HandOffInfo is the info which will be passed to MultiAgentCallback.OnHandOff, representing a hand off event. +type HandOffInfo struct { + ToAgentName string + Argument string +} + +// convertCallbacks reads graph call options, extract host.MultiAgentCallback and convert it to callbacks.Handler. +func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler { + agentOptions := agent.GetImplSpecificOptions(&options{}, opts...) + if len(agentOptions.agentCallbacks) == 0 { + return nil + } + + onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + if output == nil || info == nil { + return ctx + } + + msg := output.Message + if msg == nil || msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 { + return ctx + } + + agentName := msg.ToolCalls[0].Function.Name + argument := msg.ToolCalls[0].Function.Arguments + + for i := range agentOptions.agentCallbacks { + cb := agentOptions.agentCallbacks[i] + ctx = cb.OnHandOff(ctx, &HandOffInfo{ + ToAgentName: agentName, + Argument: argument, + }) + } + + return ctx + } + + onChatModelEndWithStreamOutput := func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + if output == nil || info == nil { + return ctx + } + + defer output.Close() + + var msgs []*schema.Message + for { + oneOutput, err := output.Recv() + if err == io.EOF { + break + } + if err != nil { + return ctx + } + + msg := oneOutput.Message + if msg == nil { + continue + } + + msgs = append(msgs, msg) + } + + msg, err := schema.ConcatMessages(msgs) + if err != nil { + return ctx + } + + if msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 { + return ctx + } + + for i := range agentOptions.agentCallbacks { + cb := agentOptions.agentCallbacks[i] + ctx = cb.OnHandOff(ctx, &HandOffInfo{ + ToAgentName: msg.ToolCalls[0].Function.Name, + Argument: msg.ToolCalls[0].Function.Arguments, + }) + } + + return ctx + } + + return template.NewHandlerHelper().ChatModel(&model.CallbackHandler{ + OnEnd: onChatModelEnd, + OnEndWithStreamOutput: onChatModelEndWithStreamOutput, + }).Handler() +} diff --git a/flow/agent/multiagent/host/compose.go b/flow/agent/multiagent/host/compose.go new file mode 100644 index 0000000..f05b17b --- /dev/null +++ b/flow/agent/multiagent/host/compose.go @@ -0,0 +1,207 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package host + +import ( + "context" + "fmt" + "io" + + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +const ( + hostName = "host" + defaultHostPrompt = "decide which tool is best for the task and call only the best tool." +) + +type state struct { + msgs []*schema.Message +} + +// NewMultiAgent creates a new host multi-agent system. +func NewMultiAgent(ctx context.Context, config *MultiAgentConfig) (*MultiAgent, error) { + if err := config.validate(); err != nil { + return nil, err + } + + g := compose.NewStateGraph[[]*schema.Message, *schema.Message](func(context.Context) *state { return &state{} }) + + agentTools := make([]*schema.ToolInfo, 0, len(config.Specialists)) + agentMap := make(map[string]bool, len(config.Specialists)+1) + for i := range config.Specialists { + specialist := config.Specialists[i] + + agentTools = append(agentTools, &schema.ToolInfo{ + Name: specialist.Name, + Desc: specialist.IntendedUse, + ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ + "reason": { + Type: schema.String, + Desc: "the reason to call this tool", + }, + }), + }) + + if err := addSpecialistAgent(specialist, g); err != nil { + return nil, err + } + + agentMap[specialist.Name] = true + } + + if err := addHostAgent(config, agentTools, g); err != nil { + return nil, err + } + + const convertorName = "msg2MsgList" + if err := g.AddLambdaNode(convertorName, compose.ToList[*schema.Message](), compose.WithNodeName("converter")); err != nil { + return nil, err + } + + if err := addDirectAnswerBranch(convertorName, g); err != nil { + return nil, err + } + + if err := addSpecialistsBranch(convertorName, agentMap, g); err != nil { + return nil, err + } + + r, err := g.Compile(ctx, compose.WithNodeTriggerMode(compose.AnyPredecessor), compose.WithGraphName(config.Name)) + if err != nil { + return nil, err + } + + return &MultiAgent{ + runnable: r, + }, nil +} + +func addSpecialistAgent(specialist *Specialist, g *compose.StateGraph[[]*schema.Message, *schema.Message, *state]) error { + if specialist.Invokable != nil || specialist.Streamable != nil { + lambda, err := compose.AnyLambda(specialist.Invokable, specialist.Streamable, nil, nil, compose.WithLambdaType("Specialist")) + if err != nil { + return err + } + preHandler := func(_ context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { + return state.msgs, nil // replace the tool call message with input msgs stored in state + } + if err := g.AddLambdaNode(specialist.Name, lambda, compose.WithStatePreHandler(preHandler), compose.WithNodeName(specialist.Name)); err != nil { + return err + } + } else if specialist.ChatModel != nil { + preHandler := func(_ context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { + if len(specialist.SystemPrompt) > 0 { + return append([]*schema.Message{{ + Role: schema.System, + Content: specialist.SystemPrompt, + }}, state.msgs...), nil + } + + return state.msgs, nil // replace the tool call message with input msgs stored in state + } + if err := g.AddChatModelNode(specialist.Name, specialist.ChatModel, compose.WithStatePreHandler(preHandler), compose.WithNodeName(specialist.Name)); err != nil { + return err + } + } + + return g.AddEdge(specialist.Name, compose.END) +} + +func addHostAgent(config *MultiAgentConfig, agentTools []*schema.ToolInfo, g *compose.StateGraph[[]*schema.Message, *schema.Message, *state]) error { + if err := config.Host.ChatModel.BindTools(agentTools); err != nil { + return err + } + + preHandler := func(_ context.Context, input []*schema.Message, state *state) ([]*schema.Message, error) { + state.msgs = input + if len(config.Host.SystemPrompt) == 0 { + return input, nil + } + return append([]*schema.Message{{ + Role: schema.System, + Content: config.Host.SystemPrompt, + }}, input...), nil + } + if err := g.AddChatModelNode(hostName, config.Host.ChatModel, compose.WithStatePreHandler(preHandler), compose.WithNodeName(hostName)); err != nil { + return err + } + + return g.AddEdge(compose.START, hostName) +} + +func addDirectAnswerBranch(convertorName string, g *compose.StateGraph[[]*schema.Message, *schema.Message, *state]) error { + // handles the case where the host agent returns a direct answer, instead of handling off to any specialist + branch := compose.NewStreamGraphBranch(func(ctx context.Context, sr *schema.StreamReader[*schema.Message]) (endNode string, err error) { + defer sr.Close() + + for { + msg, e := sr.Recv() + if e == io.EOF { + break + } + + if e != nil { + return "", e + } + + if msg.Role != schema.Assistant { + return "", fmt.Errorf("host agent should output assistant message, actual type= %s", msg.Role) + } + + if len(msg.ToolCalls) == 0 { + continue + } + + if len(msg.ToolCalls) > 1 { + handOffs := make([]string, 0, len(msg.ToolCalls)) + for _, t := range msg.ToolCalls { + handOffs = append(handOffs, t.Function.Name) + } + return "", fmt.Errorf("host agent returns multiple handoff candidates: %v", handOffs) + } + + function := msg.ToolCalls[0].Function + if len(function.Name) == 0 { + continue + } + + return convertorName, nil + } + + return compose.END, nil + }, map[string]bool{convertorName: true, compose.END: true}) + + return g.AddBranch(hostName, branch) +} + +func addSpecialistsBranch(convertorName string, agentMap map[string]bool, g *compose.StateGraph[[]*schema.Message, *schema.Message, *state]) error { + branch := compose.NewGraphBranch(func(ctx context.Context, input []*schema.Message) (string, error) { + if len(input) != 1 { + return "", fmt.Errorf("host agent output %d messages, but expected 1", len(input)) + } + + if len(input[0].ToolCalls) != 1 { + return "", fmt.Errorf("host agent output %d tool calls, but expected 1", len(input[0].ToolCalls)) + } + + return input[0].ToolCalls[0].Function.Name, nil + }, agentMap) + + return g.AddBranch(convertorName, branch) +} diff --git a/flow/agent/multiagent/host/compose_test.go b/flow/agent/multiagent/host/compose_test.go new file mode 100644 index 0000000..ebb1427 --- /dev/null +++ b/flow/agent/multiagent/host/compose_test.go @@ -0,0 +1,338 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package host + +import ( + "context" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/flow/agent" + "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/utils/generic" +) + +func TestHostMultiAgent(t *testing.T) { + ctrl := gomock.NewController(t) + mockHostLLM := model.NewMockChatModel(ctrl) + mockSpecialistLLM1 := model.NewMockChatModel(ctrl) + + specialist1 := &Specialist{ + ChatModel: mockSpecialistLLM1, + SystemPrompt: "You are a helpful assistant.", + AgentMeta: AgentMeta{ + Name: "specialist 1", + IntendedUse: "do stuff that works", + }, + } + + specialist2 := &Specialist{ + Invokable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) { + return &schema.Message{ + Role: schema.Assistant, + Content: "specialist2 invoke answer", + }, nil + }, + Streamable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.StreamReader[*schema.Message], error) { + sr, sw := schema.Pipe[*schema.Message](0) + go func() { + sw.Send(&schema.Message{ + Role: schema.Assistant, + Content: "specialist2 stream answer", + }, nil) + sw.Close() + }() + return sr, nil + }, + AgentMeta: AgentMeta{ + Name: "specialist 2", + IntendedUse: "do stuff that works too", + }, + } + + ctx := context.Background() + + mockHostLLM.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() + + hostMA, err := NewMultiAgent(ctx, &MultiAgentConfig{ + Host: Host{ + ChatModel: mockHostLLM, + }, + Specialists: []*Specialist{ + specialist1, + specialist2, + }, + }) + + assert.NoError(t, err) + + t.Run("generate direct answer from host", func(t *testing.T) { + directAnswerMsg := &schema.Message{ + Role: schema.Assistant, + Content: "direct answer", + } + + mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(directAnswerMsg, nil).Times(1) + + mockCallback := &mockAgentCallback{} + + out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback)) + assert.NoError(t, err) + assert.Equal(t, "direct answer", out.Content) + assert.Empty(t, mockCallback.infos) + }) + + t.Run("stream direct answer from host", func(t *testing.T) { + directAnswerMsg1 := &schema.Message{ + Role: schema.Assistant, + Content: "direct ", + } + + directAnswerMsg2 := &schema.Message{ + Role: schema.Assistant, + Content: "answer", + } + + sr, sw := schema.Pipe[*schema.Message](0) + go func() { + sw.Send(directAnswerMsg1, nil) + sw.Send(directAnswerMsg2, nil) + sw.Close() + }() + + mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) + + mockCallback := &mockAgentCallback{} + outStream, err := hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback)) + assert.NoError(t, err) + assert.Empty(t, mockCallback.infos) + + var msgs []*schema.Message + for { + msg, err := outStream.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + msgs = append(msgs, msg) + } + + outStream.Close() + + msg, err := schema.ConcatMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, "direct answer", msg.Content) + }) + + t.Run("generate hand off", func(t *testing.T) { + handOffMsg := &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + Index: generic.PtrOf(0), + Function: schema.FunctionCall{ + Name: specialist1.Name, + Arguments: `{"reason": "specialist 1 is the best"}`, + }, + }, + }, + } + + specialistMsg := &schema.Message{ + Role: schema.Assistant, + Content: "specialist 1 answer", + } + + mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1) + mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1) + + mockCallback := &mockAgentCallback{} + + out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback)) + assert.NoError(t, err) + assert.Equal(t, "specialist 1 answer", out.Content) + assert.Equal(t, []*HandOffInfo{ + { + ToAgentName: specialist1.Name, + Argument: `{"reason": "specialist 1 is the best"}`, + }, + }, mockCallback.infos) + + handOffMsg.ToolCalls[0].Function.Name = specialist2.Name + handOffMsg.ToolCalls[0].Function.Arguments = `{"reason": "specialist 2 is even better"}` + mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1) + + mockCallback = &mockAgentCallback{} + + out, err = hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback)) + assert.NoError(t, err) + assert.Equal(t, "specialist2 invoke answer", out.Content) + assert.Equal(t, []*HandOffInfo{ + { + ToAgentName: specialist2.Name, + Argument: `{"reason": "specialist 2 is even better"}`, + }, + }, mockCallback.infos) + }) + + t.Run("stream hand off to chat model", func(t *testing.T) { + handOffMsg1 := &schema.Message{ + Role: schema.Assistant, + } + + handOffMsg2 := &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + Index: generic.PtrOf(0), + }, + }, + } + + handOffMsg3 := &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + Index: generic.PtrOf(0), + Function: schema.FunctionCall{}, + }, + }, + } + + handOffMsg4 := &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + Index: generic.PtrOf(0), + Function: schema.FunctionCall{ + Name: specialist1.Name, + Arguments: `{"reason": "specialist 1 is the best"}`, + }, + }, + }, + } + + sr, sw := schema.Pipe[*schema.Message](0) + go func() { + sw.Send(handOffMsg1, nil) + sw.Send(handOffMsg2, nil) + sw.Send(handOffMsg3, nil) + sw.Send(handOffMsg4, nil) + sw.Close() + }() + + specialistMsg1 := &schema.Message{ + Role: schema.Assistant, + Content: "specialist ", + } + + specialistMsg2 := &schema.Message{ + Role: schema.Assistant, + Content: "1 answer", + } + + sr1, sw1 := schema.Pipe[*schema.Message](0) + go func() { + sw1.Send(specialistMsg1, nil) + sw1.Send(specialistMsg2, nil) + sw1.Close() + }() + + mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) + mockSpecialistLLM1.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr1, nil).Times(1) + + mockCallback := &mockAgentCallback{} + outStream, err := hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback)) + assert.NoError(t, err) + + var msgs []*schema.Message + for { + msg, err := outStream.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + msgs = append(msgs, msg) + } + + outStream.Close() + + msg, err := schema.ConcatMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, "specialist 1 answer", msg.Content) + + assert.Equal(t, []*HandOffInfo{ + { + ToAgentName: specialist1.Name, + Argument: `{"reason": "specialist 1 is the best"}`, + }, + }, mockCallback.infos) + + handOffMsg4.ToolCalls[0].Function.Name = specialist2.Name + handOffMsg4.ToolCalls[0].Function.Arguments = `{"reason": "specialist 2 is even better"}` + sr, sw = schema.Pipe[*schema.Message](0) + go func() { + sw.Send(handOffMsg1, nil) + sw.Send(handOffMsg2, nil) + sw.Send(handOffMsg3, nil) + sw.Send(handOffMsg4, nil) + sw.Close() + }() + + mockHostLLM.EXPECT().Stream(gomock.Any(), gomock.Any()).Return(sr, nil).Times(1) + + mockCallback = &mockAgentCallback{} + outStream, err = hostMA.Stream(ctx, nil, WithAgentCallbacks(mockCallback)) + assert.NoError(t, err) + + msgs = nil + for { + msg, err := outStream.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + msgs = append(msgs, msg) + } + + outStream.Close() + + msg, err = schema.ConcatMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, "specialist2 stream answer", msg.Content) + + assert.Equal(t, []*HandOffInfo{ + { + ToAgentName: specialist2.Name, + Argument: `{"reason": "specialist 2 is even better"}`, + }, + }, mockCallback.infos) + }) +} + +type mockAgentCallback struct { + infos []*HandOffInfo +} + +func (m *mockAgentCallback) OnHandOff(ctx context.Context, info *HandOffInfo) context.Context { + m.infos = append(m.infos, info) + return ctx +} diff --git a/flow/agent/multiagent/host/doc.go b/flow/agent/multiagent/host/doc.go new file mode 100644 index 0000000..3d205ce --- /dev/null +++ b/flow/agent/multiagent/host/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package host diff --git a/flow/agent/multiagent/host/options.go b/flow/agent/multiagent/host/options.go new file mode 100644 index 0000000..125989c --- /dev/null +++ b/flow/agent/multiagent/host/options.go @@ -0,0 +1,29 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package host + +import "github.com/cloudwego/eino/flow/agent" + +type options struct { + agentCallbacks []MultiAgentCallback +} + +func WithAgentCallbacks(agentCallbacks ...MultiAgentCallback) agent.AgentOption { + return agent.WrapImplSpecificOptFn(func(opts *options) { + opts.agentCallbacks = append(opts.agentCallbacks, agentCallbacks...) + }) +} diff --git a/flow/agent/multiagent/host/types.go b/flow/agent/multiagent/host/types.go new file mode 100644 index 0000000..81c0331 --- /dev/null +++ b/flow/agent/multiagent/host/types.go @@ -0,0 +1,141 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package host implements the host pattern for multi-agent system. +package host + +import ( + "context" + "errors" + "fmt" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/flow/agent" + "github.com/cloudwego/eino/schema" +) + +// MultiAgent is a host multi-agent system. +// A host agent is responsible for deciding which specialist to 'hand off' the task to. +// One or more specialist agents are responsible for completing the task. +type MultiAgent struct { + runnable compose.Runnable[[]*schema.Message, *schema.Message] +} + +func (ma *MultiAgent) Generate(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) { + composeOptions := agent.GetComposeOptions(opts...) + + handler := convertCallbacks(opts...) + if handler != nil { + composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(hostName)) + } + + return ma.runnable.Invoke(ctx, input, composeOptions...) +} + +func (ma *MultiAgent) Stream(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.StreamReader[*schema.Message], error) { + composeOptions := agent.GetComposeOptions(opts...) + + handler := convertCallbacks(opts...) + if handler != nil { + composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(hostName)) + } + + return ma.runnable.Stream(ctx, input, composeOptions...) +} + +// MultiAgentConfig is the config for host multi-agent system. +type MultiAgentConfig struct { + Host Host + Specialists []*Specialist + + Name string // the name of the host multi agent +} + +func (conf *MultiAgentConfig) validate() error { + if conf == nil { + return errors.New("host multi agent config is nil") + } + + if conf.Host.ChatModel == nil { + return errors.New("host multi agent host ChatModel is nil") + } + + if len(conf.Specialists) == 0 { + return errors.New("host multi agent specialists are empty") + } + + if len(conf.Host.SystemPrompt) == 0 { + conf.Host.SystemPrompt = defaultHostPrompt + } + + for _, s := range conf.Specialists { + if s.ChatModel == nil && s.Invokable == nil && s.Streamable == nil { + return fmt.Errorf("specialist %s has no chat model or Invokable or Streamable", s.Name) + } + + if err := s.AgentMeta.validate(); err != nil { + return err + } + } + + if len(conf.Name) == 0 { + conf.Name = "host multi agent" + } + + return nil +} + +// AgentMeta is the meta information of an agent within a multi-agent system. +type AgentMeta struct { + Name string // the name of the agent, should be unique within multi-agent system + IntendedUse string // the intended use-case of the agent, used as the reason for the multi-agent system to hand over control to this agent +} + +func (am AgentMeta) validate() error { + if len(am.Name) == 0 { + return errors.New("agent meta name is empty") + } + + if len(am.IntendedUse) == 0 { + return errors.New("agent meta intended use is empty") + } + + return nil +} + +// Host is the host agent within a multi-agent system. +// Currently, it can only be a model.ChatModel. +type Host struct { + ChatModel model.ChatModel + SystemPrompt string +} + +// Specialist is a specialist agent within a host multi-agent system. +// It can be a model.ChatModel or any Invokable and/or Streamable, such as react.Agent. +// ChatModel and (Invokable / Streamable) are mutually exclusive, only one should be provided. +// If Invokable is provided but not Streamable, then the Specialist will be compose.InvokableLambda. +// If Streamable is provided but not Invokable, then the Specialist will be compose.StreamableLambda. +// if Both Invokable and Streamable is provided, then the Specialist will be compose.AnyLambda. +type Specialist struct { + AgentMeta + + ChatModel model.ChatModel + SystemPrompt string + + Invokable compose.Invoke[[]*schema.Message, *schema.Message, agent.AgentOption] + Streamable compose.Stream[[]*schema.Message, *schema.Message, agent.AgentOption] +} diff --git a/flow/agent/react/callback.go b/flow/agent/react/callback.go new file mode 100644 index 0000000..63758f7 --- /dev/null +++ b/flow/agent/react/callback.go @@ -0,0 +1,34 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package react + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/callbacks/template" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" +) + +// BuildAgentCallback builds a callback handler for agent. +// eg. +// +// callback := BuildAgentCallback(modelHandler, toolHandler) +// agent, err := react.NewAgent(ctx, &AgentConfig{}) +// agent.Generate(ctx, input, agent.WithComposeOptions(compose.WithCallbacks(callback))) +func BuildAgentCallback(modelHandler *model.CallbackHandler, toolHandler *tool.CallbackHandler) callbacks.Handler { + return template.NewHandlerHelper().ChatModel(modelHandler).Tool(toolHandler).Handler() +} diff --git a/flow/agent/react/doc.go b/flow/agent/react/doc.go new file mode 100644 index 0000000..1e012c6 --- /dev/null +++ b/flow/agent/react/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package react diff --git a/flow/agent/react/react.go b/flow/agent/react/react.go new file mode 100644 index 0000000..ace334e --- /dev/null +++ b/flow/agent/react/react.go @@ -0,0 +1,371 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package react + +import ( + "context" + "errors" + "fmt" + "reflect" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/flow/agent" + "github.com/cloudwego/eino/schema" +) + +type nodeState struct { + Messages []*schema.Message +} + +// MessageModifier modify the input messages before the model is called. +type MessageModifier func(ctx context.Context, input []*schema.Message) []*schema.Message + +// AgentConfig is the config for react agent. +type AgentConfig struct { + // Model is the chat model to be used for handling user messages. + Model model.ChatModel + // ToolsConfig is the config for tools node. + ToolsConfig compose.ToolsNodeConfig + + // MessageModifier. + // modify the input messages before the model is called, it's useful when you want to add some system prompt or other messages. + MessageModifier MessageModifier + + // MaxStep. + // default 12 of steps in pregel (node num + 10). + MaxStep int `json:"max_step"` + + // Tools that will make agent return directly when the tool is called. + ToolReturnDirectly map[string]struct{} +} + +// NewPersonaModifier add the system prompt as persona before the model is called. +// example: +// +// persona := "You are an expert in golang." +// config := AgentConfig{ +// Model: model, +// MessageModifier: NewPersonaModifier(persona), +// } +// agent, err := NewAgent(ctx, config) +// if err != nil {return} +// msg, err := agent.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "how to build agent with eino"}}) +// if err != nil {return} +// println(msg.Content) +func NewPersonaModifier(persona string) MessageModifier { + return func(ctx context.Context, input []*schema.Message) []*schema.Message { + res := make([]*schema.Message, 0, len(input)+1) + + res = append(res, schema.SystemMessage(persona)) + res = append(res, input...) + return res + } +} + +// NewAgent creates a react agent. +func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) { + agent := &Agent{} + + runnable, err := agent.build(ctx, config) + if err != nil { + return nil, err + } + + agent.runnable = runnable + + return agent, nil +} + +// Agent is the react agent. +// React agent is a simple agent that handles user messages with a chat model and tools. +// react will call the chat model, if the message contains tool calls, it will call the tools. +// if the tool is configured to return directly, react will return directly. +// otherwise, react will continue to call the chat model until the message contains no tool calls. +// eg. +// +// agent, err := react.NewAgent(ctx, &react.AgentConfig{}) +// if err != nil {...} +// msg, err := agent.Generate(ctx, []*schema.Message{{Role: schema.User, Content: "how to build agent with eino"}}) +// if err != nil {...} +// println(msg.Content) +type Agent struct { + runnable compose.Runnable[[]*schema.Message, *schema.Message] +} + +func (r *Agent) build(ctx context.Context, config *AgentConfig) (compose.Runnable[[]*schema.Message, *schema.Message], error) { + var ( + nodeKeyTools = "tools" + nodeKeyChatModel = "chat" + ) + + if config.MessageModifier == nil { + config.MessageModifier = func(ctx context.Context, input []*schema.Message) []*schema.Message { + return input + } + } + + toolInfos := make([]*schema.ToolInfo, 0, len(config.ToolsConfig.Tools)) + for _, t := range config.ToolsConfig.Tools { + tl, err := t.Info(ctx) + if err != nil { + return nil, err + } + + toolInfos = append(toolInfos, tl) + } + + err := config.Model.BindTools(toolInfos) + if err != nil { + return nil, err + } + + // graph + graph := compose.NewStateGraph[[]*schema.Message, *schema.Message](func(ctx context.Context) *nodeState { + s := &nodeState{ + Messages: make([]*schema.Message, 0, 3), + } + return s + }) + + err = graph.AddChatModelNode(nodeKeyChatModel, config.Model, + compose.WithStatePreHandler(func(ctx context.Context, input []*schema.Message, state *nodeState) ([]*schema.Message, error) { + state.Messages = append(state.Messages, input...) + + modifiedInput := make([]*schema.Message, 0, len(input)) + modifiedInput = append(modifiedInput, state.Messages...) + modifiedInput = config.MessageModifier(ctx, modifiedInput) + + return modifiedInput, nil + }), + ) + if err != nil { + return nil, err + } + + toolsNode, err := compose.NewToolNode(ctx, &config.ToolsConfig) + if err != nil { + return nil, err + } + + err = graph.AddToolsNode(nodeKeyTools, toolsNode, compose.WithStatePreHandler(func(ctx context.Context, input *schema.Message, state *nodeState) (*schema.Message, error) { + state.Messages = append(state.Messages, input) + + if len(config.ToolReturnDirectly) > 0 { + if err := checkReturnDirectlyBeforeToolsNode(input, config); err != nil { + return nil, err + } + } + + if err := cacheToolCallInfo(ctx, input.ToolCalls); err != nil { + return nil, err + } + + return input, nil + })) + if err != nil { + return nil, err + } + + err = graph.AddEdge(compose.START, nodeKeyChatModel) + if err != nil { + return nil, err + } + + err = graph.AddBranch(nodeKeyChatModel, compose.NewStreamGraphBranch(func(ctx context.Context, sr *schema.StreamReader[*schema.Message]) (endNode string, err error) { + msg, err := sr.Recv() + if err != nil { + return "", err + } + defer sr.Close() + + if len(msg.ToolCalls) == 0 { + return compose.END, nil + } + + return nodeKeyTools, nil + }, map[string]bool{nodeKeyTools: true, compose.END: true})) + if err != nil { + return nil, err + } + + if len(config.ToolReturnDirectly) > 0 { + returnDirectlyConvertor := func(ctx context.Context, msgs *schema.StreamReader[[]*schema.Message]) (*schema.StreamReader[*schema.Message], error) { + flattened := schema.StreamReaderWithConvert(msgs, func(msgs []*schema.Message) (*schema.Message, error) { + if len(msgs) != 1 { + return nil, fmt.Errorf("return directly tools node output expected to have only one msg, but got %d", len(msgs)) + } + return msgs[0], nil + }) + + return flattened, nil + } + + nodeKeyConvertor := "convertor" + err = graph.AddLambdaNode(nodeKeyConvertor, compose.TransformableLambda(returnDirectlyConvertor)) + if err != nil { + return nil, err + } + + err = graph.AddBranch(nodeKeyTools, compose.NewStreamGraphBranch(func(ctx context.Context, msgsStream *schema.StreamReader[[]*schema.Message]) (endNode string, err error) { + defer msgsStream.Close() + + msgs, err := msgsStream.Recv() + if err != nil { + return "", fmt.Errorf("receive first packet from tools node result returns err: %w", err) + } + + if len(msgs) == 0 { + return "", errors.New("receive first package from tools node result returns empty msgs") + } + + msg := msgs[0] + toolCallID := msg.ToolCallID + if len(toolCallID) == 0 { + return "", errors.New("receive first package from tools node result returns empty tool call id") + } + + toolCall, err := getToolCallInfo(ctx, toolCallID) + if err != nil { + return "", fmt.Errorf("get tool call info for tool call id: %s returns err: %w", toolCallID, err) + } + + if _, ok := config.ToolReturnDirectly[toolCall.Function.Name]; ok { // return directly will appear in first message + return nodeKeyConvertor, nil + } + + return nodeKeyChatModel, nil + }, map[string]bool{nodeKeyChatModel: true, nodeKeyConvertor: true})) + if err != nil { + return nil, err + } + + if err = graph.AddEdge(nodeKeyConvertor, compose.END); err != nil { + return nil, err + } + } else { + if err = graph.AddEdge(nodeKeyTools, nodeKeyChatModel); err != nil { + return nil, err + } + } + + var opts []compose.GraphCompileOption + if config.MaxStep > 0 { + opts = append(opts, compose.WithMaxRunSteps(config.MaxStep)) + } + + runnable, err := graph.Compile(ctx, opts...) + if err != nil { + return nil, err + } + + return runnable, nil +} + +type toolCallInfoKey struct{} + +func cacheToolCallInfo(ctx context.Context, toolCalls []schema.ToolCall) error { + info := ctx.Value(toolCallInfoKey{}) + if info == nil { + return errors.New("tool call info not found in context") + } + + toolCallInfo, ok := info.(*map[string]schema.ToolCall) + if !ok { + return fmt.Errorf("tool call info type error, not atomic.Value: %v", reflect.TypeOf(info)) + } + + m := make(map[string]schema.ToolCall, len(toolCalls)) + for i := range toolCalls { + m[toolCalls[i].ID] = toolCalls[i] + } + + *toolCallInfo = m + + return nil +} + +func getToolCallInfo(ctx context.Context, toolCallID string) (*schema.ToolCall, error) { + info := ctx.Value(toolCallInfoKey{}) + if info == nil { + return nil, errors.New("tool call info not found in context") + } + + toolCallInfo, ok := info.(*map[string]schema.ToolCall) + if !ok { + return nil, fmt.Errorf("tool call info type error, not map[string]schema.ToolCall: %v", reflect.TypeOf(info)) + } + + if toolCallInfo == nil { + return nil, errors.New("tool call info is nil") + } + + toolCall, ok := (*toolCallInfo)[toolCallID] + if !ok { + return nil, fmt.Errorf("tool call info not found for tool call id: %s", toolCallID) + } + + return &toolCall, nil +} + +func checkReturnDirectlyBeforeToolsNode(input *schema.Message, config *AgentConfig) error { + if len(input.ToolCalls) > 1 { // check if a return directly tool call belongs to a batch of parallel tool calls, which is not supported for now + var returnDirectly bool + toolCalls := input.ToolCalls + toolNames := make([]string, 0, len(toolCalls)) + for i := range toolCalls { + toolNames = append(toolNames, toolCalls[i].Function.Name) + + if _, ok := config.ToolReturnDirectly[toolCalls[i].Function.Name]; ok { + returnDirectly = true + } + } + + if returnDirectly { + return fmt.Errorf("return directly tool call is not allowed when there are parallel tool calls: %v", toolNames) + } + } + + return nil +} + +// Generate generates a response from the agent. +func (r *Agent) Generate(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (output *schema.Message, err error) { + m := make(map[string]schema.ToolCall, 0) + ctx = context.WithValue(ctx, toolCallInfoKey{}, &m) + + output, err = r.runnable.Invoke(ctx, input, agent.GetComposeOptions(opts...)...) + if err != nil { + return nil, err + } + + return output, nil +} + +// Stream calls the agent and returns a stream response. +func (r *Agent) Stream(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) ( + output *schema.StreamReader[*schema.Message], err error) { + m := make(map[string]schema.ToolCall, 0) + ctx = context.WithValue(ctx, toolCallInfoKey{}, &m) + + res, err := r.runnable.Stream(ctx, input, agent.GetComposeOptions(opts...)...) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/flow/agent/react/react_test.go b/flow/agent/react/react_test.go new file mode 100644 index 0000000..bad99f3 --- /dev/null +++ b/flow/agent/react/react_test.go @@ -0,0 +1,605 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package react + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "testing" + + "github.com/bytedance/sonic" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/tool" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/flow/agent" + mockModel "github.com/cloudwego/eino/internal/mock/components/model" + "github.com/cloudwego/eino/schema" +) + +func TestReact(t *testing.T) { + ctx := context.Background() + + fakeTool := &fakeToolGreetForTest{ + tarCount: 3, + } + + info, err := fakeTool.Info(ctx) + assert.NoError(t, err) + + ctrl := gomock.NewController(t) + cm := mockModel.NewMockChatModel(ctrl) + + times := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + times += 1 + if times <= 2 { + info, _ := fakeTool.Info(ctx) + + return schema.AssistantMessage("hello max", + []schema.ToolCall{ + { + ID: randStr(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStr()), + }, + }, + }), + nil + } + + return schema.AssistantMessage("bye", nil), nil + }).AnyTimes() + cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() + + err = cm.BindTools([]*schema.ToolInfo{info}) + assert.NoError(t, err) + + a, err := NewAgent(ctx, &AgentConfig{ + Model: cm, + ToolsConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool}, + }, + + MaxStep: 40, + }) + assert.Nil(t, err) + + out, err := a.Generate(ctx, []*schema.Message{ + { + Role: schema.User, + Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + }, + }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) + assert.Nil(t, err) + + if out != nil { + t.Log(out.Content) + } + + // test return directly + times = 0 + a, err = NewAgent(ctx, &AgentConfig{ + Model: cm, + ToolsConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool}, + }, + + MaxStep: 40, + ToolReturnDirectly: map[string]struct{}{info.Name: {}}, + }) + assert.Nil(t, err) + + out, err = a.Generate(ctx, []*schema.Message{ + { + Role: schema.User, + Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + }, + }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) + assert.Nil(t, err) + + if out != nil { + t.Log(out.Content) + } +} + +func TestReactStream(t *testing.T) { + ctx := context.Background() + + fakeTool := &fakeToolGreetForTest{ + tarCount: 20, + } + + fakeStreamTool := &fakeStreamToolGreetForTest{ + tarCount: 20, + } + + ctrl := gomock.NewController(t) + cm := mockModel.NewMockChatModel(ctrl) + + times := 0 + cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) ( + *schema.StreamReader[*schema.Message], error) { + sr, sw := schema.Pipe[*schema.Message](1) + defer sw.Close() + + info, _ := fakeTool.Info(ctx) + streamInfo, _ := fakeStreamTool.Info(ctx) + + times += 1 + if times <= 2 { + sw.Send(schema.AssistantMessage("hello max", + []schema.ToolCall{ + { + ID: randStr(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "tool"}`, randStr()), + }, + }, + }), + nil) + return sr, nil + } else if times == 3 { + sw.Send(schema.AssistantMessage("hello max", + []schema.ToolCall{ + { + ID: randStr(), + Function: schema.FunctionCall{ + Name: streamInfo.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "stream tool"}`, randStr()), + }, + }, + }), + nil) + return sr, nil + } else if times == 4 { // parallel tool call + sw.Send(schema.AssistantMessage("hello max", + []schema.ToolCall{ + { + ID: randStr(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "tool"}`, randStr()), + }, + }, + { + ID: randStr(), + Function: schema.FunctionCall{ + Name: streamInfo.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "stream tool"}`, randStr()), + }, + }, + }), + nil) + return sr, nil + } + + sw.Send(schema.AssistantMessage("bye", nil), nil) + return sr, nil + }).AnyTimes() + + a, err := NewAgent(ctx, &AgentConfig{ + Model: cm, + ToolsConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool, fakeStreamTool}, + }, + + MaxStep: 40, + }) + assert.Nil(t, err) + + out, err := a.Stream(ctx, []*schema.Message{ + { + Role: schema.User, + Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + }, + }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) + if err != nil { + t.Fatal(err) + } + + defer out.Close() + + msgs := make([]*schema.Message, 0) + for { + msg, err := out.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + t.Fatal(err) + } + + msgs = append(msgs, msg) + } + + assert.Equal(t, 1, len(msgs)) + + msg, err := schema.ConcatMessages(msgs) + if err != nil { + t.Fatal(err) + } + + t.Log(msg.Content) + + info, err := fakeStreamTool.Info(ctx) + assert.NoError(t, err) + + // test return directly + a, err = NewAgent(ctx, &AgentConfig{ + Model: cm, + ToolsConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool, fakeStreamTool}, + }, + + MaxStep: 40, + ToolReturnDirectly: map[string]struct{}{info.Name: {}}, // one of the two tools is return directly + }) + assert.Nil(t, err) + + times = 0 + out, err = a.Stream(ctx, []*schema.Message{ + { + Role: schema.User, + Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + }, + }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) + if err != nil { + t.Fatal(err) + } + + defer out.Close() + + msgs = make([]*schema.Message, 0) + for { + msg, err := out.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + t.Fatal(err) + } + + msgs = append(msgs, msg) + } + + assert.Equal(t, 1, len(msgs)) + + msg, err = schema.ConcatMessages(msgs) + if err != nil { + t.Fatal(err) + } + + t.Log(msg.Content) + + // return directly tool call within parallel tool calls + _, err = a.Stream(ctx, []*schema.Message{ + { + Role: schema.User, + Content: "使用 greet tool 持续打招呼,直到得到一个 bye 的回复,打招呼名字按照以下顺序: max、bob、alice、john、marry、joe、ken、lily, 请直接开始!请直接开始!请直接开始!", + }, + }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) + assert.Error(t, err) + assert.Contains(t, err.Error(), "return directly tool call is not allowed when there are parallel tool calls") +} + +func TestReactWithModifier(t *testing.T) { + ctx := context.Background() + + fakeTool := &fakeToolGreetForTest{} + ctrl := gomock.NewController(t) + cm := mockModel.NewMockChatModel(ctrl) + + times := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + times += 1 + if times <= 2 { + info, _ := fakeTool.Info(ctx) + + return schema.AssistantMessage("hello max", + []schema.ToolCall{ + { + ID: randStr(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStr()), + }, + }, + }), + nil + } + + return schema.AssistantMessage("bye", nil), nil + }).AnyTimes() + cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() + + a, err := NewAgent(ctx, &AgentConfig{ + Model: cm, + ToolsConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool}, + }, + MessageModifier: NewPersonaModifier("you are a helpful assistant"), + + MaxStep: 40, + }) + + assert.Nil(t, err) + + out, err := a.Generate(ctx, []*schema.Message{ + { + Role: schema.User, + Content: "hello", + }, + }, agent.WithComposeOptions(compose.WithCallbacks(callbackForTest))) + if err != nil { + t.Fatal(err) + } + + if out != nil { + t.Log(out.Content) + } +} + +func TestAgentInGraph(t *testing.T) { + t.Run("agent generate in chain", func(t *testing.T) { + ctx := context.Background() + + fakeTool := &fakeToolGreetForTest{} + ctrl := gomock.NewController(t) + cm := mockModel.NewMockChatModel(ctrl) + + times := 0 + cm.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + + times += 1 + if times <= 2 { + info, _ := fakeTool.Info(ctx) + + return schema.AssistantMessage("hello max", + []schema.ToolCall{ + { + ID: randStr(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStr()), + }, + }, + }), + nil + } + + return schema.AssistantMessage("bye", nil), nil + + }).Times(3) + cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() + + agent, err := NewAgent(ctx, &AgentConfig{ + Model: cm, + ToolsConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{fakeTool, &fakeStreamToolGreetForTest{}}, + }, + + MaxStep: 40, + }) + assert.Nil(t, err) + + chain := compose.NewChain[[]*schema.Message, string]() + agentLambda, err := compose.AnyLambda(agent.Generate, agent.Stream, nil, nil) + assert.Nil(t, err) + + chain. + AppendLambda(agentLambda). + AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (string, error) { + t.Log("got agent response: ", input.Content) + return input.Content, nil + })) + r, err := chain.Compile(ctx) + assert.Nil(t, err) + + res, err := r.Invoke(ctx, []*schema.Message{{Role: schema.User, Content: "hello"}}, + compose.WithCallbacks(callbackForTest)) + assert.Nil(t, err) + + t.Log(res) + }) + + t.Run("agent stream in chain", func(t *testing.T) { + + fakeStreamTool := &fakeStreamToolGreetForTest{} + ctx := context.Background() + ctrl := gomock.NewController(t) + cm := mockModel.NewMockChatModel(ctrl) + + times := 0 + cm.EXPECT().Stream(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, input []*schema.Message, opts ...model.Option) ( + *schema.StreamReader[*schema.Message], error) { + sr, sw := schema.Pipe[*schema.Message](1) + defer sw.Close() + + times += 1 + if times <= 2 { + info, _ := fakeStreamTool.Info(ctx) + sw.Send(schema.AssistantMessage("hello max", + []schema.ToolCall{ + { + ID: randStr(), + Function: schema.FunctionCall{ + Name: info.Name, + Arguments: fmt.Sprintf(`{"name": "%s", "hh": "123"}`, randStr()), + }, + }, + }), + nil) + return sr, nil + } + + sw.Send(schema.AssistantMessage("bye", nil), nil) + return sr, nil + }).Times(3) + cm.EXPECT().BindTools(gomock.Any()).Return(nil).AnyTimes() + + agent, err := NewAgent(ctx, &AgentConfig{ + Model: cm, + ToolsConfig: compose.ToolsNodeConfig{ + Tools: []tool.BaseTool{&fakeToolGreetForTest{}, fakeStreamTool}, + }, + + MaxStep: 40, + }) + assert.Nil(t, err) + + chain := compose.NewChain[[]*schema.Message, string]() + agentLambda, err := compose.AnyLambda(agent.Generate, agent.Stream, nil, nil) + assert.Nil(t, err) + + chain. + AppendLambda(agentLambda). + AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *schema.Message) (string, error) { + t.Log("got agent response: ", input.Content) + return input.Content, nil + })) + r, err := chain.Compile(ctx) + assert.Nil(t, err) + + outStream, err := r.Stream(ctx, []*schema.Message{{Role: schema.User, Content: "hello"}}, + compose.WithCallbacks(callbackForTest)) + if err != nil { + t.Fatal(err) + } + + defer outStream.Close() + + msg := "" + for { + msgItem, err := outStream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + + t.Fatal(err) + } + + msg += msgItem + } + + t.Log(msg) + }) + +} + +type fakeStreamToolGreetForTest struct { + tarCount int + curCount int +} + +func (t *fakeStreamToolGreetForTest) StreamableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) ( + *schema.StreamReader[string], error) { + p := &fakeToolInput{} + err := sonic.UnmarshalString(argumentsInJSON, p) + if err != nil { + return nil, err + } + + if t.curCount >= t.tarCount { + s := schema.StreamReaderFromArray([]string{`{"say": "bye"}`}) + return s, nil + } + t.curCount += 1 + s := schema.StreamReaderFromArray([]string{fmt.Sprintf(`{"say": "hello %v"}`, p.Name)}) + return s, nil +} + +type fakeToolGreetForTest struct { + tarCount int + curCount int +} + +func (t *fakeToolGreetForTest) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "greet", + Desc: "greet with name", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "name": { + Desc: "user name who to greet", + Required: true, + Type: schema.String, + }, + }), + }, nil +} + +func (t *fakeStreamToolGreetForTest) Info(ctx context.Context) (*schema.ToolInfo, error) { + return &schema.ToolInfo{ + Name: "greet in stream", + Desc: "greet with name in stream", + ParamsOneOf: schema.NewParamsOneOfByParams( + map[string]*schema.ParameterInfo{ + "name": { + Desc: "user name who to greet", + Required: true, + Type: schema.String, + }, + }), + }, nil +} + +func (t *fakeToolGreetForTest) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { + p := &fakeToolInput{} + err := sonic.UnmarshalString(argumentsInJSON, p) + if err != nil { + return "", err + } + + if t.curCount >= t.tarCount { + return `{"say": "bye"}`, nil + } + + t.curCount += 1 + return fmt.Sprintf(`{"say": "hello %v"}`, p.Name), nil +} + +type fakeToolInput struct { + Name string `json:"name"` +} + +func randStr() string { + seeds := []rune("abcdefghijklmnopqrstuvwxyz") + b := make([]rune, 8) + for i := range b { + b[i] = seeds[rand.Intn(len(seeds))] + } + return string(b) +} + +var callbackForTest = BuildAgentCallback(&model.CallbackHandler{}, &tool.CallbackHandler{}) diff --git a/flow/retriever/multiquery/multi_query.go b/flow/retriever/multiquery/multi_query.go new file mode 100644 index 0000000..b51b40c --- /dev/null +++ b/flow/retriever/multiquery/multi_query.go @@ -0,0 +1,211 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package multiquery + +import ( + "context" + "fmt" + "strings" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/flow/retriever/utils" + "github.com/cloudwego/eino/schema" +) + +const ( + defaultRewritePrompt = `You are an helpful assistant. + Your role is to create three different versions of the user query to retrieve relevant documents from store. + Your goal is to improve the performance of similarity search by generating text from different perspectives based on the user query. + Only provide the generated queries and separate them by newlines. + user query: {{query}}` + defaultQueryVariable = "query" + defaultMaxQueriesNum = 5 +) + +var deduplicateFusion = func(ctx context.Context, docs [][]*schema.Document) ([]*schema.Document, error) { + m := map[string]bool{} + var ret []*schema.Document + for i := range docs { + for j := range docs[i] { + if _, ok := m[docs[i][j].ID]; !ok { + m[docs[i][j].ID] = true + ret = append(ret, docs[i][j]) + } + } + } + return ret, nil +} + +// NewRetriever creates a multi-query retriever. +// multi-query retriever is useful when you want to retrieve documents from multiple retrievers with different queries. +// eg. +// +// multiRetriever := multiquery.NewRetriever(ctx, &multiquery.Config{}) +// docs, err := multiRetriever.Retrieve(ctx, "how to build agent with eino") +// if err != nil { +// ... +// } +// println(docs) +// +// for more info: https://bytedance.larkoffice.com/wiki/G8T2w5bYuigJ4LkMi1ycw6VznAh#A4PqdcJmpoveWcxv8NPc70TanLb +func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) { + var err error + + // config validate + if config.OrigRetriever == nil { + return nil, fmt.Errorf("OrigRetriever is required") + } + if config.RewriteHandler == nil && config.RewriteLLM == nil { + return nil, fmt.Errorf("at least one of RewriteHandler and RewriteLLM must not be empty") + } + + // construct rewrite chain + rewriteChain := compose.NewChain[string, []string]() + if config.RewriteHandler != nil { + rewriteChain.AppendLambda(compose.InvokableLambda(config.RewriteHandler), compose.WithNodeName("CustomQueryRewriter")) + } else { + tpl := config.RewriteTemplate + variable := config.QueryVar + parser := config.LLMOutputParser + if tpl == nil { + tpl = prompt.FromMessages(schema.Jinja2, schema.UserMessage(defaultRewritePrompt)) + variable = defaultQueryVariable + } + if parser == nil { + parser = func(ctx context.Context, message *schema.Message) ([]string, error) { + return strings.Split(message.Content, "\n"), nil + } + } + + rewriteChain. + AppendLambda(compose.InvokableLambda(func(ctx context.Context, input string) (output map[string]any, err error) { + return map[string]any{variable: input}, nil + }), compose.WithNodeName("Converter")). + AppendChatTemplate(tpl). + AppendChatModel(config.RewriteLLM). + AppendLambda(compose.InvokableLambda(parser), compose.WithNodeName("OutputParser")) + } + rewriteRunner, err := rewriteChain.Compile(ctx, compose.WithGraphName("QueryRewrite")) + if err != nil { + return nil, err + } + + maxQueriesNum := config.MaxQueriesNum + if maxQueriesNum == 0 { + maxQueriesNum = defaultMaxQueriesNum + } + + fusionFunc := config.FusionFunc + if fusionFunc == nil { + fusionFunc = deduplicateFusion + } + + return &multiQueryRetriever{ + queryRunner: rewriteRunner, + maxQueriesNum: maxQueriesNum, + origRetriever: config.OrigRetriever, + fusionFunc: fusionFunc, + }, nil +} + +// Config is the config for multi-query retriever. +type Config struct { + // Rewrite + // 1. set the following fields to use llm to generate multi queries + // a. chat model, required + RewriteLLM model.ChatModel + // b. prompt llm to generate multi queries, we provide default template so you can leave this field blank + RewriteTemplate prompt.ChatTemplate + // c. origin query variable of your custom template, it can be empty if you use default template + QueryVar string + // d. parser llm output to queries, split content using "\n" by default + LLMOutputParser func(context.Context, *schema.Message) ([]string, error) + // 2. set RewriteHandler to provide custom query generation logic, possibly without a ChatModel. If this field is set, it takes precedence over other configurations above + RewriteHandler func(ctx context.Context, query string) ([]string, error) + // limit max queries num that Rewrite generates, and excess queries will be truncated, 5 by default + MaxQueriesNum int + + // Origin Retriever + OrigRetriever retriever.Retriever + + // fusion docs recalled from multi retrievers, remove dup based on document id by default + FusionFunc func(ctx context.Context, docs [][]*schema.Document) ([]*schema.Document, error) +} + +type multiQueryRetriever struct { + queryRunner compose.Runnable[string, []string] + maxQueriesNum int + origRetriever retriever.Retriever + fusionFunc func(ctx context.Context, docs [][]*schema.Document) ([]*schema.Document, error) +} + +// Retrieve retrieves documents from the multi-query retriever. +func (m *multiQueryRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + // generate queries + queries, err := m.queryRunner.Invoke(ctx, query) + if err != nil { + return nil, err + } + if len(queries) > m.maxQueriesNum { + queries = queries[:m.maxQueriesNum] + } + + // retrieve + tasks := make([]*utils.RetrieveTask, len(queries)) + for i := range queries { + tasks[i] = &utils.RetrieveTask{Retriever: m.origRetriever, Query: queries[i]} + } + utils.ConcurrentRetrieveWithCallback(ctx, tasks) + result := make([][]*schema.Document, len(queries)) + for i, task := range tasks { + if task.Err != nil { + return nil, task.Err + } + result[i] = task.Result + } + + // fusion + ctx = ctxWithFusionRunInfo(ctx) + ctx = callbacks.OnStart(ctx, result) + fusionDocs, err := m.fusionFunc(ctx, result) + if err != nil { + callbacks.OnError(ctx, err) + return nil, err + } + callbacks.OnEnd(ctx, fusionDocs) + return fusionDocs, nil +} + +// GetType returns the type of the retriever (MultiQuery). +func (m *multiQueryRetriever) GetType() string { + return "MultiQuery" +} + +func ctxWithFusionRunInfo(ctx context.Context) context.Context { + runInfo := &callbacks.RunInfo{ + Component: compose.ComponentOfLambda, + Type: "FusionFunc", + } + + runInfo.Name = runInfo.Type + string(runInfo.Component) + + return callbacks.SwitchRunInfo(ctx, runInfo) +} diff --git a/flow/retriever/multiquery/multi_query_test.go b/flow/retriever/multiquery/multi_query_test.go new file mode 100644 index 0000000..84baf06 --- /dev/null +++ b/flow/retriever/multiquery/multi_query_test.go @@ -0,0 +1,118 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package multiquery + +import ( + "context" + "strings" + "testing" + + "github.com/cloudwego/eino/components/model" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/schema" +) + +type mockRetriever struct { +} + +func (m *mockRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + var ret []*schema.Document + if strings.Contains(query, "1") { + ret = append(ret, &schema.Document{ID: "1"}) + } + if strings.Contains(query, "2") { + ret = append(ret, &schema.Document{ID: "2"}) + } + if strings.Contains(query, "3") { + ret = append(ret, &schema.Document{ID: "3"}) + } + if strings.Contains(query, "4") { + ret = append(ret, &schema.Document{ID: "4"}) + } + if strings.Contains(query, "5") { + ret = append(ret, &schema.Document{ID: "5"}) + } + return ret, nil +} + +type mockModel struct { +} + +func (m *mockModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + return &schema.Message{ + Content: "12\n23\n34\n14\n23\n45", + }, nil +} + +func (m *mockModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + panic("implement me") +} + +func (m *mockModel) BindTools(tools []*schema.ToolInfo) error { + panic("implement me") +} + +func TestMultiQueryRetriever(t *testing.T) { + ctx := context.Background() + + // use default llm + mqr, err := NewRetriever(ctx, &Config{ + RewriteLLM: &mockModel{}, + OrigRetriever: &mockRetriever{}, + }) + if err != nil { + t.Fatal(err) + } + c := compose.NewChain[string, []*schema.Document]() + cr, err := c.AppendRetriever(mqr).Compile(ctx) + if err != nil { + t.Fatal(err) + } + + result, err := cr.Invoke(ctx, "query") + if err != nil { + t.Fatal(err) + } + if len(result) != 4 { + t.Fatal("default llm retrieve result is unexpected") + } + + // use custom + mqr, err = NewRetriever(ctx, &Config{ + RewriteHandler: func(ctx context.Context, query string) ([]string, error) { + return []string{"1", "3", "5"}, nil + }, + OrigRetriever: &mockRetriever{}, + }) + if err != nil { + t.Fatal(err) + } + c = compose.NewChain[string, []*schema.Document]() + cr, err = c.AppendRetriever(mqr).Compile(ctx) + if err != nil { + t.Fatal(err) + } + + result, err = cr.Invoke(ctx, "query") + if err != nil { + t.Fatal(err) + } + if len(result) != 3 { + t.Fatal("default llm retrieve result is unexpected") + } +} diff --git a/flow/retriever/router/router.go b/flow/retriever/router/router.go new file mode 100644 index 0000000..3cb131a --- /dev/null +++ b/flow/retriever/router/router.go @@ -0,0 +1,193 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package router + +import ( + "context" + "fmt" + "sort" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/compose" + "github.com/cloudwego/eino/flow/retriever/utils" + "github.com/cloudwego/eino/schema" +) + +var rrf = func(ctx context.Context, result map[string][]*schema.Document) ([]*schema.Document, error) { + if len(result) < 1 { + return nil, fmt.Errorf("no documents") + } + if len(result) == 1 { + for _, docs := range result { + return docs, nil + } + } + + docRankMap := make(map[string]float64) + docMap := make(map[string]*schema.Document) + for _, v := range result { + for i := range v { + docMap[v[i].ID] = v[i] + if _, ok := docRankMap[v[i].ID]; !ok { + docRankMap[v[i].ID] = 1.0 / float64(i+60) + } else { + docRankMap[v[i].ID] += 1.0 / float64(i+60) + } + } + } + docList := make([]*schema.Document, 0, len(docMap)) + for id := range docMap { + docList = append(docList, docMap[id]) + } + + sort.Slice(docList, func(i, j int) bool { + return docRankMap[docList[i].ID] > docRankMap[docList[j].ID] + }) + + return docList, nil +} + +// NewRetriever creates a router retriever. +// router retriever is useful when you want to retrieve documents from multiple retrievers with different queries. +// eg. +// +// routerRetriever := router.NewRetriever(ctx, &router.Config{}) +// docs, err := routerRetriever.Retrieve(ctx, "how to build agent with eino") +// if err != nil { +// ... +// } +// println(docs) +func NewRetriever(ctx context.Context, config *Config) (retriever.Retriever, error) { + if len(config.Retrievers) == 0 { + return nil, fmt.Errorf("retrievers is empty") + } + + router := config.Router + if router == nil { + var retrieverSet []string + for k := range config.Retrievers { + retrieverSet = append(retrieverSet, k) + } + router = func(ctx context.Context, query string) ([]string, error) { + return retrieverSet, nil + } + } + + fusion := config.FusionFunc + if fusion == nil { + fusion = rrf + } + + return &routerRetriever{ + retrievers: config.Retrievers, + router: config.Router, + fusionFunc: fusion, + }, nil +} + +// Config is the config for router retriever. +type Config struct { + // Retrievers is the retrievers to be used. + Retrievers map[string]retriever.Retriever + // Router is the function to route the query to the retrievers. + Router func(ctx context.Context, query string) ([]string, error) + // FusionFunc is the function to fuse the documents from the retrievers. + FusionFunc func(ctx context.Context, result map[string][]*schema.Document) ([]*schema.Document, error) +} + +type routerRetriever struct { + retrievers map[string]retriever.Retriever + router func(ctx context.Context, query string) ([]string, error) + fusionFunc func(ctx context.Context, result map[string][]*schema.Document) ([]*schema.Document, error) +} + +// Retrieve retrieves documents from the router retriever. +func (e *routerRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + routeCtx := ctxWithRouterRunInfo(ctx) + routeCtx = callbacks.OnStart(routeCtx, query) + retrieverNames, err := e.router(routeCtx, query) + if err != nil { + callbacks.OnError(routeCtx, err) + return nil, err + } + if len(retrieverNames) == 0 { + err = fmt.Errorf("no retriever has been selected") + callbacks.OnError(routeCtx, err) + return nil, err + } + callbacks.OnEnd(routeCtx, retrieverNames) + + // retrieve + tasks := make([]*utils.RetrieveTask, len(retrieverNames)) + for i := range retrieverNames { + r, ok := e.retrievers[retrieverNames[i]] + if !ok { + return nil, fmt.Errorf("router output[%s] has not registered", retrieverNames[i]) + } + tasks[i] = &utils.RetrieveTask{ + Name: retrieverNames[i], + Retriever: r, + Query: query, + RetrieveOptions: opts, + } + } + utils.ConcurrentRetrieveWithCallback(ctx, tasks) + result := make(map[string][]*schema.Document) + for i := range tasks { + if tasks[i].Err != nil { + return nil, tasks[i].Err + } + result[tasks[i].Name] = tasks[i].Result + } + + // fusion + fusionCtx := ctxWithFusionRunInfo(ctx) + fusionCtx = callbacks.OnStart(fusionCtx, result) + fusionDocs, err := e.fusionFunc(fusionCtx, result) + if err != nil { + callbacks.OnError(fusionCtx, err) + return nil, err + } + callbacks.OnEnd(fusionCtx, fusionDocs) + return fusionDocs, nil +} + +// GetType returns the type of the retriever (Router). +func (e *routerRetriever) GetType() string { return "Router" } + +func ctxWithRouterRunInfo(ctx context.Context) context.Context { + runInfo := &callbacks.RunInfo{ + Component: compose.ComponentOfLambda, + Type: "Router", + } + + runInfo.Name = runInfo.Type + string(runInfo.Component) + + return callbacks.SwitchRunInfo(ctx, runInfo) +} + +func ctxWithFusionRunInfo(ctx context.Context) context.Context { + runInfo := &callbacks.RunInfo{ + Component: compose.ComponentOfLambda, + Type: "FusionFunc", + } + + runInfo.Name = runInfo.Type + string(runInfo.Component) + + return callbacks.SwitchRunInfo(ctx, runInfo) +} diff --git a/flow/retriever/router/router_test.go b/flow/retriever/router/router_test.go new file mode 100644 index 0000000..edf1af1 --- /dev/null +++ b/flow/retriever/router/router_test.go @@ -0,0 +1,134 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package router + +import ( + "context" + "reflect" + "strings" + "testing" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" +) + +type mockRetriever struct { +} + +func (m *mockRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + var ret []*schema.Document + if strings.Contains(query, "1") { + ret = append(ret, &schema.Document{ID: "1"}) + } + if strings.Contains(query, "2") { + ret = append(ret, &schema.Document{ID: "2"}) + } + if strings.Contains(query, "3") { + ret = append(ret, &schema.Document{ID: "3"}) + } + if strings.Contains(query, "4") { + ret = append(ret, &schema.Document{ID: "4"}) + } + if strings.Contains(query, "5") { + ret = append(ret, &schema.Document{ID: "5"}) + } + return ret, nil +} + +func (m *mockRetriever) GetType() string { + return "Mock" +} + +func TestRouterRetriever(t *testing.T) { + ctx := context.Background() + r, err := NewRetriever(ctx, &Config{ + Retrievers: map[string]retriever.Retriever{ + "1": &mockRetriever{}, + "2": &mockRetriever{}, + "3": &mockRetriever{}, + }, + Router: func(ctx context.Context, query string) ([]string, error) { + return []string{"2", "3"}, nil + }, + FusionFunc: func(ctx context.Context, result map[string][]*schema.Document) ([]*schema.Document, error) { + var ret []*schema.Document + for _, v := range result { + ret = append(ret, v...) + } + return ret, nil + }, + }) + if err != nil { + t.Fatal(err) + } + + handler := callbacks.NewHandlerBuilder(). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + switch info.Name { + case "FusionFuncLambda": + if _, ok := output.([]*schema.Document); !ok { + t.Fatal("FusionFuncLambda output is not a []*schema.Document") + } + case "RouterLambda": + if _, ok := output.([]string); !ok { + t.Fatal("RouterLambda output is not a []string") + } + case "MockRetriever": + if _, ok := output.([]*schema.Document); !ok { + t.Fatal("MockRetriever output is not a []string") + } + default: + t.Fatalf("unknown name: %s", info.Name) + } + return ctx + }). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + t.Fatal(err) + return ctx + }).Build() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{}, handler) + result, err := r.Retrieve(ctx, "3") + if err != nil { + t.Fatal(err) + } + if len(result) != 2 { + t.Fatal("expected 2 results") + } +} + +func TestRRF(t *testing.T) { + doc1 := &schema.Document{ID: "1"} + doc2 := &schema.Document{ID: "2"} + doc3 := &schema.Document{ID: "3"} + doc4 := &schema.Document{ID: "4"} + doc5 := &schema.Document{ID: "5"} + + input := map[string][]*schema.Document{ + "1": {doc1, doc2, doc3, doc4, doc5}, + "2": {doc2, doc3, doc4, doc5, doc1}, + "3": {doc3, doc4, doc5, doc1, doc2}, + } + + result, err := rrf(context.Background(), input) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(result, []*schema.Document{doc3, doc2, doc4, doc1, doc5}) { + t.Fatal("rrf fail") + } +} diff --git a/flow/retriever/utils/utils.go b/flow/retriever/utils/utils.go new file mode 100644 index 0000000..482a38e --- /dev/null +++ b/flow/retriever/utils/utils.go @@ -0,0 +1,83 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "context" + "fmt" + "sync" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" +) + +// RetrieveTask is a task for retrieving documents. +type RetrieveTask struct { + Name string + Retriever retriever.Retriever + Query string + RetrieveOptions []retriever.Option + Result []*schema.Document + Err error +} + +// ConcurrentRetrieveWithCallback concurrently retrieves documents with callback. +func ConcurrentRetrieveWithCallback(ctx context.Context, tasks []*RetrieveTask) { + wg := sync.WaitGroup{} + for i := range tasks { + wg.Add(1) + go func(ctx context.Context, t *RetrieveTask) { + ctx = ctxWithRetrieverRunInfo(ctx, t.Retriever) + + defer func() { + if e := recover(); e != nil { + t.Err = fmt.Errorf("retrieve panic, query: %s, error: %v", t.Query, e) + ctx = callbacks.OnError(ctx, t.Err) + } + wg.Done() + }() + + ctx = callbacks.OnStart(ctx, t.Query) + docs, err := t.Retriever.Retrieve(ctx, t.Query, t.RetrieveOptions...) + if err != nil { + callbacks.OnError(ctx, err) + t.Err = err + return + } + + callbacks.OnEnd(ctx, docs) + t.Result = docs + }(ctx, tasks[i]) + } + wg.Wait() +} + +func ctxWithRetrieverRunInfo(ctx context.Context, r retriever.Retriever) context.Context { + runInfo := &callbacks.RunInfo{ + Component: components.ComponentOfRetriever, + } + + if typ, okk := components.GetType(r); okk { + runInfo.Type = typ + } + + runInfo.Name = runInfo.Type + string(runInfo.Component) + + return callbacks.SwitchRunInfo(ctx, runInfo) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0fbdf2b --- /dev/null +++ b/go.mod @@ -0,0 +1,47 @@ +module github.com/cloudwego/eino + +go 1.18 + +require ( + github.com/bytedance/sonic v1.12.2 + github.com/getkin/kin-openapi v0.118.0 + github.com/nikolalohinski/gonja v1.5.3 + github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f + github.com/smartystreets/goconvey v1.8.1 + github.com/stretchr/testify v1.9.0 + go.uber.org/mock v0.4.0 +) + +require ( + github.com/bytedance/sonic/loader v0.2.0 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-openapi/jsonpointer v0.19.5 // indirect + github.com/go-openapi/swag v0.19.5 // indirect + github.com/goph/emperror v0.17.2 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/invopop/yaml v0.1.0 // indirect + github.com/josharian/intern v1.0.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/mailru/easyjson v0.7.7 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect + github.com/pelletier/go-toml/v2 v2.0.9 // indirect + github.com/perimeterx/marshmallow v1.1.4 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/smarty/assertions v1.15.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/yargevad/filepathx v1.0.0 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect + golang.org/x/sys v0.10.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5d77c9f --- /dev/null +++ b/go.sum @@ -0,0 +1,149 @@ +github.com/airbrake/gobrake v3.6.1+incompatible/go.mod h1:wM4gu3Cn0W0K7GUuVWnlXZU11AGBXMILnrdOU8Kn00o= +github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= +github.com/bugsnag/bugsnag-go v1.4.0/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= +github.com/bugsnag/panicwrap v1.2.0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= +github.com/bytedance/sonic v1.12.2 h1:oaMFuRTpMHYLpCntGca65YWt5ny+wAceDERTkT2L9lg= +github.com/bytedance/sonic v1.12.2/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= +github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= +github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= +github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= +github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= +github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/swag v0.19.5 h1:lTz6Ys4CmqqCQmZPBlbQENR1/GucA2bzYTE12Pw4tFY= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/goph/emperror v0.17.2 h1:yLapQcmEsO0ipe9p5TaN22djm3OFV/TfM/fcYP0/J18= +github.com/goph/emperror v0.17.2/go.mod h1:+ZbQ+fUNO/6FNiUo0ujtMjhgad9Xa6fQL9KhH4LNHic= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc= +github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= +github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= +github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0= +github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/perimeterx/marshmallow v1.1.4 h1:pZLDH9RjlLGGorbXhcaQLhfuV0pFMNfPO55FuFkxqLw= +github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= +github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= +github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= +github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= +github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/internal/gmap/gmap.go b/internal/gmap/gmap.go new file mode 100644 index 0000000..0cc16af --- /dev/null +++ b/internal/gmap/gmap.go @@ -0,0 +1,122 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gmap + +// Concat returns the unions of maps as a new map. +// +// 💡 NOTE: +// +// - Once the key conflicts, the newer value always replace the older one ([DiscardOld]), +// - If the result is an empty set, always return an empty map instead of nil +// +// 🚀 EXAMPLE: +// +// m := map[int]int{1: 1, 2: 2} +// Concat(m, nil) ⏩ map[int]int{1: 1, 2: 2} +// Concat(m, map[int]{3: 3}) ⏩ map[int]int{1: 1, 2: 2, 3: 3} +// Concat(m, map[int]{2: -1}) ⏩ map[int]int{1: 1, 2: -1} // "2:2" is replaced by the newer "2:-1" +// +// 💡 AKA: Merge, Union, Combine +func Concat[K comparable, V any](ms ...map[K]V) map[K]V { + + // FastPath: no map or only one map given. + if len(ms) == 0 { + return make(map[K]V) + } + if len(ms) == 1 { + return cloneWithoutNilCheck(ms[0]) + } + + var maxLen int + for _, m := range ms { + if len(m) > maxLen { + maxLen = len(m) + } + } + ret := make(map[K]V, maxLen) + // FastPath: all maps are empty. + if maxLen == 0 { + return ret + } + + // Concat all maps. + for _, m := range ms { + for k, v := range m { + ret[k] = v + } + } + return ret +} + +// Map applies function f to each key and value of map m. +// Results of f are returned as a new map. +// +// 🚀 EXAMPLE: +// +// f := func(k, v int) (string, string) { return strconv.Itoa(k), strconv.Itoa(v) } +// Map(map[int]int{1: 1}, f) ⏩ map[string]string{"1": "1"} +// Map(map[int]int{}, f) ⏩ map[string]string{} +func Map[K1, K2 comparable, V1, V2 any](m map[K1]V1, f func(K1, V1) (K2, V2)) map[K2]V2 { + r := make(map[K2]V2, len(m)) + for k, v := range m { + k2, v2 := f(k, v) + r[k2] = v2 + } + return r +} + +// Values returns the values of the map m. +// +// 🚀 EXAMPLE: +// +// m := map[int]string{1: "1", 2: "2", 3: "3", 4: "4"} +// Values(m) ⏩ []string{"1", "4", "2", "3"} //⚠️INDETERMINATE ORDER⚠️ +// +// ⚠️ WARNING: The keys values be in an indeterminate order, +func Values[K comparable, V any](m map[K]V) []V { + r := make([]V, 0, len(m)) + for _, v := range m { + r = append(r, v) + } + return r +} + +// Clone returns a shallow copy of map. +// If the given map is nil, nil is returned. +// +// 🚀 EXAMPLE: +// +// Clone(map[int]int{1: 1, 2: 2}) ⏩ map[int]int{1: 1, 2: 2} +// Clone(map[int]int{}) ⏩ map[int]int{} +// Clone[int, int](nil) ⏩ nil +// +// 💡 HINT: Both keys and values are copied using assignment (=), so this is a shallow clone. +// 💡 AKA: Copy +func Clone[K comparable, V any, M ~map[K]V](m M) M { + if m == nil { + return nil + } + return cloneWithoutNilCheck(m) +} + +func cloneWithoutNilCheck[K comparable, V any, M ~map[K]V](m M) M { + r := make(M, len(m)) + for k, v := range m { + r[k] = v + } + return r +} diff --git a/internal/gmap/gmap_test.go b/internal/gmap/gmap_test.go new file mode 100644 index 0000000..1c0f1ef --- /dev/null +++ b/internal/gmap/gmap_test.go @@ -0,0 +1,89 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gmap + +import ( + "fmt" + "sort" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMerge(t *testing.T) { + assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, + Concat(map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, nil)) + assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, + Concat[int, int](nil, map[int]int{1: 1, 2: 2, 3: 3, 4: 4})) + assert.Equal(t, map[int]int{}, Concat[int, int](nil, nil)) + assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, + Concat(map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, map[int]int{1: 1, 2: 2, 3: 3, 4: 4})) + assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, + Concat(map[int]int{1: 0, 2: 0}, map[int]int{1: 1, 2: 2, 3: 3, 4: 4})) + assert.Equal(t, map[int]int{1: 1, 2: 2, 3: 3, 4: 4}, + Concat(map[int]int{1: 1, 2: 1}, map[int]int{2: 2, 3: 3, 4: 4})) +} + +func TestMap(t *testing.T) { + assert.Equal(t, + map[string]string{"1": "1", "2": "2"}, + Map(map[int]int{1: 1, 2: 2}, func(k, v int) (string, string) { + return strconv.Itoa(k), strconv.Itoa(v) + })) + assert.Equal(t, + map[string]string{}, + Map(map[int]int{}, func(k, v int) (string, string) { + return strconv.Itoa(k), strconv.Itoa(v) + })) +} + +func TestValues(t *testing.T) { + { + keys := Values(map[int]string{1: "1", 2: "2", 3: "3", 4: "4"}) + sort.Strings(keys) + assert.Equal(t, []string{"1", "2", "3", "4"}, keys) + } + assert.Equal(t, []string{}, Values(map[int]string{})) + assert.Equal(t, []string{}, Values[int, string](nil)) +} + +func TestClone(t *testing.T) { + assert.Equal(t, map[int]int{1: 1, 2: 2}, Clone(map[int]int{1: 1, 2: 2})) + var nilMap map[int]int + assert.Equal(t, map[int]int{}, Clone(map[int]int{})) + assert.NotEqual(t, (map[int]int)(nil), Clone(map[int]int{})) + assert.Equal(t, (map[int]int)(nil), Clone(nilMap)) + assert.NotEqual(t, map[int]int{}, Clone(nilMap)) + + // Test new type. + type I2I map[int]int + assert.Equal(t, I2I{1: 1, 2: 2}, Clone(I2I{1: 1, 2: 2})) + assert.Equal(t, "gmap.I2I", fmt.Sprintf("%T", Clone(I2I{}))) + + // Test shallow clone. + src := map[int]*int{1: ptr(1), 2: ptr(2)} + dst := Clone(src) + assert.Equal(t, src, dst) + assert.True(t, src[1] == dst[1]) + assert.True(t, src[2] == dst[2]) +} + +// Ptr returns a pointer to the given value. +func ptr[T any](v T) *T { + return &v +} diff --git a/internal/gslice/gslice.go b/internal/gslice/gslice.go new file mode 100644 index 0000000..a9c05fe --- /dev/null +++ b/internal/gslice/gslice.go @@ -0,0 +1,39 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gslice + +// ToMap collects elements of slice to map, both map keys and values are produced +// by mapping function f. +// +// 🚀 EXAMPLE: +// +// type Foo struct { +// ID int +// Name string +// } +// mapper := func(f Foo) (int, string) { return f.ID, f.Name } +// ToMap([]Foo{}, mapper) ⏩ map[int]string{} +// s := []Foo{{1, "one"}, {2, "two"}, {3, "three"}} +// ToMap(s, mapper) ⏩ map[int]string{1: "one", 2: "two", 3: "three"} +func ToMap[T, V any, K comparable](s []T, f func(T) (K, V)) map[K]V { + m := make(map[K]V, len(s)) + for _, e := range s { + k, v := f(e) + m[k] = v + } + return m +} diff --git a/internal/gslice/gslice_test.go b/internal/gslice/gslice_test.go new file mode 100644 index 0000000..b55c854 --- /dev/null +++ b/internal/gslice/gslice_test.go @@ -0,0 +1,36 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gslice + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestToMap(t *testing.T) { + type Foo struct { + ID int + Name string + } + mapper := func(f Foo) (int, string) { return f.ID, f.Name } + assert.Equal(t, map[int]string{}, ToMap([]Foo{}, mapper)) + assert.Equal(t, map[int]string{}, ToMap(nil, mapper)) + assert.Equal(t, + map[int]string{1: "one", 2: "two", 3: "three"}, + ToMap([]Foo{{1, "one"}, {2, "two"}, {3, "three"}}, mapper)) +} diff --git a/internal/mock/components/document/document_mock.go b/internal/mock/components/document/document_mock.go new file mode 100644 index 0000000..696808b --- /dev/null +++ b/internal/mock/components/document/document_mock.go @@ -0,0 +1,159 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go + +// Package document is a generated GoMock package. +package document + +import ( + context "context" + reflect "reflect" + + document "github.com/cloudwego/eino/components/document" + schema "github.com/cloudwego/eino/schema" + gomock "go.uber.org/mock/gomock" +) + +// MockLoaderSplitter is a mock of LoaderSplitter interface. +type MockLoaderSplitter struct { + ctrl *gomock.Controller + recorder *MockLoaderSplitterMockRecorder +} + +// MockLoaderSplitterMockRecorder is the mock recorder for MockLoaderSplitter. +type MockLoaderSplitterMockRecorder struct { + mock *MockLoaderSplitter +} + +// NewMockLoaderSplitter creates a new mock instance. +func NewMockLoaderSplitter(ctrl *gomock.Controller) *MockLoaderSplitter { + mock := &MockLoaderSplitter{ctrl: ctrl} + mock.recorder = &MockLoaderSplitterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLoaderSplitter) EXPECT() *MockLoaderSplitterMockRecorder { + return m.recorder +} + +// LoadAndSplit mocks base method. +func (m *MockLoaderSplitter) LoadAndSplit(ctx context.Context, src document.Source, opts ...document.LoaderSplitterOption) ([]*schema.Document, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, src} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "LoadAndSplit", varargs...) + ret0, _ := ret[0].([]*schema.Document) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadAndSplit indicates an expected call of LoadAndSplit. +func (mr *MockLoaderSplitterMockRecorder) LoadAndSplit(ctx, src interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, src}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAndSplit", reflect.TypeOf((*MockLoaderSplitter)(nil).LoadAndSplit), varargs...) +} + +// MockLoader is a mock of Loader interface. +type MockLoader struct { + ctrl *gomock.Controller + recorder *MockLoaderMockRecorder +} + +// MockLoaderMockRecorder is the mock recorder for MockLoader. +type MockLoaderMockRecorder struct { + mock *MockLoader +} + +// NewMockLoader creates a new mock instance. +func NewMockLoader(ctrl *gomock.Controller) *MockLoader { + mock := &MockLoader{ctrl: ctrl} + mock.recorder = &MockLoaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLoader) EXPECT() *MockLoaderMockRecorder { + return m.recorder +} + +// Load mocks base method. +func (m *MockLoader) Load(ctx context.Context, src document.Source, opts ...document.LoaderOption) ([]*schema.Document, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, src} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Load", varargs...) + ret0, _ := ret[0].([]*schema.Document) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Load indicates an expected call of Load. +func (mr *MockLoaderMockRecorder) Load(ctx, src interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, src}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockLoader)(nil).Load), varargs...) +} + +// MockTransformer is a mock of Transformer interface. +type MockTransformer struct { + ctrl *gomock.Controller + recorder *MockTransformerMockRecorder +} + +// MockTransformerMockRecorder is the mock recorder for MockTransformer. +type MockTransformerMockRecorder struct { + mock *MockTransformer +} + +// NewMockTransformer creates a new mock instance. +func NewMockTransformer(ctrl *gomock.Controller) *MockTransformer { + mock := &MockTransformer{ctrl: ctrl} + mock.recorder = &MockTransformerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTransformer) EXPECT() *MockTransformerMockRecorder { + return m.recorder +} + +// Transform mocks base method. +func (m *MockTransformer) Transform(ctx context.Context, src []*schema.Document, opts ...document.TransformerOption) ([]*schema.Document, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, src} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Transform", varargs...) + ret0, _ := ret[0].([]*schema.Document) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Transform indicates an expected call of Transform. +func (mr *MockTransformerMockRecorder) Transform(ctx, src interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, src}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transform", reflect.TypeOf((*MockTransformer)(nil).Transform), varargs...) +} diff --git a/internal/mock/components/embedding/Embedding_mock.go b/internal/mock/components/embedding/Embedding_mock.go new file mode 100644 index 0000000..f0c1063 --- /dev/null +++ b/internal/mock/components/embedding/Embedding_mock.go @@ -0,0 +1,77 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go +// +// Generated by this command: +// +// mockgen -destination ../../internal/mock/components/embedding/Embedding_mock.go --package embedding -source interface.go +// + +// Package embedding is a generated GoMock package. +package embedding + +import ( + context "context" + reflect "reflect" + + embedding "github.com/cloudwego/eino/components/embedding" + gomock "go.uber.org/mock/gomock" +) + +// MockEmbedder is a mock of Embedder interface. +type MockEmbedder struct { + ctrl *gomock.Controller + recorder *MockEmbedderMockRecorder +} + +// MockEmbedderMockRecorder is the mock recorder for MockEmbedder. +type MockEmbedderMockRecorder struct { + mock *MockEmbedder +} + +// NewMockEmbedder creates a new mock instance. +func NewMockEmbedder(ctrl *gomock.Controller) *MockEmbedder { + mock := &MockEmbedder{ctrl: ctrl} + mock.recorder = &MockEmbedderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEmbedder) EXPECT() *MockEmbedderMockRecorder { + return m.recorder +} + +// EmbedStrings mocks base method. +func (m *MockEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, texts} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "EmbedStrings", varargs...) + ret0, _ := ret[0].([][]float64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EmbedStrings indicates an expected call of EmbedStrings. +func (mr *MockEmbedderMockRecorder) EmbedStrings(ctx, texts any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, texts}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EmbedStrings", reflect.TypeOf((*MockEmbedder)(nil).EmbedStrings), varargs...) +} diff --git a/internal/mock/components/indexer/indexer_mock.go b/internal/mock/components/indexer/indexer_mock.go new file mode 100644 index 0000000..829af22 --- /dev/null +++ b/internal/mock/components/indexer/indexer_mock.go @@ -0,0 +1,78 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go +// +// Generated by this command: +// +// mockgen -destination ../../internal/mock/components/indexer/indexer_mock.go --package indexer -source interface.go +// + +// Package indexer is a generated GoMock package. +package indexer + +import ( + context "context" + reflect "reflect" + + indexer "github.com/cloudwego/eino/components/indexer" + schema "github.com/cloudwego/eino/schema" + gomock "go.uber.org/mock/gomock" +) + +// MockIndexer is a mock of Indexer interface. +type MockIndexer struct { + ctrl *gomock.Controller + recorder *MockIndexerMockRecorder +} + +// MockIndexerMockRecorder is the mock recorder for MockIndexer. +type MockIndexerMockRecorder struct { + mock *MockIndexer +} + +// NewMockIndexer creates a new mock instance. +func NewMockIndexer(ctrl *gomock.Controller) *MockIndexer { + mock := &MockIndexer{ctrl: ctrl} + mock.recorder = &MockIndexerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIndexer) EXPECT() *MockIndexerMockRecorder { + return m.recorder +} + +// Store mocks base method. +func (m *MockIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) ([]string, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, docs} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Store", varargs...) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Store indicates an expected call of Store. +func (mr *MockIndexerMockRecorder) Store(ctx, docs any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, docs}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Store", reflect.TypeOf((*MockIndexer)(nil).Store), varargs...) +} diff --git a/internal/mock/components/model/ChatModel_mock.go b/internal/mock/components/model/ChatModel_mock.go new file mode 100644 index 0000000..b56dcc1 --- /dev/null +++ b/internal/mock/components/model/ChatModel_mock.go @@ -0,0 +1,112 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go +// +// Generated by this command: +// +// mockgen -destination ../../internal/mock/components/model/ChatModel_mock.go --package model -source interface.go +// + +// Package model is a generated GoMock package. +package model + +import ( + context "context" + reflect "reflect" + + model "github.com/cloudwego/eino/components/model" + schema "github.com/cloudwego/eino/schema" + gomock "go.uber.org/mock/gomock" +) + +// MockChatModel is a mock of ChatModel interface. +type MockChatModel struct { + ctrl *gomock.Controller + recorder *MockChatModelMockRecorder +} + +// MockChatModelMockRecorder is the mock recorder for MockChatModel. +type MockChatModelMockRecorder struct { + mock *MockChatModel +} + +// NewMockChatModel creates a new mock instance. +func NewMockChatModel(ctrl *gomock.Controller) *MockChatModel { + mock := &MockChatModel{ctrl: ctrl} + mock.recorder = &MockChatModelMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChatModel) EXPECT() *MockChatModelMockRecorder { + return m.recorder +} + +// BindTools mocks base method. +func (m *MockChatModel) BindTools(tools []*schema.ToolInfo) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BindTools", tools) + ret0, _ := ret[0].(error) + return ret0 +} + +// BindTools indicates an expected call of BindTools. +func (mr *MockChatModelMockRecorder) BindTools(tools any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindTools", reflect.TypeOf((*MockChatModel)(nil).BindTools), tools) +} + +// Generate mocks base method. +func (m *MockChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, input} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Generate", varargs...) + ret0, _ := ret[0].(*schema.Message) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Generate indicates an expected call of Generate. +func (mr *MockChatModelMockRecorder) Generate(ctx, input any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, input}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Generate", reflect.TypeOf((*MockChatModel)(nil).Generate), varargs...) +} + +// Stream mocks base method. +func (m *MockChatModel) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) { + m.ctrl.T.Helper() + varargs := []any{ctx, input} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Stream", varargs...) + ret0, _ := ret[0].(*schema.StreamReader[*schema.Message]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Stream indicates an expected call of Stream. +func (mr *MockChatModelMockRecorder) Stream(ctx, input any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, input}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stream", reflect.TypeOf((*MockChatModel)(nil).Stream), varargs...) +} diff --git a/internal/mock/components/retriever/retriever_mock.go b/internal/mock/components/retriever/retriever_mock.go new file mode 100644 index 0000000..386371a --- /dev/null +++ b/internal/mock/components/retriever/retriever_mock.go @@ -0,0 +1,78 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: interface.go +// +// Generated by this command: +// +// mockgen -destination ../../internal/mock/components/retriever/retriever_mock.go --package retriever -source interface.go +// + +// Package retriever is a generated GoMock package. +package retriever + +import ( + context "context" + reflect "reflect" + + retriever "github.com/cloudwego/eino/components/retriever" + schema "github.com/cloudwego/eino/schema" + gomock "go.uber.org/mock/gomock" +) + +// MockRetriever is a mock of Retriever interface. +type MockRetriever struct { + ctrl *gomock.Controller + recorder *MockRetrieverMockRecorder +} + +// MockRetrieverMockRecorder is the mock recorder for MockRetriever. +type MockRetrieverMockRecorder struct { + mock *MockRetriever +} + +// NewMockRetriever creates a new mock instance. +func NewMockRetriever(ctrl *gomock.Controller) *MockRetriever { + mock := &MockRetriever{ctrl: ctrl} + mock.recorder = &MockRetrieverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRetriever) EXPECT() *MockRetrieverMockRecorder { + return m.recorder +} + +// Retrieve mocks base method. +func (m *MockRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, query} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Retrieve", varargs...) + ret0, _ := ret[0].([]*schema.Document) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Retrieve indicates an expected call of Retrieve. +func (mr *MockRetrieverMockRecorder) Retrieve(ctx, query any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, query}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retrieve", reflect.TypeOf((*MockRetriever)(nil).Retrieve), varargs...) +} diff --git a/internal/mock/doc.go b/internal/mock/doc.go new file mode 100644 index 0000000..020235d --- /dev/null +++ b/internal/mock/doc.go @@ -0,0 +1,51 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package mock provides mock implementations for testing purposes. +// +// This package aims to provide mock implementations for interfaces in the components package, +// making it easier to use in testing environments. It includes mock implementations for +// various core components such as retrievers, tools, message handlers, and graph runners. +// +// Directory Structure: +// - components/: Contains mock implementations for various components +// - retriever/: Provides mock implementation for the Retriever interface +// - retriever_mock.go: Mock implementation for document retrieval +// - tool/: Mock implementations for tool-related interfaces +// - message/: Mock implementations for message handling components +// - graph/: Mock implementations for graph execution components +// - stream/: Mock implementations for streaming components +// +// Usage: +// These mock implementations are primarily used in unit tests and integration tests, +// allowing developers to conduct tests without depending on actual external services. +// Each mock component strictly follows the contract of its corresponding interface +// while providing controllable behaviors and results. +// +// Examples: +// +// - Using mock retriever: +// retriever := mock.NewMockRetriever() +// // Configure retriever behavior +// +// - Using mock tool: +// tool := mock.NewMockTool() +// // Configure tool behavior +// +// - Using mock graph runner: +// runner := mock.NewMockGraphRunner() +// // Configure runner behavior +package mock diff --git a/profile/README.md b/profile/README.md deleted file mode 100644 index 2127160..0000000 --- a/profile/README.md +++ /dev/null @@ -1,13 +0,0 @@ -## Hi there 👋 - -🙋‍♀️ A short introduction - CloudWeGo is an open-source middleware set launched by ByteDance that can be used to quickly build enterprise-class cloud native architectures. The common characteristics of CloudWeGo projects are high performance, high scalability, high reliability and focusing on microservices communication and governance. - -🌈 Community Membership - the [Responsibilities and Requirements](https://github.com/cloudwego/community/blob/main/COMMUNITY_MEMBERSHIP.md) of contributor roles in CloudWeGo. - -👩‍💻 Useful resources - [Portal](https://www.cloudwego.io/), [Community](https://www.cloudwego.io/zh/community/), [Blogs](https://www.cloudwego.io/zh/blog/), [Use Cases](https://www.cloudwego.io/zh/cooperation/) - -🍿 Security - [Vulnerability Reporting](https://www.cloudwego.io/zh/security/vulnerability-reporting/), [Safety Bulletin](https://www.cloudwego.io/zh/security/safety-bulletin/) - -🌲 Ecosystem - [Kitex-contrib](https://github.com/kitex-contrib), [Hertz-contrib](https://github.com/hertz-contrib), [Volo-rs](https://github.com/volo-rs) - -🎊 Example - [kitex-example](https://github.com/cloudwego/kitex-examples), [hertz-example](https://github.com/cloudwego/hertz-examples), [biz-demo](https://github.com/cloudwego/biz-demo), [netpoll-example](https://github.com/cloudwego/netpoll-examples) diff --git a/schema/doc.go b/schema/doc.go new file mode 100644 index 0000000..592c8e5 --- /dev/null +++ b/schema/doc.go @@ -0,0 +1,17 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema diff --git a/schema/document.go b/schema/document.go new file mode 100644 index 0000000..80f58d3 --- /dev/null +++ b/schema/document.go @@ -0,0 +1,175 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +const ( + metaDataKeySubIndexes = "_sub_indexes" + metaDataKeyScore = "_score" + metaDataKeyVikingExtraInfo = "_viking_extra_info" + metaDataKeyVikingDSL = "_viking_dsl" + metaDataKeyVector = "_vector" +) + +// Document is a piece of text with metadata. +type Document struct { + // ID is the unique identifier of the document. + ID string `json:"id"` + // Content is the content of the document. + Content string `json:"content"` + // MetaData is the metadata of the document, can be used to store extra information. + MetaData map[string]any `json:"meta_data"` +} + +// String returns the content of the document. +func (d *Document) String() string { + return d.Content +} + +// WithSubIndexes sets the sub indexes of the document. +// can use doc.SubIndexes() to get the sub indexes, useful for search engine to use sub indexes to search. +func (d *Document) WithSubIndexes(indexes []string) *Document { + if d.MetaData == nil { + d.MetaData = make(map[string]any) + } + + d.MetaData[metaDataKeySubIndexes] = indexes + + return d +} + +// SubIndexes returns the sub indexes of the document. +// can use doc.WithSubIndexes() to set the sub indexes. +func (d *Document) SubIndexes() []string { + if d.MetaData == nil { + return nil + } + + indexes, ok := d.MetaData[metaDataKeySubIndexes].([]string) + if ok { + return indexes + } + + return nil +} + +// WithScore sets the score of the document. +// can use doc.Score() to get the score. +func (d *Document) WithScore(score float64) *Document { + if d.MetaData == nil { + d.MetaData = make(map[string]any) + } + + d.MetaData[metaDataKeyScore] = score + + return d +} + +// Score returns the score of the document. +// can use doc.WithScore() to set the score. +func (d *Document) Score() float64 { + if d.MetaData == nil { + return 0 + } + + score, ok := d.MetaData[metaDataKeyScore].(float64) + if ok { + return score + } + + return 0 +} + +// WithVikingExtraInfo sets the extra info of the document. +// can use doc.VikingExtraInfo() to get the extra info. +func (d *Document) WithVikingExtraInfo(extraInfo string) *Document { + if d.MetaData == nil { + d.MetaData = make(map[string]any) + } + + d.MetaData[metaDataKeyVikingExtraInfo] = extraInfo + + return d +} + +// VikingExtraInfo returns the extra info of the document. +// can use doc.WithVikingExtraInfo() to set the extra info. +func (d *Document) VikingExtraInfo() string { + if d.MetaData == nil { + return "" + } + + extraInfo, ok := d.MetaData[metaDataKeyVikingExtraInfo].(string) + if ok { + return extraInfo + } + + return "" +} + +// WithVikingDSLInfo sets the dsl info of the document. +// can use doc.VikingDSLInfo() to get the dsl info. +func (d *Document) WithVikingDSLInfo(dslInfo map[string]any) *Document { + if d.MetaData == nil { + d.MetaData = make(map[string]any) + } + + d.MetaData[metaDataKeyVikingDSL] = dslInfo + + return d +} + +// VikingDSLInfo returns the dsl info of the document. +// can use doc.WithVikingDSLInfo() to set the dsl info. +func (d *Document) VikingDSLInfo() map[string]any { + if d.MetaData == nil { + return nil + } + + dslInfo, ok := d.MetaData[metaDataKeyVikingDSL].(map[string]any) + if ok { + return dslInfo + } + + return nil +} + +// WithVector sets the vector of the document. +// can use doc.Vector() to get the vector. +func (d *Document) WithVector(vector []float64) *Document { + if d.MetaData == nil { + d.MetaData = make(map[string]any) + } + + d.MetaData[metaDataKeyVector] = vector + + return d +} + +// Vector returns the vector of the document. +// can use doc.WithVector() to set the vector. +func (d *Document) Vector() []float64 { + if d.MetaData == nil { + return nil + } + + vector, ok := d.MetaData[metaDataKeyVector].([]float64) + if ok { + return vector + } + + return nil +} diff --git a/schema/document_test.go b/schema/document_test.go new file mode 100644 index 0000000..f9a13cc --- /dev/null +++ b/schema/document_test.go @@ -0,0 +1,53 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "testing" + + "github.com/smartystreets/goconvey/convey" +) + +func TestDocument(t *testing.T) { + convey.Convey("test document", t, func() { + var ( + subIndexes = []string{"hello", "bye"} + score = 1.1 + extraInfo = "asd" + dslInfo = map[string]any{"hello": true} + vector = []float64{1.1, 2.2} + ) + + d := &Document{ + ID: "asd", + Content: "qwe", + MetaData: nil, + } + + d.WithSubIndexes(subIndexes). + WithVector(vector). + WithScore(score). + WithVikingExtraInfo(extraInfo). + WithVikingDSLInfo(dslInfo) + + convey.So(d.SubIndexes(), convey.ShouldEqual, subIndexes) + convey.So(d.Score(), convey.ShouldEqual, score) + convey.So(d.VikingExtraInfo(), convey.ShouldEqual, extraInfo) + convey.So(d.VikingDSLInfo(), convey.ShouldEqual, dslInfo) + convey.So(d.Vector(), convey.ShouldEqual, vector) + }) +} diff --git a/schema/message.go b/schema/message.go new file mode 100644 index 0000000..c99f840 --- /dev/null +++ b/schema/message.go @@ -0,0 +1,705 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "fmt" + "reflect" + "sort" + "strings" + "sync" + "text/template" + + "github.com/nikolalohinski/gonja" + "github.com/nikolalohinski/gonja/config" + "github.com/nikolalohinski/gonja/nodes" + "github.com/nikolalohinski/gonja/parser" + "github.com/slongfield/pyfmt" + + "github.com/cloudwego/eino/internal/gmap" +) + +// FormatType used by MessageTemplate.Format +type FormatType uint8 + +const ( + // FString Supported by pyfmt(github.com/slongfield/pyfmt), which is an implementation of https://peps.python.org/pep-3101/. + FString FormatType = 0 + // GoTemplate https://pkg.go.dev/text/template. + GoTemplate FormatType = 1 + // Jinja2 Supported by gonja(github.com/nikolalohinski/gonja), which is a implementation of https://jinja.palletsprojects.com/en/3.1.x/templates/. + Jinja2 FormatType = 2 +) + +// RoleType is the type of the role of a message. +type RoleType string + +const ( + // Assistant is the role of an assistant, means the message is returned by ChatModel. + Assistant RoleType = "assistant" + // User is the role of a user, means the message is a user message. + User RoleType = "user" + // System is the role of a system, means the message is a system message. + System RoleType = "system" + // Tool is the role of a tool, means the message is a tool call output. + Tool RoleType = "tool" +) + +// FunctionCall is the function call in a message. +// It's used in Assistant Message. +type FunctionCall struct { + // Name is the name of the function to call, it can be used to identify the specific function. + Name string `json:"name,omitempty"` + // Arguments is the arguments to call the function with, in JSON format. + Arguments string `json:"arguments,omitempty"` +} + +// ToolCall is the tool call in a message. +// It's used in Assistant Message when there are tool calls should be made. +type ToolCall struct { + // Index is used when there are multiple tool calls in a message. + // In stream mode, it's used to identify the chunk of the tool call for merging. + Index *int `json:"index,omitempty"` + // ID is the id of the tool call, it can be used to identify the specific tool call. + ID string `json:"id"` + // Type is the type of the tool call, default is "function". + Type string `json:"type"` + // Function is the function call to be made. + Function FunctionCall `json:"function"` + + // Extra is used to store extra information for the tool call. + Extra map[string]any `json:"extra,omitempty"` +} + +// ImageURLDetail is the detail of the image url. +type ImageURLDetail string + +const ( + // ImageURLDetailHigh means the high quality image url. + ImageURLDetailHigh ImageURLDetail = "high" + // ImageURLDetailLow means the low quality image url. + ImageURLDetailLow ImageURLDetail = "low" + // ImageURLDetailAuto means the auto quality image url. + ImageURLDetailAuto ImageURLDetail = "auto" +) + +// ChatMessageImageURL is used to represent an image part in a chat message. +// Choose either URL or URI. +// If your model implementation supports it, URL could be used to embed inline image data as defined in RFC-2397. +type ChatMessageImageURL struct { + // URL can either be a traditional URL or a special URL conforming to RFC-2397 (https://www.rfc-editor.org/rfc/rfc2397). + // double check with model implementations for detailed instructions on how to use this. + URL string `json:"url,omitempty"` + URI string `json:"uri,omitempty"` + // Detail is the quality of the image url. + Detail ImageURLDetail `json:"detail,omitempty"` + + // MIMEType is the mime type of the image, eg. "image/png". + MIMEType string `json:"mime_type,omitempty"` + // Extra is used to store extra information for the image url. + Extra map[string]any `json:"extra,omitempty"` +} + +// ChatMessagePartType is the type of the part in a chat message. +type ChatMessagePartType string + +const ( + // ChatMessagePartTypeText means the part is a text. + ChatMessagePartTypeText ChatMessagePartType = "text" + // ChatMessagePartTypeImageURL means the part is an image url. + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" + // ChatMessagePartTypeAudioURL means the part is an audio url. + ChatMessagePartTypeAudioURL ChatMessagePartType = "audio_url" + // ChatMessagePartTypeVideoURL means the part is a video url. + ChatMessagePartTypeVideoURL ChatMessagePartType = "video_url" + // ChatMessagePartTypeFileURL means the part is a file url. + ChatMessagePartTypeFileURL ChatMessagePartType = "file_url" +) + +// ChatMessageAudioURL is used to represent an audio part in a chat message. +// Choose either URL or URI. +// If your model implementation supports it, URL could be used to embed inline audio data as defined in RFC-2397. +type ChatMessageAudioURL struct { + // URL can either be a traditional URL or a special URL conforming to RFC-2397 (https://www.rfc-editor.org/rfc/rfc2397). + // double check with model implementations for detailed instructions on how to use this. + URL string `json:"url,omitempty"` + URI string `json:"uri,omitempty"` + + // MIMEType is the mime type of the audio, eg. "audio/wav" or "audio/ogg". + MIMEType string `json:"mime_type,omitempty"` + // Extra is used to store extra information for the audio url. + Extra map[string]any `json:"extra,omitempty"` +} + +// ChatMessageVideoURL is used to represent an video part in a chat message. +// Choose either URL or URI. +// If your model implementation supports it, URL could be used to embed inline video data as defined in RFC-2397. +type ChatMessageVideoURL struct { + // URL can either be a traditional URL or a special URL conforming to RFC-2397 (https://www.rfc-editor.org/rfc/rfc2397). + // double check with model implementations for detailed instructions on how to use this. + URL string `json:"url,omitempty"` + URI string `json:"uri,omitempty"` + + // MIMEType is the mime type of the video, eg. "video/mp4". + MIMEType string `json:"mime_type,omitempty"` + // Extra is used to store extra information for the video url. + Extra map[string]any `json:"extra,omitempty"` +} + +// ChatMessageFileURL is used to represent an file part in a chat message. +// Choose either URL or URI. +type ChatMessageFileURL struct { + URL string `json:"url,omitempty"` + URI string `json:"uri,omitempty"` + + // MIMEType is the mime type of the file, eg. "application/pdf", "text/plain". + MIMEType string `json:"mime_type,omitempty"` + // Name is the name of the file. + Name string `json:"name,omitempty"` + + // Extra is used to store extra information for the file url. + Extra map[string]any `json:"extra,omitempty"` +} + +// ChatMessagePart is the part in a chat message. +type ChatMessagePart struct { + // Type is the type of the part, eg. "text", "image_url", "audio_url", "video_url", "file_url". + Type ChatMessagePartType `json:"type,omitempty"` + + // Text is the text of the part, it's used when Type is "text". + Text string `json:"text,omitempty"` + + // ImageURL is the image url of the part, it's used when Type is "image_url". + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + // AudioURL is the audio url of the part, it's used when Type is "audio_url". + AudioURL *ChatMessageAudioURL `json:"audio_url,omitempty"` + // VideoURL is the video url of the part, it's used when Type is "video_url". + VideoURL *ChatMessageVideoURL `json:"video_url,omitempty"` + // FileURL is the file url of the part, it's used when Type is "file_url". + FileURL *ChatMessageFileURL `json:"file_url,omitempty"` +} + +// ResponseMeta collects meta information about a chat response. +type ResponseMeta struct { + // FinishReason is the reason why the chat response is finished. + // It's usually "stop", "length", "tool_calls", "content_filter", "null". This is defined by chat model implementation. + FinishReason string `json:"finish_reason,omitempty"` + // Usage is the token usage of the chat response, whether usage exists depends on whether the chat model implementation returns. + Usage *TokenUsage `json:"usage,omitempty"` +} + +type Message struct { + Role RoleType `json:"role"` + Content string `json:"content"` + + // if MultiContent is not empty, use this instead of Content + // if MultiContent is empty, use Content + MultiContent []ChatMessagePart `json:"multi_content,omitempty"` + + Name string `json:"name,omitempty"` + + // only for AssistantMessage + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // only for ToolMessage + ToolCallID string `json:"tool_call_id,omitempty"` + + ResponseMeta *ResponseMeta `json:"response_meta,omitempty"` + + // customized information for model implementation + Extra map[string]any `json:"extra,omitempty"` +} + +// TokenUsage Represents the token usage of chat model request. +type TokenUsage struct { + // PromptTokens is the number of tokens in the prompt. + PromptTokens int `json:"prompt_tokens"` + // CompletionTokens is the number of tokens in the completion. + CompletionTokens int `json:"completion_tokens"` + // TotalTokens is the total number of tokens in the request. + TotalTokens int `json:"total_tokens"` +} + +var _ MessagesTemplate = &Message{} +var _ MessagesTemplate = MessagesPlaceholder("", false) + +// MessagesTemplate is the interface for messages template. +// It's used to render a template to a list of messages. +// e.g. +// +// chatTemplate := prompt.FromMessages( +// schema.SystemMessage("you are eino helper"), +// schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +type MessagesTemplate interface { + Format(ctx context.Context, vs map[string]any, formatType FormatType) ([]*Message, error) +} + +type messagesPlaceholder struct { + key string + optional bool +} + +// MessagesPlaceholder can render a placeholder to a list of messages in params. +// e.g. +// +// placeholder := MessagesPlaceholder("history", false) +// params := map[string]any{ +// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, +// "query": "how to use eino?", +// } +// chatTemplate := chatTpl := prompt.FromMessages( +// schema.SystemMessage("you are eino helper"), +// schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +func MessagesPlaceholder(key string, optional bool) MessagesTemplate { + return &messagesPlaceholder{ + key: key, + optional: optional, + } +} + +// Format just return the messages of specified key. +// because it's a placeholder. +// e.g. +// +// placeholder := MessagesPlaceholder("history", false) +// params := map[string]any{ +// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, +// "query": "how to use eino?", +// } +// msgs, err := placeholder.Format(ctx, params) // <= this will return the value of "history" in params +func (p *messagesPlaceholder) Format(_ context.Context, vs map[string]any, _ FormatType) ([]*Message, error) { + v, ok := vs[p.key] + if !ok { + if p.optional { + return []*Message{}, nil + } + + return nil, fmt.Errorf("message placeholder format: %s not found", p.key) + } + + msgs, ok := v.([]*Message) + if !ok { + return nil, fmt.Errorf("only messages can be used to format message placeholder, key: %v, actual type: %v", p.key, reflect.TypeOf(v)) + } + + return msgs, nil +} + +func formatContent(content string, vs map[string]any, formatType FormatType) (string, error) { + switch formatType { + case FString: + return pyfmt.Fmt(content, vs) + case GoTemplate: + parsedTmpl, err := template.New("template"). + Option("missingkey=error"). + Parse(content) + if err != nil { + return "", err + } + sb := new(strings.Builder) + err = parsedTmpl.Execute(sb, vs) + if err != nil { + return "", err + } + return sb.String(), nil + case Jinja2: + env, err := getJinjaEnv() + if err != nil { + return "", err + } + tpl, err := env.FromString(content) + if err != nil { + return "", err + } + out, err := tpl.Execute(vs) + if err != nil { + return "", err + } + return out, nil + default: + return "", fmt.Errorf("unknown format type: %v", formatType) + } +} + +// Format returns the messages after renderring by the given formatType. +// e.g. +// +// msg := schema.UserMessage("hello world, {name}") +// msgs, err := msg.Format(ctx, map[string]any{"name": "eino"}, schema.FString) // <= this will render the content of msg by pyfmt +// // msgs[0].Content will be "hello world, eino" +func (m *Message) Format(_ context.Context, vs map[string]any, formatType FormatType) ([]*Message, error) { + c, err := formatContent(m.Content, vs, formatType) + if err != nil { + return nil, err + } + + copied := *m + copied.Content = c + return []*Message{&copied}, nil +} + +// String returns the string representation of the message. +// e.g. +// +// msg := schema.UserMessage("hello world") +// fmt.Println(msg.String()) // Output will be: `user: hello world`` +// +// msg := schema.Message{ +// Role: schema.Tool, +// Content: "{...}", +// ToolCallID: "callxxxx" +// } +// fmt.Println(msg.String()) +// Output will be: +// tool: {...} +// call_id: callxxxx +func (m *Message) String() string { + s := fmt.Sprintf("%s: %s", m.Role, m.Content) + if len(m.ToolCalls) > 0 { + s += fmt.Sprintf("\ntool_calls: %v", m.ToolCalls) + } + if m.ToolCallID != "" { + s += fmt.Sprintf("\ntool_call_id: %s", m.ToolCallID) + } + if m.ResponseMeta != nil { + s += fmt.Sprintf("\nfinish_reason: %s", m.ResponseMeta.FinishReason) + if m.ResponseMeta.Usage != nil { + s += fmt.Sprintf("\nusage: %v", m.ResponseMeta.Usage) + } + } + + return s +} + +// SystemMessage represents a message with Role "system". +func SystemMessage(content string) *Message { + return &Message{ + Role: System, + Content: content, + } +} + +// AssistantMessage represents a message with Role "assistant". +func AssistantMessage(content string, toolCalls []ToolCall) *Message { + return &Message{ + Role: Assistant, + Content: content, + ToolCalls: toolCalls, + } +} + +// UserMessage represents a message with Role "user". +func UserMessage(content string) *Message { + return &Message{ + Role: User, + Content: content, + } +} + +// ToolMessage represents a message with Role "tool". +func ToolMessage(content string, toolCallID string) *Message { + return &Message{ + Role: Tool, + Content: content, + ToolCallID: toolCallID, + } +} + +func concatToolCalls(chunks []ToolCall) ([]ToolCall, error) { + var merged []ToolCall + m := make(map[int][]int) + for i := range chunks { + index := chunks[i].Index + if index == nil { + merged = append(merged, chunks[i]) + } else { + m[*index] = append(m[*index], i) + } + } + + var args strings.Builder + for k, v := range m { + index := k + toolCall := ToolCall{Index: &index} + if len(v) > 0 { + toolCall = chunks[v[0]] + } + + args.Reset() + toolID, toolType, toolName := "", "", "" // these field will output atomically in any chunk + + for _, n := range v { + chunk := chunks[n] + if chunk.ID != "" { + if toolID == "" { + toolID = chunk.ID + } else if toolID != chunk.ID { + return nil, fmt.Errorf("cannot concat ToolCalls with different tool id: '%s' '%s'", toolID, chunk.ID) + } + + } + + if chunk.Type != "" { + if toolType == "" { + toolType = chunk.Type + } else if toolType != chunk.Type { + return nil, fmt.Errorf("cannot concat ToolCalls with different tool type: '%s' '%s'", toolType, chunk.Type) + } + } + + if chunk.Function.Name != "" { + if toolName == "" { + toolName = chunk.Function.Name + } else if toolName != chunk.Function.Name { + return nil, fmt.Errorf("cannot concat ToolCalls with different tool name: '%s' '%s'", toolName, chunk.Function.Name) + } + } + + if chunk.Function.Arguments != "" { + _, err := args.WriteString(chunk.Function.Arguments) + if err != nil { + return nil, err + } + } + } + + toolCall.ID = toolID + toolCall.Type = toolType + toolCall.Function.Name = toolName + toolCall.Function.Arguments = args.String() + + merged = append(merged, toolCall) + } + + if len(merged) > 1 { + sort.SliceStable(merged, func(i, j int) bool { + iVal, jVal := merged[i].Index, merged[j].Index + if iVal == nil && jVal == nil { + return false + } else if iVal == nil && jVal != nil { + return true + } else if iVal != nil && jVal == nil { + return false + } + + return *iVal < *jVal + }) + } + + return merged, nil +} + +// ConcatMessages concat messages with the same role and name. +// It will concat tool calls with the same index. +// It will return an error if the messages have different roles or names. +// It's useful for concatenating messages from a stream. +// e.g. +// +// msgs := []*Message{} +// for { +// msg, err := stream.Recv() +// if errors.Is(err, io.EOF) { +// break +// } +// if err != nil {...} +// msgs = append(msgs, msg) +// } +// +// concatedMsg, err := ConcatMessages(msgs) // concatedMsg.Content will be full content of all messages +func ConcatMessages(msgs []*Message) (*Message, error) { + + for idx, m := range msgs { + if m == nil { + return nil, fmt.Errorf("unexpected nil chunk in message stream, index: %d", idx) + } + } + + var ( + contents []string + contentLen int + toolCalls []ToolCall + ret = Message{} + extraList = make([]map[string]any, 0, len(msgs)) + ) + + for _, msg := range msgs { + if msg.Role != "" { + if ret.Role == "" { + ret.Role = msg.Role + } else if ret.Role != msg.Role { + return nil, fmt.Errorf("cannot concat messages with "+ + "different roles: '%s' '%s'", ret.Role, msg.Role) + } + } + + if msg.Name != "" { + if ret.Name == "" { + ret.Name = msg.Name + } else if ret.Name != msg.Name { + return nil, fmt.Errorf("cannot concat messages with"+ + " different names: '%s' '%s'", ret.Name, msg.Name) + } + } + + if msg.ToolCallID != "" { + if ret.ToolCallID == "" { + ret.ToolCallID = msg.ToolCallID + } else if ret.ToolCallID != msg.ToolCallID { + return nil, fmt.Errorf("cannot concat messages with"+ + " different toolCallIDs: '%s' '%s'", ret.ToolCallID, msg.ToolCallID) + } + } + + if msg.Content != "" { + contents = append(contents, msg.Content) + contentLen += len(msg.Content) + } + + if len(msg.ToolCalls) > 0 { + toolCalls = append(toolCalls, msg.ToolCalls...) + } + + if len(msg.Extra) > 0 { + extraList = append(extraList, msg.Extra) + } + + // There's no scenario that requires to concat messages with MultiContent currently + if len(msg.MultiContent) > 0 { + ret.MultiContent = msg.MultiContent + } + + if msg.ResponseMeta != nil && ret.ResponseMeta == nil { + ret.ResponseMeta = msg.ResponseMeta + } else if msg.ResponseMeta != nil && ret.ResponseMeta != nil { + // keep the last FinishReason with a valid value. + if msg.ResponseMeta.FinishReason != "" { + ret.ResponseMeta.FinishReason = msg.ResponseMeta.FinishReason + } + + if msg.ResponseMeta.Usage != nil { + if ret.ResponseMeta.Usage == nil { + ret.ResponseMeta.Usage = &TokenUsage{} + } + + if msg.ResponseMeta.Usage.PromptTokens > ret.ResponseMeta.Usage.PromptTokens { + ret.ResponseMeta.Usage.PromptTokens = msg.ResponseMeta.Usage.PromptTokens + } + if msg.ResponseMeta.Usage.CompletionTokens > ret.ResponseMeta.Usage.CompletionTokens { + ret.ResponseMeta.Usage.CompletionTokens = msg.ResponseMeta.Usage.CompletionTokens + } + + if msg.ResponseMeta.Usage.TotalTokens > ret.ResponseMeta.Usage.TotalTokens { + ret.ResponseMeta.Usage.TotalTokens = msg.ResponseMeta.Usage.TotalTokens + } + + } + + } + } + + if len(contents) > 0 { + var sb strings.Builder + sb.Grow(contentLen) + sb.WriteString(ret.Content) + for _, content := range contents { + _, err := sb.WriteString(content) + if err != nil { + return nil, err + } + } + + ret.Content = sb.String() + } + + if len(toolCalls) > 0 { + merged, err := concatToolCalls(toolCalls) + if err != nil { + return nil, err + } + + ret.ToolCalls = merged + } + + extra := gmap.Concat(extraList...) + if len(extra) > 0 { + ret.Extra = extra + } + + return &ret, nil +} + +// custom jinja env +var jinjaEnvOnce sync.Once +var jinjaEnv *gonja.Environment +var envInitErr error + +const ( + jinjaInclude = "include" + jinjaExtends = "extends" + jinjaImport = "import" + jinjaFrom = "from" +) + +func getJinjaEnv() (*gonja.Environment, error) { + jinjaEnvOnce.Do(func() { + jinjaEnv = gonja.NewEnvironment(config.DefaultConfig, gonja.DefaultLoader) + formatInitError := "init jinja env fail: %w" + var err error + if jinjaEnv.Statements.Exists(jinjaInclude) { + err = jinjaEnv.Statements.Replace(jinjaInclude, func(parser *parser.Parser, args *parser.Parser) (nodes.Statement, error) { + return nil, fmt.Errorf("keyword[include] has been disabled") + }) + if err != nil { + envInitErr = fmt.Errorf(formatInitError, err) + return + } + } + if jinjaEnv.Statements.Exists(jinjaExtends) { + err = jinjaEnv.Statements.Replace(jinjaExtends, func(parser *parser.Parser, args *parser.Parser) (nodes.Statement, error) { + return nil, fmt.Errorf("keyword[extends] has been disabled") + }) + if err != nil { + envInitErr = fmt.Errorf(formatInitError, err) + return + } + } + if jinjaEnv.Statements.Exists(jinjaFrom) { + err = jinjaEnv.Statements.Replace(jinjaFrom, func(parser *parser.Parser, args *parser.Parser) (nodes.Statement, error) { + return nil, fmt.Errorf("keyword[from] has been disabled") + }) + if err != nil { + envInitErr = fmt.Errorf(formatInitError, err) + return + } + } + if jinjaEnv.Statements.Exists(jinjaImport) { + err = jinjaEnv.Statements.Replace(jinjaImport, func(parser *parser.Parser, args *parser.Parser) (nodes.Statement, error) { + return nil, fmt.Errorf("keyword[import] has been disabled") + }) + if err != nil { + envInitErr = fmt.Errorf(formatInitError, err) + return + } + } + }) + return jinjaEnv, envInitErr +} diff --git a/schema/message_parser.go b/schema/message_parser.go new file mode 100644 index 0000000..64f620f --- /dev/null +++ b/schema/message_parser.go @@ -0,0 +1,138 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "fmt" + "strings" + + "github.com/bytedance/sonic" +) + +type MessageParser[T any] interface { + Parse(ctx context.Context, m *Message) (T, error) +} + +// MessageParseFrom determines the source of the data to be parsed. default is content (Message.Content). +type MessageParseFrom string + +const ( + MessageParseFromContent MessageParseFrom = "content" + MessageParseFromToolCall MessageParseFrom = "tool_call" +) + +type MessageJSONParseConfig struct { + // parse from content or tool call, default is content. + ParseFrom MessageParseFrom `json:"parse_from,omitempty"` + + // parse key path, default is empty. + // must be a valid json path expression, eg: field.sub_field + ParseKeyPath string `json:"parse_key_path,omitempty"` +} + +// NewMessageJSONParser creates a new MessageJSONParser. +func NewMessageJSONParser[T any](config *MessageJSONParseConfig) MessageParser[T] { + if config == nil { + config = &MessageJSONParseConfig{} + } + + if config.ParseFrom == "" { + config.ParseFrom = MessageParseFromContent + } + + return &MessageJSONParser[T]{ + ParseFrom: config.ParseFrom, + ParseKeyPath: config.ParseKeyPath, + } +} + +// MessageJSONParser is a parser that parses a message into an object T, using json unmarshal. +// eg of parse to single struct: +// +// config := &MessageJSONParseConfig{ +// ParseFrom: MessageParseFromToolCall, +// } +// parser := NewMessageJSONParser[GetUserParam](config) +// param, err := parser.Parse(ctx, message) +// +// eg of parse to slice of struct: +// +// config := &MessageJSONParseConfig{ +// ParseFrom: MessageParseFromToolCall, +// } +// +// parser := NewMessageJSONParser[GetUserParam](config) +// param, err := parser.Parse(ctx, message) +type MessageJSONParser[T any] struct { + ParseFrom MessageParseFrom + ParseKeyPath string +} + +// Parse parses a message into an object T. +func (p *MessageJSONParser[T]) Parse(ctx context.Context, m *Message) (parsed T, err error) { + if p.ParseFrom == MessageParseFromContent { + return p.parse(m.Content) + } else if p.ParseFrom == MessageParseFromToolCall { + if len(m.ToolCalls) == 0 { + return parsed, fmt.Errorf("no tool call found") + } + + return p.parse(m.ToolCalls[0].Function.Arguments) + } + + return parsed, fmt.Errorf("invalid parse from type: %s", p.ParseFrom) +} + +// extractData extracts data from a string using the parse key path. +func (p *MessageJSONParser[T]) extractData(data string) (string, error) { + if p.ParseKeyPath == "" { + return data, nil + } + + keys := strings.Split(p.ParseKeyPath, ".") + interfaceKeys := make([]interface{}, len(keys)) + for i, key := range keys { + interfaceKeys[i] = key + } + + node, err := sonic.GetFromString(data, interfaceKeys...) + if err != nil { + return "", fmt.Errorf("failed to get parse key path: %w", err) + } + + bytes, err := node.MarshalJSON() + if err != nil { + return "", fmt.Errorf("failed to marshal node: %w", err) + } + + return string(bytes), nil +} + +// parse parses a string into an object T. +func (p *MessageJSONParser[T]) parse(data string) (parsed T, err error) { + parsedData, err := p.extractData(data) + if err != nil { + return parsed, err + } + + if err := sonic.UnmarshalString(parsedData, &parsed); err != nil { + return parsed, fmt.Errorf("failed to unmarshal content: %w", err) + } + + return parsed, nil +} diff --git a/schema/message_parser_test.go b/schema/message_parser_test.go new file mode 100644 index 0000000..e645f87 --- /dev/null +++ b/schema/message_parser_test.go @@ -0,0 +1,179 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +type TestStructForParse struct { + ID int `json:"id"` + Name string `json:"name"` + XX struct { + YY int `json:"yy"` + } `json:"xx"` +} + +func TestMessageJSONParser(t *testing.T) { + ctx := context.Background() + + t.Run("parse from content", func(t *testing.T) { + parser := NewMessageJSONParser[TestStructForParse](&MessageJSONParseConfig{ + ParseFrom: MessageParseFromContent, + }) + + parsed, err := parser.Parse(ctx, &Message{ + Content: `{"id": 1, "name": "test", "xx": {"yy": 2}}`, + }) + assert.Nil(t, err) + assert.Equal(t, 1, parsed.ID) + }) + + t.Run("parse from tool call", func(t *testing.T) { + t.Run("only one tool call, default use first tool call", func(t *testing.T) { + parser := NewMessageJSONParser[TestStructForParse](&MessageJSONParseConfig{ + ParseFrom: MessageParseFromToolCall, + }) + + parsed, err := parser.Parse(ctx, &Message{ + ToolCalls: []ToolCall{ + {Function: FunctionCall{Arguments: `{"id": 1, "name": "test", "xx": {"yy": 2}}`}}, + }, + }) + assert.Nil(t, err) + assert.Equal(t, 1, parsed.ID) + }) + + t.Run("parse key path", func(t *testing.T) { + type TestStructForParse2 struct { + YY int `json:"yy"` + } + + parser := NewMessageJSONParser[TestStructForParse2](&MessageJSONParseConfig{ + ParseFrom: MessageParseFromToolCall, + ParseKeyPath: "xx", + }) + + parsed, err := parser.Parse(ctx, &Message{ + ToolCalls: []ToolCall{ + {Function: FunctionCall{Arguments: `{"id": 1, "name": "test", "xx": {"yy": 2}}`}}, + }, + }) + assert.Nil(t, err) + assert.Equal(t, 2, parsed.YY) + }) + + t.Run("parse key path, deep level", func(t *testing.T) { + type TestStructForParse3 struct { + ZZ int `json:"zz"` + } + + parser := NewMessageJSONParser[TestStructForParse3](&MessageJSONParseConfig{ + ParseFrom: MessageParseFromToolCall, + ParseKeyPath: "xx.yy", + }) + + parsed, err := parser.Parse(ctx, &Message{ + ToolCalls: []ToolCall{ + {Function: FunctionCall{Arguments: `{"id": 1, "name": "test", "xx": {"yy": {"zz": 3}}}`}}, + }, + }) + assert.Nil(t, err) + assert.Equal(t, 3, parsed.ZZ) + }) + + t.Run("parse key with pointer", func(t *testing.T) { + type TestStructForParse4 struct { + ZZ *int `json:"zz"` + } + + parser := NewMessageJSONParser[**TestStructForParse4](&MessageJSONParseConfig{ + ParseFrom: MessageParseFromToolCall, + }) + + parsed, err := parser.Parse(ctx, &Message{ + ToolCalls: []ToolCall{{Function: FunctionCall{Arguments: `{"zz": 3}`}}}, + }) + assert.Nil(t, err) + assert.Equal(t, 3, *((**parsed).ZZ)) + }) + }) + + t.Run("parse of slice", func(t *testing.T) { + t.Run("valid slice string, not multiple tool calls", func(t *testing.T) { + parser := NewMessageJSONParser[[]map[string]any](&MessageJSONParseConfig{ + ParseFrom: MessageParseFromToolCall, + }) + + parsed, err := parser.Parse(ctx, &Message{ + ToolCalls: []ToolCall{{Function: FunctionCall{Arguments: `[{"id": 1}, {"id": 2}]`}}}, + }) + assert.Nil(t, err) + assert.Equal(t, 2, len(parsed)) + }) + + t.Run("invalid slice string, not multiple tool calls", func(t *testing.T) { + parser := NewMessageJSONParser[[]map[string]any](&MessageJSONParseConfig{ + ParseFrom: MessageParseFromToolCall, + }) + + _, err := parser.Parse(ctx, &Message{ + ToolCalls: []ToolCall{ + {Function: FunctionCall{Arguments: `{"id": 1}`}}, + {Function: FunctionCall{Arguments: `{"id": 2}`}}, + }, + }) + assert.NotNil(t, err) + }) + }) + + t.Run("invalid configs", func(t *testing.T) { + parser := NewMessageJSONParser[TestStructForParse](nil) + _, err := parser.Parse(ctx, &Message{ + Content: "", + }) + assert.NotNil(t, err) + }) + + t.Run("invalid parse key path", func(t *testing.T) { + parser := NewMessageJSONParser[TestStructForParse](&MessageJSONParseConfig{ + ParseKeyPath: "...invalid", + }) + _, err := parser.Parse(ctx, &Message{}) + assert.NotNil(t, err) + }) + + t.Run("invalid parse from", func(t *testing.T) { + parser := NewMessageJSONParser[TestStructForParse](&MessageJSONParseConfig{ + ParseFrom: "invalid", + }) + _, err := parser.Parse(ctx, &Message{}) + assert.NotNil(t, err) + }) + + t.Run("invalid parse from type", func(t *testing.T) { + parser := NewMessageJSONParser[int](&MessageJSONParseConfig{ + ParseFrom: MessageParseFrom("invalid"), + }) + _, err := parser.Parse(ctx, &Message{}) + assert.NotNil(t, err) + }) + +} diff --git a/schema/message_test.go b/schema/message_test.go new file mode 100644 index 0000000..d635234 --- /dev/null +++ b/schema/message_test.go @@ -0,0 +1,658 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "reflect" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/utils/generic" +) + +func TestMessageTemplate(t *testing.T) { + pyFmtMessage := UserMessage("input: {question}") + jinja2Message := UserMessage("input: {{question}}") + goTemplateMessage := UserMessage("input: {{.question}}") + ctx := context.Background() + question := "what's the weather today" + expected := []*Message{UserMessage("input: " + question)} + + ms, err := pyFmtMessage.Format(ctx, map[string]any{"question": question}, FString) + assert.Nil(t, err) + assert.True(t, reflect.DeepEqual(expected, ms)) + ms, err = jinja2Message.Format(ctx, map[string]any{"question": question}, Jinja2) + assert.Nil(t, err) + assert.True(t, reflect.DeepEqual(expected, ms)) + ms, err = goTemplateMessage.Format(ctx, map[string]any{"question": question}, GoTemplate) + assert.Nil(t, err) + assert.True(t, reflect.DeepEqual(expected, ms)) + + mp := MessagesPlaceholder("chat_history", false) + m1 := UserMessage("how are you?") + m2 := AssistantMessage("I'm good. how about you?", nil) + ms, err = mp.Format(ctx, map[string]any{"chat_history": []*Message{m1, m2}}, FString) + assert.Nil(t, err) + + // len(ms) == 2 + assert.Equal(t, 2, len(ms)) + assert.Equal(t, ms[0], m1) + assert.Equal(t, ms[1], m2) +} + +func TestConcatMessage(t *testing.T) { + t.Run("tool_call_normal_append", func(t *testing.T) { + expectMsg := &Message{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "i_am_a_too_call_id", + Type: "function", + Function: FunctionCall{ + Name: "i_am_a_tool_name", + Arguments: "{}", + }, + }, + }, + } + givenMsgList := []*Message{ + { + Role: "", + Content: "", + ToolCalls: []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + }, + }, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "i_am_a_too_call_id", + Type: "function", + Function: FunctionCall{ + Name: "i_am_a_tool_name", + }, + }, + }, + }, + { + Role: "", + Content: "", + ToolCalls: []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "", + Type: "", + Function: FunctionCall{ + Name: "", + Arguments: "{}", + }, + }, + }, + }, + } + + msg, err := ConcatMessages(givenMsgList) + assert.NoError(t, err) + assert.EqualValues(t, expectMsg, msg) + }) + + t.Run("exist_nil_message", func(t *testing.T) { + givenMsgList := []*Message{ + nil, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "i_am_a_too_call_id", + Type: "function", + Function: FunctionCall{ + Name: "i_am_a_tool_name", + }, + }, + }, + }, + } + + _, err := ConcatMessages(givenMsgList) + assert.ErrorContains(t, err, "unexpected nil chunk in message stream") + }) + + t.Run("response_meta", func(t *testing.T) { + expectedMsg := &Message{ + Role: "assistant", + ResponseMeta: &ResponseMeta{ + FinishReason: "stop", + Usage: &TokenUsage{ + CompletionTokens: 15, + PromptTokens: 30, + TotalTokens: 45, + }, + }, + } + + givenMsgList := []*Message{ + { + Role: "assistant", + }, + { + Role: "assistant", + ResponseMeta: &ResponseMeta{ + FinishReason: "", + Usage: &TokenUsage{ + CompletionTokens: 10, + PromptTokens: 20, + TotalTokens: 30, + }, + }, + }, + { + Role: "assistant", + ResponseMeta: &ResponseMeta{ + FinishReason: "stop", + }, + }, + { + Role: "assistant", + ResponseMeta: &ResponseMeta{ + Usage: &TokenUsage{ + CompletionTokens: 15, + PromptTokens: 30, + TotalTokens: 45, + }, + }, + }, + } + + msg, err := ConcatMessages(givenMsgList) + assert.NoError(t, err) + assert.Equal(t, expectedMsg, msg) + + givenMsgList = append(givenMsgList, &Message{ + Role: "assistant", + ResponseMeta: &ResponseMeta{ + FinishReason: "tool_calls", + }, + }) + msg, err = ConcatMessages(givenMsgList) + assert.NoError(t, err) + expectedMsg.ResponseMeta.FinishReason = "tool_calls" + assert.Equal(t, expectedMsg, msg) + + }) + + t.Run("err: different roles", func(t *testing.T) { + msgs := []*Message{ + {Role: User}, + {Role: Assistant}, + } + + msg, err := ConcatMessages(msgs) + if assert.Error(t, err) { + assert.ErrorContains(t, err, "cannot concat messages with different roles") + assert.Nil(t, msg) + } + }) + + t.Run("err: different name", func(t *testing.T) { + msgs := []*Message{ + {Role: Assistant, Name: "n", Content: "1"}, + {Role: Assistant, Name: "a", Content: "2"}, + } + + msg, err := ConcatMessages(msgs) + if assert.Error(t, err) { + assert.ErrorContains(t, err, "cannot concat messages with different names") + assert.Nil(t, msg) + } + }) + + t.Run("err: different tool name", func(t *testing.T) { + msgs := []*Message{ + { + Role: "", + Content: "", + ToolCallID: "123", + ToolCalls: []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "abc", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + }, + }, + { + Role: "assistant", + Content: "", + ToolCallID: "321", + ToolCalls: []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "abc", + Type: "function", + Function: FunctionCall{ + Name: "i_am_a_tool_name", + }, + }, + }, + }, + } + + msg, err := ConcatMessages(msgs) + if assert.Error(t, err) { + assert.ErrorContains(t, err, "cannot concat messages with different toolCallIDs") + assert.Nil(t, msg) + } + }) + + t.Run("first response meta usage is nil", func(t *testing.T) { + exp := &Message{ + Role: "assistant", + ResponseMeta: &ResponseMeta{ + FinishReason: "stop", + Usage: &TokenUsage{ + CompletionTokens: 15, + PromptTokens: 30, + TotalTokens: 45, + }, + }, + } + + msgs := []*Message{ + { + Role: "assistant", + ResponseMeta: &ResponseMeta{ + FinishReason: "", + Usage: nil, + }, + }, + { + Role: "assistant", + ResponseMeta: &ResponseMeta{ + FinishReason: "stop", + }, + }, + { + Role: "assistant", + ResponseMeta: &ResponseMeta{ + Usage: &TokenUsage{ + CompletionTokens: 15, + PromptTokens: 30, + TotalTokens: 45, + }, + }, + }, + } + + msg, err := ConcatMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, exp, msg) + }) + + t.Run("concurrent concat", func(t *testing.T) { + content := "i_am_a_good_concat_message" + exp := &Message{Role: Assistant, Content: content} + var msgs []*Message + for i := 0; i < len(content); i++ { + msgs = append(msgs, &Message{Role: Assistant, Content: content[i : i+1]}) + } + + wg := sync.WaitGroup{} + size := 100 + wg.Add(size) + for i := 0; i < size; i++ { + go func() { + defer wg.Done() + msg, err := ConcatMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, exp, msg) + }() + } + + wg.Wait() + }) +} + +func TestConcatToolCalls(t *testing.T) { + t.Run("atomic_field_in_first_chunk", func(t *testing.T) { + givenToolCalls := []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + }, + }, + { + Index: generic.PtrOf(0), + Function: FunctionCall{ + Arguments: "call me please", + }, + }, + } + + expectedToolCall := ToolCall{ + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + Arguments: "call me please", + }, + } + + tc, err := concatToolCalls(givenToolCalls) + assert.NoError(t, err) + assert.Len(t, tc, 1) + assert.EqualValues(t, expectedToolCall, tc[0]) + }) + + t.Run("atomic_field_in_every_chunk", func(t *testing.T) { + givenToolCalls := []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + }, + }, + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + Arguments: "call me please", + }, + }, + } + + expectedToolCall := ToolCall{ + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + Arguments: "call me please", + }, + } + + tc, err := concatToolCalls(givenToolCalls) + assert.NoError(t, err) + assert.Len(t, tc, 1) + assert.EqualValues(t, expectedToolCall, tc[0]) + }) + + t.Run("atomic_field_in_interval", func(t *testing.T) { + givenToolCalls := []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + { + Index: generic.PtrOf(0), + ID: "", + Type: "function", + Function: FunctionCall{ + Name: "", + Arguments: "call me please", + }, + }, + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "", + Function: FunctionCall{ + Name: "", + Arguments: "", + }, + }, + } + + expectedToolCall := ToolCall{ + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "", + Arguments: "call me please", + }, + } + + tc, err := concatToolCalls(givenToolCalls) + assert.NoError(t, err) + assert.Len(t, tc, 1) + assert.EqualValues(t, expectedToolCall, tc[0]) + }) + + t.Run("different_tool_id", func(t *testing.T) { + givenToolCalls := []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + }, + }, + { + Index: generic.PtrOf(0), + ID: "tool_call_id_1", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + Arguments: "call me please", + }, + }, + } + + _, err := concatToolCalls(givenToolCalls) + t.Logf("concat tool call failed info: %v", err) + assert.ErrorContains(t, err, "cannot concat ToolCalls with different tool id") + }) + + t.Run("different_tool_type", func(t *testing.T) { + givenToolCalls := []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + }, + }, + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function_1", + Function: FunctionCall{ + Name: "tool_name", + Arguments: "call me please", + }, + }, + } + + _, err := concatToolCalls(givenToolCalls) + t.Logf("concat tool call failed info: %v", err) + assert.ErrorContains(t, err, "cannot concat ToolCalls with different tool type") + }) + + t.Run("different_tool_name", func(t *testing.T) { + givenToolCalls := []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name", + }, + }, + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "tool_name_1", + Arguments: "call me please", + }, + }, + } + + _, err := concatToolCalls(givenToolCalls) + t.Logf("concat tool call failed info: %v", err) + assert.ErrorContains(t, err, "cannot concat ToolCalls with different tool name") + }) + + t.Run("multi_tool_call", func(t *testing.T) { + givenToolCalls := []ToolCall{ + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + { + Index: generic.PtrOf(0), + ID: "", + Type: "function", + Function: FunctionCall{ + Name: "", + Arguments: "call me please", + }, + }, + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "", + Function: FunctionCall{ + Name: "", + Arguments: "", + }, + }, + { + Index: generic.PtrOf(1), + ID: "tool_call_id", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + { + Index: generic.PtrOf(1), + ID: "", + Type: "function", + Function: FunctionCall{ + Name: "", + Arguments: "call me please", + }, + }, + { + Index: generic.PtrOf(1), + ID: "tool_call_id", + Type: "", + Function: FunctionCall{ + Name: "", + Arguments: "", + }, + }, + { + Index: nil, + ID: "22", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + { + Index: nil, + ID: "44", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + } + + expectedToolCall := []ToolCall{ + { + Index: nil, + ID: "22", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + { + Index: nil, + ID: "44", + Type: "", + Function: FunctionCall{ + Name: "", + }, + }, + { + Index: generic.PtrOf(0), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "", + Arguments: "call me please", + }, + }, + { + Index: generic.PtrOf(1), + ID: "tool_call_id", + Type: "function", + Function: FunctionCall{ + Name: "", + Arguments: "call me please", + }, + }, + } + + tc, err := concatToolCalls(givenToolCalls) + assert.NoError(t, err) + assert.EqualValues(t, expectedToolCall, tc) + }) +} diff --git a/schema/stream.go b/schema/stream.go new file mode 100644 index 0000000..07e79bb --- /dev/null +++ b/schema/stream.go @@ -0,0 +1,773 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "container/list" + "errors" + "io" + "reflect" + "runtime/debug" + "sync" + "sync/atomic" + + "github.com/cloudwego/eino/utils/safe" +) + +// ErrNoValue is the error returned when the value is not found. +// used in convert function when has WithInputKey option. +var ErrNoValue = errors.New("no value") + +// Pipe creates a new stream with the given capacity that represented with StreamWriter and StreamReader. +// The capacity is the maximum number of items that can be buffered in the stream. +// e.g. +// +// sr, sw := schema.Pipe[string](3) +// go func() { // send data +// defer sw.Close() +// for i := 0; i < 10; i++ { +// sw.send(i, nil) +// } +// } +// +// defer sr.Close() +// for chunk, err := sr.Recv() { +// if errors.Is(err, io.EOF) { +// break +// } +// fmt.Println(chunk) +// } +func Pipe[T any](cap int) (*StreamReader[T], *StreamWriter[T]) { + stm := newStream[T](cap) + return stm.asReader(), &StreamWriter[T]{stm: stm} +} + +// StreamWriter the sender of a stream. +// created by Pipe function. +// eg. +// +// sr, sw := schema.Pipe[string](3) +// go func() { // send data +// defer sw.Close() +// for i := 0; i < 10; i++ { +// sw.send(i, nil) +// } +// } +type StreamWriter[T any] struct { + stm *stream[T] +} + +// Send sends a value to the stream. +// eg. +// +// closed := sw.Send(i, nil) +// if closed { +// // the stream is closed +// } +func (sw *StreamWriter[T]) Send(chunk T, err error) (closed bool) { + return sw.stm.send(chunk, err) +} + +// Close notify the receiver that the stream sender has finished. +// The stream receiver will get an error of io.EOF from StreamReader.Recv(). +// Notice: always remember to call Close() after sending all data. +// eg. +// +// defer sw.Close() +// for i := 0; i < 10; i++ { +// sw.Send(i, nil) +// } +func (sw *StreamWriter[T]) Close() { + sw.stm.closeSend() +} + +// StreamReader the receiver of a stream. +// created by Pipe function. +// eg. +// +// sr, sw := schema.Pipe[string](3) +// // omit sending data +// // most of time, reader is returned by function, and used in another function. +// +// for chunk, err := sr.Recv() { +// if errors.Is(err, io.EOF) { +// break +// } +// if err != nil { +// // handle error +// } +// fmt.Println(chunk) +// } +type StreamReader[T any] struct { + typ readerType + + st *stream[T] + + ar *arrayReader[T] + + msr *multiStreamReader[T] + + srw *streamReaderWithConvert[T] + + csr *childStreamReader[T] +} + +// Recv receives a value from the stream. +// eg. +// +// for chunk, err := sr.Recv() { +// if errors.Is(err, io.EOF) { +// break +// } +// if err != nil { +// fmt.Println(chunk) +// } +func (sr *StreamReader[T]) Recv() (T, error) { + switch sr.typ { + case readerTypeStream: + return sr.st.recv() + case readerTypeArray: + return sr.ar.recv() + case readerTypeMultiStream: + return sr.msr.recv() + case readerTypeWithConvert: + return sr.srw.recv() + case readerTypeChild: + return sr.csr.recv() + default: + panic("impossible") // nolint: byted_s_panic_detect + } +} + +// Close safely closes the StreamReader. +// It should be called only once, as multiple calls may not work as expected. +// Notice: always remember to call Close() after using Recv(). +// eg. +// +// defer sr.Close() +// +// for chunk, err := sr.Recv() { +// if errors.Is(err, io.EOF) { +// break +// } +// fmt.Println(chunk) +// } +func (sr *StreamReader[T]) Close() { + switch sr.typ { + case readerTypeStream: + sr.st.closeRecv() + case readerTypeArray: + + case readerTypeMultiStream: + sr.msr.close() + case readerTypeWithConvert: + sr.srw.close() + case readerTypeChild: + sr.csr.close() + default: + panic("impossible") // nolint: byted_s_panic_detect + } +} + +// Copy creates a slice of new StreamReader. +// The number of copies, indicated by the parameter n, should be a non-zero positive integer. +// The original StreamReader will become unusable after Copy. +// eg. +// +// sr := schema.StreamReaderFromArray([]int{1, 2, 3}) +// srs := sr.Copy(2) +// +// sr1 := srs[0] +// sr2 := srs[1] +// defer sr1.Close() +// defer sr2.Close() +// +// chunk1, err1 := sr1.Recv() +// chunk2, err2 := sr2.Recv() +func (sr *StreamReader[T]) Copy(n int) []*StreamReader[T] { + if n < 2 { + return []*StreamReader[T]{sr} + } + + if sr.typ == readerTypeArray { + ret := make([]*StreamReader[T], n) + for i, ar := range sr.ar.copy(n) { + ret[i] = &StreamReader[T]{typ: readerTypeArray, ar: ar} + } + return ret + } + + return copyStreamReaders[T](sr, n) +} + +func (sr *StreamReader[T]) recvAny() (any, error) { + return sr.Recv() +} + +func (sr *StreamReader[T]) copyAny(n int) []iStreamReader { + ret := make([]iStreamReader, n) + + srs := sr.Copy(n) + + for i := 0; i < n; i++ { + ret[i] = srs[i] + } + + return ret +} + +type readerType int + +const ( + readerTypeStream readerType = iota + readerTypeArray + readerTypeMultiStream + readerTypeWithConvert + readerTypeChild +) + +type iStreamReader interface { + recvAny() (any, error) + copyAny(int) []iStreamReader + Close() +} + +// stream is a channel-based stream with 1 sender and 1 receiver. +// The sender calls closeSend() to notify the receiver that the stream sender has finished. +// The receiver calls closeRecv() to notify the sender that the receiver stop receiving. +type stream[T any] struct { + items chan streamItem[T] + + closed chan struct{} + isClosed uint32 +} + +type streamItem[T any] struct { + chunk T + err error +} + +func newStream[T any](cap int) *stream[T] { + return &stream[T]{ + items: make(chan streamItem[T], cap), + closed: make(chan struct{}), + } +} + +func (s *stream[T]) asReader() *StreamReader[T] { + return &StreamReader[T]{typ: readerTypeStream, st: s} +} + +func (s *stream[T]) recv() (chunk T, err error) { + item, ok := <-s.items + + if !ok { + item.err = io.EOF + } + + return item.chunk, item.err +} + +func (s *stream[T]) send(chunk T, err error) (closed bool) { + // if the stream is closed, return immediately + select { + case <-s.closed: + return true + default: + } + + item := streamItem[T]{chunk, err} + + select { + case <-s.closed: + return true + case s.items <- item: + return false + } +} + +func (s *stream[T]) closeSend() { + close(s.items) +} + +func (s *stream[T]) closeRecv() { + if !atomic.CompareAndSwapUint32(&s.isClosed, 0, 1) { + return + } + + close(s.closed) +} + +// StreamReaderFromArray creates a StreamReader from a given slice of elements. +// It takes an array of type T and returns a pointer to a StreamReader[T]. +// This allows for streaming the elements of the array in a controlled manner. +// eg. +// +// sr := schema.StreamReaderFromArray([]int{1, 2, 3}) +// defer sr.Close() +// +// for chunk, err := sr.Recv() { +// fmt.Println(chunk) +// } +func StreamReaderFromArray[T any](arr []T) *StreamReader[T] { + return &StreamReader[T]{ar: &arrayReader[T]{arr: arr}, typ: readerTypeArray} +} + +type arrayReader[T any] struct { + arr []T + index int +} + +func (ar *arrayReader[T]) recv() (T, error) { + if ar.index < len(ar.arr) { + ret := ar.arr[ar.index] + ar.index++ + + return ret, nil + } + + var t T + return t, io.EOF +} + +func (ar *arrayReader[T]) copy(n int) []*arrayReader[T] { + ret := make([]*arrayReader[T], n) + + for i := 0; i < n; i++ { + ret[i] = &arrayReader[T]{ + arr: ar.arr, // nolint: byted_use_uninitialized_object + index: ar.index, + } + } + + return ret +} + +type multiStreamReader[T any] struct { + sts []*stream[T] + + itemsCases []reflect.SelectCase + + numOfClosedItemsCh int +} + +func newMultiStreamReader[T any](sts []*stream[T]) *multiStreamReader[T] { + itemsCases := make([]reflect.SelectCase, len(sts)) + + for i, st := range sts { + itemsCases[i] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(st.items), + } + } + + return &multiStreamReader[T]{ + sts: sts, + itemsCases: itemsCases, + } +} + +func (msr *multiStreamReader[T]) recv() (T, error) { + for msr.numOfClosedItemsCh < len(msr.sts) { + chosen, recv, ok := reflect.Select(msr.itemsCases) + if ok { + item := recv.Interface().(streamItem[T]) // nolint: byted_interface_check_golintx + return item.chunk, item.err + } + + msr.itemsCases[chosen].Chan = reflect.Value{} + msr.numOfClosedItemsCh++ + } + + var t T + + return t, io.EOF +} + +func (msr *multiStreamReader[T]) close() { + for _, s := range msr.sts { + s.closeRecv() + } +} + +type streamReaderWithConvert[T any] struct { + sr iStreamReader + + convert func(any) (T, error) +} + +func newStreamReaderWithConvert[T any](origin iStreamReader, convert func(any) (T, error)) *StreamReader[T] { + srw := &streamReaderWithConvert[T]{ + sr: origin, + convert: convert, + } + + return &StreamReader[T]{ + typ: readerTypeWithConvert, + srw: srw, + } +} + +// StreamReaderWithConvert converts the stream reader to another stream reader. +// +// eg. +// +// intReader := StreamReaderFromArray([]int{1, 2, 3}) +// stringReader := StreamReaderWithConvert(sr, func(i int) (string, error) { +// return fmt.Sprintf("val_%d", i), nil +// }) +// +// defer stringReader.Close() // Close the reader if you using Recv(), or may cause memory/goroutine leak. +// s, err := stringReader.Recv() +// fmt.Println(s) // Output: val_1 +func StreamReaderWithConvert[T, D any](sr *StreamReader[T], convert func(T) (D, error)) *StreamReader[D] { + c := func(a any) (D, error) { + return convert(a.(T)) // nolint: byted_interface_check_golintx + } + + return newStreamReaderWithConvert(sr, c) +} + +func (srw *streamReaderWithConvert[T]) recv() (T, error) { + for { + out, err := srw.sr.recvAny() + + if err != nil { + var t T + return t, err + } + + t, err := srw.convert(out) + if err == nil { + return t, nil + } + + if !errors.Is(err, ErrNoValue) { + return t, err + } + } +} + +func (srw *streamReaderWithConvert[T]) close() { + srw.sr.Close() +} + +func (srw *streamReaderWithConvert[T]) toStream() *stream[T] { + ret := newStream[T](5) + + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) // nolint: byted_returned_err_should_do_check + + var chunk T + _ = ret.send(chunk, e) + } + + ret.closeSend() + srw.close() + }() + + for { + out, err := srw.recv() + if err == io.EOF { + break + } + + closed := ret.send(out, err) + if closed { + break + } + } + }() + + return ret +} + +type listElement[T any] struct { + item streamItem[T] + refCount int +} + +func copyStreamReaders[T any](sr *StreamReader[T], n int) []*StreamReader[T] { + cpsr := &parentStreamReader[T]{ + sr: sr, + recvMu: sync.Mutex{}, + mem: &cpStreamMem[T]{ + mu: sync.Mutex{}, + buf: list.New(), + subStreamList: make([]*list.Element, n), + closedNum: 0, + closedList: make([]bool, n), + hasFinished: false, + }, + } + + ret := make([]*StreamReader[T], n) + for i := range ret { + ret[i] = &StreamReader[T]{ + csr: &childStreamReader[T]{ + parent: cpsr, + index: i, + }, + typ: readerTypeChild, + } + } + + return ret +} + +type parentStreamReader[T any] struct { + sr *StreamReader[T] + + recvMu sync.Mutex + + mem *cpStreamMem[T] +} + +type cpStreamMem[T any] struct { + mu sync.Mutex + + buf *list.List + subStreamList []*list.Element + + closedNum int + closedList []bool + + hasFinished bool +} + +func (c *parentStreamReader[T]) peek(idx int) (T, error) { + if t, err, ok := c.mem.peek(idx); ok { + return t, err + } + + c.recvMu.Lock() + defer c.recvMu.Unlock() + + // retry read from buffer + if t, err, ok := c.mem.peek(idx); ok { + return t, err + } + + // get value from StreamReader + nChunk, err := c.sr.Recv() + + c.mem.set(idx, nChunk, err) + + return nChunk, err +} + +func (c *parentStreamReader[T]) close(idx int) { + if allClosed := c.mem.close(idx); allClosed { + c.sr.Close() + } +} + +func (m *cpStreamMem[T]) peek(idx int) (T, error, bool) { + m.mu.Lock() + defer m.mu.Unlock() + + if elem := m.subStreamList[idx]; elem != nil { + next := elem.Next() + cElem := elem.Value.(*listElement[T]) // nolint: byted_interface_check_golintx + cElem.refCount-- + if cElem.refCount == 0 { + m.buf.Remove(elem) + } + + m.subStreamList[idx] = next + return cElem.item.chunk, cElem.item.err, true + } + + var t T + + if m.hasFinished { + return t, io.EOF, true + } + + return t, nil, false +} + +func (m *cpStreamMem[T]) set(idx int, nChunk T, err error) { + m.mu.Lock() + defer m.mu.Unlock() + + if err == io.EOF { // nolint: byted_s_error_binary + m.hasFinished = true + return + } + + nElem := &listElement[T]{ + item: streamItem[T]{chunk: nChunk, err: err}, + refCount: len(m.subStreamList) - m.closedNum - 1, // except chan receiver + } + + if nElem.refCount == 0 { + // no need to set buffer when there's no other receivers + return + } + + elem := m.buf.PushBack(nElem) + for i := range m.subStreamList { + if m.subStreamList[i] == nil && i != idx && !m.closedList[i] { + m.subStreamList[i] = elem + } + } +} + +func (m *cpStreamMem[T]) close(idx int) (allClosed bool) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.closedList[idx] { + return false // avoid close multiple times + } + + m.closedList[idx] = true + m.closedNum++ + if m.closedNum == len(m.subStreamList) { + allClosed = true + } + + p := m.subStreamList[idx] + for p != nil { + next := p.Next() + ptr := p.Value.(*listElement[T]) // nolint: byted_interface_check_golintx + ptr.refCount-- + if ptr.refCount == 0 { + m.buf.Remove(p) + } + + p = next + } + + return allClosed +} + +type childStreamReader[T any] struct { + parent *parentStreamReader[T] + index int +} + +func (csr *childStreamReader[T]) recv() (T, error) { + return csr.parent.peek(csr.index) +} + +func (csr *childStreamReader[T]) toStream() *stream[T] { + ret := newStream[T](5) + + go func() { + defer func() { + panicErr := recover() + if panicErr != nil { + e := safe.NewPanicErr(panicErr, debug.Stack()) // nolint: byted_returned_err_should_do_check + + var chunk T + _ = ret.send(chunk, e) + } + + ret.closeSend() + csr.close() + }() + + for { + out, err := csr.recv() + if err == io.EOF { + break + } + + closed := ret.send(out, err) + if closed { + break + } + } + }() + + return ret +} + +func (csr *childStreamReader[T]) close() { + csr.parent.close(csr.index) +} + +// MergeStreamReaders merge multiple StreamReader into one. +// it's useful when you want to merge multiple streams into one. +// eg. +// +// sr1, sr2 := schema.Pipe[string](2) +// defer sr1.Close() +// defer sr2.Close() +// +// sr := schema.MergeStreamReaders([]*schema.StreamReader[string]{sr1, sr2}) +// +// defer sr.Close() +// for chunk, err := sr.Recv() { +// fmt.Println(chunk) +// } +func MergeStreamReaders[T any](srs []*StreamReader[T]) *StreamReader[T] { + if len(srs) < 1 { + return nil + } + + if len(srs) < 2 { + return srs[0] + } + + var arr []T + var ss []*stream[T] + + for _, sr := range srs { + switch sr.typ { + case readerTypeStream: + ss = append(ss, sr.st) + case readerTypeArray: + arr = append(arr, sr.ar.arr[sr.ar.index:]...) + case readerTypeMultiStream: + ss = append(ss, sr.msr.sts...) + case readerTypeWithConvert: + ss = append(ss, sr.srw.toStream()) + case readerTypeChild: + ss = append(ss, sr.csr.toStream()) + default: + panic("impossible") // nolint: byted_s_panic_detect + } + } + + if len(ss) == 0 && len(arr) != 0 { + return &StreamReader[T]{ + typ: readerTypeArray, + ar: &arrayReader[T]{ + arr: arr, + index: 0, + }, + } + } else if len(arr) != 0 { + s := newStream[T](len(arr)) + for i := range arr { + s.send(arr[i], nil) + } + s.closeSend() + ss = append(ss, s) + } + + return &StreamReader[T]{ + typ: readerTypeMultiStream, + msr: newMultiStreamReader(ss), + } +} diff --git a/schema/stream_test.go b/schema/stream_test.go new file mode 100644 index 0000000..1d14c6c --- /dev/null +++ b/schema/stream_test.go @@ -0,0 +1,551 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "errors" + "fmt" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStream(t *testing.T) { + s := newStream[int](0) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + closed := s.send(i, nil) + t.Logf("send: %d, closed: %v", i, closed) + if closed { + break + } + } + s.closeSend() + }() + + i := 0 + for { + i++ + if i > 5 { + s.closeRecv() + break + } + v, err := s.recv() + if err != nil { + assert.ErrorIs(t, err, io.EOF) + break + } + t.Log(v) + } + + wg.Wait() +} + +func TestStreamCopy(t *testing.T) { + s := newStream[string](10) + srs := s.asReader().Copy(2) + + s.send("a", nil) + s.send("b", nil) + s.send("c", nil) + s.closeSend() + + defer func() { + for _, sr := range srs { + sr.Close() + } + }() + + for { + v, err := srs[0].Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + t.Fatal(err) + } + + t.Log("copy 01 recv", v) + } + + for { + v, err := srs[1].Recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + t.Fatal(err) + } + + t.Log("copy 02 recv", v) + } + + for { + v, err := s.recv() + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + t.Fatal(err) + } + + t.Log("recv origin", v) + } + + t.Log("done") +} + +func TestNewStreamCopy(t *testing.T) { + t.Run("test one index recv channel blocked while other indexes could recv", func(t *testing.T) { + s := newStream[string](1) + scp := s.asReader().Copy(2) + + var t1, t2 time.Time + + go func() { + s.send("a", nil) + t1 = time.Now() + time.Sleep(time.Millisecond * 200) + s.send("a", nil) + s.closeSend() + }() + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + defer func() { + scp[0].Close() + wg.Done() + }() + + for { + str, err := scp[0].Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, str, "a") + } + }() + + go func() { + defer func() { + scp[1].Close() + wg.Done() + }() + + time.Sleep(time.Millisecond * 100) + for { + str, err := scp[1].Recv() + if err == io.EOF { + break + } + + if t2.IsZero() { + t2 = time.Now() + } + + assert.NoError(t, err) + assert.Equal(t, str, "a") + } + }() + + wg.Wait() + + assert.True(t, t2.Sub(t1) < time.Millisecond*200) + }) + + t.Run("test one index recv channel blocked and other index closed", func(t *testing.T) { + s := newStream[string](1) + scp := s.asReader().Copy(2) + + go func() { + s.send("a", nil) + time.Sleep(time.Millisecond * 200) + s.send("a", nil) + s.closeSend() + }() + + wg := sync.WaitGroup{} + wg.Add(2) + + buf := scp[0].csr.parent.mem.buf + go func() { + defer func() { + scp[0].Close() + wg.Done() + }() + + for { + str, err := scp[0].Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, str, "a") + } + }() + + go func() { + time.Sleep(time.Millisecond * 100) + scp[1].Close() + scp[1].Close() // try close multiple times + wg.Done() + }() + + wg.Wait() + + assert.Equal(t, 0, buf.Len()) + }) + + t.Run("test long time recv", func(t *testing.T) { + s := newStream[int](2) + n := 1000 + go func() { + for i := 0; i < n; i++ { + s.send(i, nil) + } + + s.closeSend() + }() + + m := 100 + wg := sync.WaitGroup{} + wg.Add(m) + copies := s.asReader().Copy(m) + for i := 0; i < m; i++ { + idx := i + go func() { + cp := copies[idx] + l := 0 + defer func() { + assert.Equal(t, 1000, l) + cp.Close() + wg.Done() + }() + + for { + exp, err := cp.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, exp, l) + l++ + } + }() + } + + wg.Wait() + memo := copies[0].csr.parent.mem + assert.Equal(t, true, memo.hasFinished) + assert.Equal(t, 0, memo.buf.Len()) + }) + + t.Run("test closes", func(t *testing.T) { + s := newStream[int](20) + n := 1000 + go func() { + for i := 0; i < n; i++ { + s.send(i, nil) + } + + s.closeSend() + }() + + m := 100 + wg := sync.WaitGroup{} + wg.Add(m) + + wgEven := sync.WaitGroup{} + wgEven.Add(m / 2) + + copies := s.asReader().Copy(m) + for i := 0; i < m; i++ { + idx := i + go func() { + cp := copies[idx] + l := 0 + defer func() { + cp.Close() + wg.Done() + if idx%2 == 0 { + wgEven.Done() + } + }() + + for { + if idx%2 == 0 && l == idx { + break + } + + exp, err := cp.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, exp, l) + l++ + } + }() + } + + wgEven.Wait() + memo := copies[0].csr.parent.mem + assert.Equal(t, m/2, memo.closedNum) + + wg.Wait() + assert.Equal(t, m, memo.closedNum) + assert.Equal(t, 0, memo.buf.Len()) + }) + + t.Run("test reader do no close", func(t *testing.T) { + s := newStream[int](20) + n := 1000 + go func() { + for i := 0; i < n; i++ { + s.send(i, nil) + } + + s.closeSend() + }() + + m := 4 + wg := sync.WaitGroup{} + wg.Add(m) + + copies := s.asReader().Copy(m) + for i := 0; i < m; i++ { + idx := i + go func() { + cp := copies[idx] + l := 0 + defer func() { + wg.Done() + }() + + for { + exp, err := cp.Recv() + if err == io.EOF { + break + } + + assert.NoError(t, err) + assert.Equal(t, exp, l) + l++ + } + }() + } + + wg.Wait() + memo := copies[0].csr.parent.mem + assert.Equal(t, 0, memo.closedNum) // not closed + assert.Equal(t, 0, memo.buf.Len()) // buff cleared + }) + +} + +func checkStream(s *StreamReader[int]) error { + defer s.Close() + + for i := 0; i < 10; i++ { + chunk, err := s.Recv() + if err != nil { + return err + } + if chunk != i { + return fmt.Errorf("receive err, expected:%d, actual: %d", i, chunk) + } + } + _, err := s.Recv() + if err != io.EOF { + return fmt.Errorf("close chan fail") + } + return nil +} + +func testStreamN(cap, n int) error { + s := newStream[int](cap) + go func() { + for i := 0; i < 10; i++ { + s.send(i, nil) + } + s.closeSend() + }() + + vs := s.asReader().Copy(n) + err := checkStream(vs[0]) + if err != nil { + return err + } + + vs = vs[1].Copy(n) + err = checkStream(vs[0]) + if err != nil { + return err + } + vs = vs[1].Copy(n) + err = checkStream(vs[0]) + if err != nil { + return err + } + return nil +} + +func TestCopy(t *testing.T) { + for i := 0; i < 10; i++ { + for j := 2; j < 10; j++ { + err := testStreamN(i, j) + if err != nil { + t.Fatal(err) + } + } + } +} + +func TestCopy5(t *testing.T) { + s := newStream[int](0) + go func() { + for i := 0; i < 10; i++ { + closed := s.send(i, nil) + if closed { + fmt.Printf("has closed") + } + } + s.closeSend() + }() + vs := s.asReader().Copy(5) + time.Sleep(time.Second) + defer func() { + for _, v := range vs { + v.Close() + } + }() + for i := 0; i < 10; i++ { + chunk, err := vs[0].Recv() + if err != nil { + t.Fatal(err) + } + if chunk != i { + t.Fatalf("receive err, expected:%d, actual: %d", i, chunk) + } + } + _, err := vs[0].Recv() + if err != io.EOF { + t.Fatalf("copied stream reader cannot return EOF") + } + _, err = vs[0].Recv() + if err != io.EOF { + t.Fatalf("copied stream reader cannot return EOF repeatedly") + } +} + +func TestStreamReaderWithConvert(t *testing.T) { + s := newStream[int](2) + + var cntA int + var e error + + convA := func(src int) (int, error) { + if src == 1 { + return 0, fmt.Errorf("mock err") + } + + return src, nil + } + + sta := StreamReaderWithConvert[int, int](s.asReader(), convA) + + s.send(1, nil) + s.send(2, nil) + s.closeSend() + + defer sta.Close() + + for { + item, err := sta.Recv() + if err != nil { + if err == io.EOF { + break + } + + e = err + continue + } + + cntA += item + } + + assert.NotNil(t, e) + assert.Equal(t, cntA, 2) +} + +func TestArrayStreamCombined(t *testing.T) { + asr := &StreamReader[int]{ + typ: readerTypeArray, + ar: &arrayReader[int]{ + arr: []int{0, 1, 2}, + index: 0, + }, + } + + s := newStream[int](3) + for i := 3; i < 6; i++ { + s.send(i, nil) + } + s.closeSend() + + nSR := MergeStreamReaders([]*StreamReader[int]{asr, s.asReader()}) + + record := make([]bool, 6) + for i := 0; i < 6; i++ { + chunk, err := nSR.Recv() + if err != nil { + t.Fatal(err) + } + if record[chunk] { + t.Fatal("record duplicated") + } + record[chunk] = true + } + + _, err := nSR.Recv() + if err != io.EOF { + t.Fatal("reader haven't finish correctly") + } + + for i := range record { + if !record[i] { + t.Fatal("record missing") + } + } +} diff --git a/schema/tool.go b/schema/tool.go new file mode 100644 index 0000000..e398fce --- /dev/null +++ b/schema/tool.go @@ -0,0 +1,188 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "fmt" + + "github.com/getkin/kin-openapi/openapi3" +) + +// DataType is the type of the parameter. +// It must be one of the following values: "object", "number", "integer", "string", "array", "null", "boolean", which is the same as the type of the parameter in OpenAPI v3.0. +type DataType string + +const ( + Object DataType = "object" + Number DataType = "number" + Integer DataType = "integer" + String DataType = "string" + Array DataType = "array" + Null DataType = "null" + Boolean DataType = "boolean" +) + +// ToolInfo is the information of a tool. +type ToolInfo struct { + // The unique name of the tool that clearly communicates its purpose. + Name string + // Used to tell the model how/when/why to use the tool. + // You can provide few-shot examples as a part of the description. + Desc string + + // The parameters the functions accepts (different models may require different parameter types). + // can be described in two ways: + // - use ParameterInfo: schema.NewParamsOneOfByParams(params) + // - use OpenAPIV3: schema.NewParamsOneOfByOpenAPIV3(openAPIV3) + ParamsOneOf +} + +// ParameterInfo is the information of a parameter. +// It is used to describe the parameters of a tool. +type ParameterInfo struct { + // The type of the parameter. + Type DataType + // The element type of the parameter, only for array. + ElemInfo *ParameterInfo + // The sub parameters of the parameter, only for object. + SubParams map[string]*ParameterInfo + // The description of the parameter. + Desc string + // The enum values of the parameter, only for string. + Enum []string + // Whether the parameter is required. + Required bool +} + +// ParamsOneOf is a union of the different methods user can choose which describe a tool's request parameters. +// User must specify one and ONLY one method to describe the parameters. +// 1. use Params: an intuitive way to describe the parameters that covers most of the use-cases. +// 2. use OpenAPIV3: a formal way to describe the parameters that strictly adheres to OpenAPIV3.0 specification. +// See https://github.com/getkin/kin-openapi/blob/master/openapi3/schema.go. +type ParamsOneOf struct { + // deprecated: use NewParamsOneOfByParams instead, Params will no longer be exported in the future. + Params map[string]*ParameterInfo + + // deprecated: use NewParamsOneOfByOpenAPIV3 instead, OpenAPIV3 will no longer be exported in the future. + OpenAPIV3 *openapi3.Schema +} + +// NewParamsOneOfByParams creates a ParamsOneOf with map[string]*ParameterInfo. +func NewParamsOneOfByParams(params map[string]*ParameterInfo) ParamsOneOf { + return ParamsOneOf{ + Params: params, + } +} + +// NewParamsOneOfByOpenAPIV3 creates a ParamsOneOf with *openapi3.Schema. +func NewParamsOneOfByOpenAPIV3(openAPIV3 *openapi3.Schema) ParamsOneOf { + return ParamsOneOf{ + OpenAPIV3: openAPIV3, + } +} + +// ToOpenAPIV3 parses ParamsOneOf, converts the parameter description that user actually provides, into the format ready to be passed to Model. +func (p ParamsOneOf) ToOpenAPIV3() (*openapi3.Schema, error) { + var ( + useParameterInfo = p.Params != nil + useOpenAPIV3 = p.OpenAPIV3 != nil + ) + + if !useParameterInfo && !useOpenAPIV3 { + return nil, fmt.Errorf("ParamsOneOf needs to have at least one method to describe the parameters") + } + + if useParameterInfo && useOpenAPIV3 { + return nil, fmt.Errorf("ParamsOneOf can only have one method to describe the parameters, but not multiple methods") + } + + if p.Params != nil { + sc := &openapi3.Schema{ + Properties: make(map[string]*openapi3.SchemaRef, len(p.Params)), + Type: openapi3.TypeObject, + Required: make([]string, 0, len(p.Params)), + } + + for k := range p.Params { + v := p.Params[k] + sc.Properties[k] = paramInfoToJSONSchema(v) + if v.Required { + sc.Required = append(sc.Required, k) + } + } + + return sc, nil + } + + return p.OpenAPIV3, nil +} + +func paramInfoToJSONSchema(paramInfo *ParameterInfo) *openapi3.SchemaRef { + var types string + switch paramInfo.Type { + case Null: + types = "null" + case Boolean: + types = openapi3.TypeBoolean + case Integer: + types = openapi3.TypeInteger + case Number: + types = openapi3.TypeNumber + case String: + types = openapi3.TypeString + case Array: + types = openapi3.TypeArray + case Object: + types = openapi3.TypeObject + } + + js := &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: types, + Description: paramInfo.Desc, + }, + } + + if len(paramInfo.Enum) > 0 { + js.Value.Enum = make([]any, 0, len(paramInfo.Enum)) + for _, enum := range paramInfo.Enum { + js.Value.Enum = append(js.Value.Enum, enum) + } + } + + if paramInfo.ElemInfo != nil { + js.Value.Items = paramInfoToJSONSchema(paramInfo.ElemInfo) + } + + if len(paramInfo.SubParams) > 0 { + required := make([]string, 0, len(paramInfo.SubParams)) + js.Value.Properties = make(map[string]*openapi3.SchemaRef, len(paramInfo.SubParams)) + for k, v := range paramInfo.SubParams { + item := paramInfoToJSONSchema(v) + + js.Value.Properties[k] = item + + if v.Required { + required = append(required, k) + } + } + + js.Value.Required = required + } + + return js +} diff --git a/schema/tool_test.go b/schema/tool_test.go new file mode 100644 index 0000000..fe33b96 --- /dev/null +++ b/schema/tool_test.go @@ -0,0 +1,99 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "testing" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/smartystreets/goconvey/convey" +) + +func TestParamsOneOfToJSONSchema(t *testing.T) { + convey.Convey("TestParamsOneOfToJSONSchema", t, func() { + var ( + oneOf ParamsOneOf + converted any + err error + ) + + convey.Convey("user provides no option if ParamsOneOf", func() { + _, err = oneOf.ToOpenAPIV3() + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "ParamsOneOf needs to have at least one method to describe the parameters") + }) + + convey.Convey("user provides multiple options in ParamsOneOf", func() { + oneOf.Params = make(map[string]*ParameterInfo) + oneOf.OpenAPIV3 = &openapi3.Schema{} + _, err = oneOf.ToOpenAPIV3() + convey.So(err, convey.ShouldNotBeNil) + convey.So(err.Error(), convey.ShouldContainSubstring, "ParamsOneOf can only have one method to describe the parameters, but not multiple methods") + }) + + convey.Convey("user provides openAPIV3.0 json schema directly, use what the user provides", func() { + oneOf.OpenAPIV3 = &openapi3.Schema{ + Type: openapi3.TypeString, + Description: "this is the only argument", + } + converted, err = oneOf.ToOpenAPIV3() + convey.So(err, convey.ShouldBeNil) + convey.So(converted, convey.ShouldResemble, oneOf.OpenAPIV3) + }) + + convey.Convey("user provides map[string]ParameterInfo, converts to json schema", func() { + oneOf.Params = map[string]*ParameterInfo{ + "arg1": { + Type: String, + Desc: "this is the first argument", + Required: true, + Enum: []string{"1", "2"}, + }, + "arg2": { + Type: Object, + Desc: "this is the second argument", + SubParams: map[string]*ParameterInfo{ + "sub_arg1": { + Type: String, + Desc: "this is the sub argument", + Required: true, + Enum: []string{"1", "2"}, + }, + "sub_arg2": { + Type: String, + Desc: "this is the sub argument 2", + }, + }, + Required: true, + }, + "arg3": { + Type: Array, + Desc: "this is the third argument", + ElemInfo: &ParameterInfo{ + Type: String, + Desc: "this is the element of the third argument", + Required: true, + Enum: []string{"1", "2"}, + }, + Required: true, + }, + } + converted, err = oneOf.ToOpenAPIV3() + convey.So(err, convey.ShouldBeNil) + }) + }) +} diff --git a/utils/generic/generic.go b/utils/generic/generic.go new file mode 100644 index 0000000..2287280 --- /dev/null +++ b/utils/generic/generic.go @@ -0,0 +1,67 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generic + +import ( + "reflect" +) + +// NewInstance create an instance of the given type T. +// the main purpose of this function is to create an instance of a type, can handle the type of T is a pointer or not. +// eg. NewInstance[int] returns 0. +// eg. NewInstance[*int] returns *0 (will be ptr of 0, not nil!). +func NewInstance[T any]() T { + + typ := TypeOf[T]() + + switch typ.Kind() { + case reflect.Map: + return reflect.MakeMap(typ).Interface().(T) + case reflect.Slice, reflect.Array: + return reflect.MakeSlice(typ, 0, 0).Interface().(T) + case reflect.Ptr: + typ = typ.Elem() + origin := reflect.New(typ) + inst := origin + + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + inst = inst.Elem() + inst.Set(reflect.New(typ)) + } + + return origin.Interface().(T) + default: + var t T + return t + } +} + +// TypeOf returns the type of T. +// eg. TypeOf[int] returns reflect.TypeOf(int). +// eg. TypeOf[*int] returns reflect.TypeOf(*int). +func TypeOf[T any]() reflect.Type { + return reflect.TypeOf((*T)(nil)).Elem() +} + +// PtrOf returns a pointer of T. +// useful when you want to get a pointer of a value, in some config, for example. +// eg. PtrOf[int] returns *int. +// eg. PtrOf[*int] returns **int. +func PtrOf[T any](v T) *T { + return &v +} diff --git a/utils/generic/generic_test.go b/utils/generic/generic_test.go new file mode 100644 index 0000000..97e3fee --- /dev/null +++ b/utils/generic/generic_test.go @@ -0,0 +1,88 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewInstance(t *testing.T) { + t.Run("struct", func(t *testing.T) { + type Test struct{} + + inst := NewInstance[Test]() + + assert.IsType(t, Test{}, inst) + }) + + t.Run("pointer", func(t *testing.T) { + type Test struct{} + + inst := NewInstance[*Test]() + + assert.IsType(t, &Test{}, inst) + }) + + t.Run("interface", func(t *testing.T) { + type Test interface{} + + inst := NewInstance[Test]() + assert.IsType(t, Test(nil), inst) + }) + + t.Run("pointer of pointer of pointer", func(t *testing.T) { + type Test struct { + Value int + } + + inst := NewInstance[***Test]() + + ptr := &Test{} + ptrOfPtr := &ptr + assert.NotNil(t, inst) + assert.NotNil(t, *inst) + assert.IsType(t, ptrOfPtr, *inst) + assert.NotNil(t, **inst) + assert.Equal(t, Test{Value: 0}, ***inst) + }) + + t.Run("primitive_map", func(t *testing.T) { + inst := NewInstance[map[string]any]() + assert.NotNil(t, inst) + inst["a"] = 1 + assert.Equal(t, map[string]any{"a": 1}, inst) + }) + + t.Run("primitive_slice", func(t *testing.T) { + inst := NewInstance[[]int]() + assert.NotNil(t, inst) + inst = append(inst, 1) + assert.Equal(t, []int{1}, inst) + }) + + t.Run("primitive_string", func(t *testing.T) { + inst := NewInstance[string]() + assert.Equal(t, "", inst) + }) + + t.Run("primitive_int64", func(t *testing.T) { + inst := NewInstance[int64]() + assert.Equal(t, int64(0), inst) + }) +} diff --git a/utils/generic/type_name.go b/utils/generic/type_name.go new file mode 100644 index 0000000..4829e74 --- /dev/null +++ b/utils/generic/type_name.go @@ -0,0 +1,71 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generic + +import ( + "reflect" + "regexp" + "runtime" + "strings" +) + +var ( + regOfAnonymousFunc = regexp.MustCompile(`^func[0-9]+`) + regOfNumber = regexp.MustCompile(`^\d+$`) +) + +// ParseTypeName returns the name of the type of the given value. +// It takes a reflect.Value as input and processes it to determine the underlying type. If the type is a pointer, it dereferences it to get the actual type. (the optimization of this function) +// eg: ParseTypeName(reflect.ValueOf(&&myStruct{})) returns "myStruct" (not "**myStruct") +// +// If the type is a function, it retrieves the function's name, handling both named and anonymous functions. +// examples of function paths: [package_path].[receiver_type].[func_name] +// named function: xxx/utils.ParseTypeName +// method: xxx/utils.(*MyStruct).Method +// anonymous function: xxx/utils.TestParseTypeName.func6.1 +func ParseTypeName(val reflect.Value) string { + typ := val.Type() + + for typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + + if typ.Kind() == reflect.Func { + funcName := runtime.FuncForPC(val.Pointer()).Name() + idx := strings.LastIndex(funcName, ".") + if idx < 0 { + if funcName != "" { + return funcName + } + return "" + } + + name := funcName[idx+1:] + + if regOfAnonymousFunc.MatchString(name) { + return "" + } + + if regOfNumber.MatchString(name) { + return "" + } + + return name + } + + return typ.Name() +} diff --git a/utils/generic/type_name_test.go b/utils/generic/type_name_test.go new file mode 100644 index 0000000..4230174 --- /dev/null +++ b/utils/generic/type_name_test.go @@ -0,0 +1,86 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generic + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseTypeName(t *testing.T) { + t.Run("named_struct", func(t *testing.T) { + type OpenAI struct{} + model := &OpenAI{} + name := ParseTypeName(reflect.Indirect(reflect.ValueOf(model))) + assert.Equal(t, "OpenAI", name) + }) + + t.Run("anonymous_struct", func(t *testing.T) { + model := &struct{}{} + name := ParseTypeName(reflect.ValueOf(model)) + assert.Equal(t, "", name) + }) + + t.Run("anonymous_struct_from_func", func(t *testing.T) { + model := genStruct() + name := ParseTypeName(reflect.ValueOf(model)) + assert.Equal(t, "", name) + }) + + t.Run("named_interface", func(t *testing.T) { + type OpenAI interface{} + model := OpenAI(&struct{}{}) + name := ParseTypeName(reflect.ValueOf(model)) + assert.Equal(t, "", name) + + name = ParseTypeName(reflect.ValueOf((*OpenAI)(nil))) + assert.Equal(t, "OpenAI", name) + }) + + t.Run("named_function", func(t *testing.T) { + f := genOpenAI + name := ParseTypeName(reflect.ValueOf(f)) + assert.Equal(t, "genOpenAI", name) + }) + + t.Run("anonymous_function", func(t *testing.T) { + f := genAnonymousFunc() + name := ParseTypeName(reflect.ValueOf(f)) + assert.Equal(t, "", name) + + ff := func(n string) { + _ = n + } + + name = ParseTypeName(reflect.ValueOf(ff)) + assert.Equal(t, "", name) + }) +} + +func genStruct() *struct{} { + return &struct{}{} +} + +func genOpenAI() {} + +func genAnonymousFunc() func(n string) { + return func(n string) { + _ = n + } +} diff --git a/utils/safe/panic.go b/utils/safe/panic.go new file mode 100644 index 0000000..16bc1dc --- /dev/null +++ b/utils/safe/panic.go @@ -0,0 +1,40 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package safe + +import ( + "fmt" +) + +type panicErr struct { + info any + stack []byte +} + +func (p *panicErr) Error() string { + return fmt.Sprintf("panic error: %v, \nstack: %s", p.info, string(p.stack)) +} + +// NewPanicErr creates a new panic error. +// panicErr is a wrapper of panic info and stack trace. +// it implements the error interface, can print error message of info and stack trace. +func NewPanicErr(info any, stack []byte) error { + return &panicErr{ + info: info, + stack: stack, + } +}