diff --git a/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs b/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs index 9c62b5eb..3f25caa5 100644 --- a/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs +++ b/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs @@ -32,7 +32,7 @@ public abstract class BaseCombineDocumentsChain(BaseCombineDocumentsChainInput f /// public override async Task CallAsync(IChainValues values) { - var docs = values.Value[InputKey]; + var docs = values.Value["input_documents"]; //Other keys are assumed to be needed for LLM prediction var otherKeys = values.Value diff --git a/src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs b/src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs index d33f3cf0..918b7f8f 100644 --- a/src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs +++ b/src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs @@ -13,7 +13,7 @@ public class LlmChainInput( { public BasePromptTemplate Prompt { get; set; } = prompt; public IChatModel Llm { get; set; } = llm; - public string OutputKey { get; set; } + public string OutputKey { get; set; } = "text"; public bool? Verbose { get; set; } public CallbackManager CallbackManager { get; set; } public BaseMemory? Memory { get; set; } = memory; diff --git a/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs b/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs index d22b648d..6b8465df 100644 --- a/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs +++ b/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChain.cs @@ -36,14 +36,15 @@ public abstract class BaseRetrievalQaChain(BaseRetrievalQaChainInput fields) : B /// public override async Task CallAsync(IChainValues values) { + var question = values.Value[_inputKey].ToString(); var docs = (await GetDocsAsync(question)).ToList(); var input = new Dictionary { - [fields.DocumentsKey] = docs, - [_inputKey] = question + ["input_documents"] = docs, + [_inputKey]= question }; var answer = await _combineDocumentsChain.Run(input); diff --git a/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChainInput.cs b/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChainInput.cs index 228a97d9..3c3dc7bb 100644 --- a/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChainInput.cs +++ b/src/libs/LangChain.Core/Chains/RetrievalQA/BaseRetrievalQaChainInput.cs @@ -13,7 +13,6 @@ public class BaseRetrievalQaChainInput(BaseCombineDocumentsChain combineDocument public bool ReturnSourceDocuments { get; set; } public string InputKey { get; set; } = "question"; - public string DocumentsKey { get; set; } = "input_documents"; public string OutputKey { get; set; } = "output_text"; public bool? Verbose { get; set; } public CallbackManager? CallbackManager { get; set; } diff --git a/src/libs/LangChain.Core/Indexes/VectorStoreIndexWrapper.cs b/src/libs/LangChain.Core/Indexes/VectorStoreIndexWrapper.cs index bd81e404..bbce0b99 100644 --- a/src/libs/LangChain.Core/Indexes/VectorStoreIndexWrapper.cs +++ b/src/libs/LangChain.Core/Indexes/VectorStoreIndexWrapper.cs @@ -16,15 +16,14 @@ public VectorStoreIndexWrapper(VectorStore vectorStore) _vectorStore = vectorStore; } - public Task QueryAsync(string question, BaseCombineDocumentsChain llm, string documentsKey= "input_documents", string questionKey="question", string outputKey= "output_text") + public Task QueryAsync(string question, BaseCombineDocumentsChain llm, string inputKey= "question", string outputKey= "output_text") { var chain = new RetrievalQaChain( new RetrievalQaChainInput( llm, _vectorStore.AsRetreiver()) { - InputKey= questionKey, - DocumentsKey= documentsKey, + InputKey= inputKey, OutputKey= outputKey, } ); diff --git a/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs b/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs index 594a611f..45a39f88 100644 --- a/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs +++ b/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs @@ -1,4 +1,5 @@ -using LangChain.Chains.CombineDocuments; +using LangChain.Abstractions.Embeddings.Base; +using LangChain.Chains.CombineDocuments; using LangChain.Chains.LLM; using LangChain.Databases; using LangChain.Databases.InMemory; @@ -9,6 +10,7 @@ using LangChain.Providers.Downloader; using LangChain.Providers.LLamaSharp; using LangChain.TextSplitters; +using Microsoft.SemanticKernel.AI.Embeddings; namespace LangChain.Providers.LLamaSharp.IntegrationTests; @@ -103,27 +105,59 @@ public void EmbeddingsTestWithInMemory() Assert.AreEqual("My dog name is Bob", closest.PageContent); } - - [TestMethod] -#if CONTINUOUS_INTEGRATION_BUILD - [Ignore] -#endif - public void DocumentsQuestionAnsweringTest() + + + #region Helpers + IEmbeddings CreateEmbeddings() { - // setup var embeddings = new LLamaSharpEmbeddings(new LLamaSharpConfiguration { PathToModelFile = ModelPath, Temperature = 0 }); + return embeddings; + } + + IChatModel CreateInstructionModel() + { var model = new LLamaSharpModelInstruction(new LLamaSharpConfiguration { PathToModelFile = ModelPath, Temperature = 0 }); + return model; + } + + VectorStoreIndexWrapper CreateVectorStoreIndex(IEmbeddings embeddings, string[] texts) + { InMemoryVectorStore vectorStore = new InMemoryVectorStore(embeddings); + var textSplitter = new CharacterTextSplitter(); + VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter); + var index = indexCreator.FromDocumentsAsync(texts.Select(x => new Document(x)).ToList()).Result; + return index; + } + + PromptTemplate CreatePromptTemplate() + { + string prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:"; + var template = new PromptTemplate(new PromptTemplateInput(prompt, new List() { "context", "question" })); + return template; + } + #endregion + + + [TestMethod] +#if CONTINUOUS_INTEGRATION_BUILD + [Ignore] +#endif + public void DocumentsQuestionAnsweringTest() + { + // setup + var embeddings = CreateEmbeddings(); + var model = CreateInstructionModel(); + string[] texts = new string[] { "I spent entire day watching TV", @@ -131,30 +165,25 @@ public void DocumentsQuestionAnsweringTest() "This icecream is delicious", "It is cold in space" }; - var textSplitter = new CharacterTextSplitter(); - VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter); - var index=indexCreator.FromDocumentsAsync(texts.Select(x=>new Document(x)).ToList()).Result; + + var index = CreateVectorStoreIndex(embeddings, texts); + var template = CreatePromptTemplate(); + + var chain = new LlmChain(new LlmChainInput(model, template)); - string prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:"; - var template = new PromptTemplate(new PromptTemplateInput(prompt, new List() { "context", "question" })); - var chain = new LlmChain(new LlmChainInput(model, template) - { - OutputKey = "text" - }); var stuffDocumentsChain = new StuffDocumentsChain(new StuffDocumentsChainInput(chain) { - InputKey = "context", - DocumentVariableName = "context", - OutputKey = "text", + DocumentVariableName = "context", // variable name in prompt template + // for the documents }); // test - var query = "What is the good name for a pet? Tell me only the name, no explanations."; - var answer=index.QueryAsync(query, stuffDocumentsChain, - documentsKey: "context", - questionKey: "question", - outputKey: "text").Result; + var question = "What is the good name for a pet?"; + var answer=index.QueryAsync(question, stuffDocumentsChain, + inputKey:"question" // variable name in prompt template for the question + // it would be passed by to stuffDocumentsChain + ).Result;