Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfixes and massive simplification of DocumentQnA test #51

Merged
merged 5 commits into from
Nov 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public abstract class BaseCombineDocumentsChain(BaseCombineDocumentsChainInput f
/// <returns></returns>
public override async Task<IChainValues> CallAsync(IChainValues values)
{
var docs = values.Value[InputKey];
var docs = values.Value["input_documents"];

//Other keys are assumed to be needed for LLM prediction
var otherKeys = values.Value
Expand Down
2 changes: 1 addition & 1 deletion src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class LlmChainInput(
{
public BasePromptTemplate Prompt { get; set; } = prompt;
public IChatModel Llm { get; set; } = llm;
public string OutputKey { get; set; }
public string OutputKey { get; set; } = "text";
public bool? Verbose { get; set; }
public CallbackManager CallbackManager { get; set; }
public BaseMemory? Memory { get; set; } = memory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ public abstract class BaseRetrievalQaChain(BaseRetrievalQaChainInput fields) : B
/// <exception cref="NotImplementedException"></exception>
public override async Task<IChainValues> CallAsync(IChainValues values)
{

var question = values.Value[_inputKey].ToString();

var docs = (await GetDocsAsync(question)).ToList();

var input = new Dictionary<string, object>
{
[fields.DocumentsKey] = docs,
[_inputKey] = question
["input_documents"] = docs,
[_inputKey]= question
};

var answer = await _combineDocumentsChain.Run(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public class BaseRetrievalQaChainInput(BaseCombineDocumentsChain combineDocument
public bool ReturnSourceDocuments { get; set; }

public string InputKey { get; set; } = "question";
public string DocumentsKey { get; set; } = "input_documents";
public string OutputKey { get; set; } = "output_text";
public bool? Verbose { get; set; }
public CallbackManager? CallbackManager { get; set; }
Expand Down
5 changes: 2 additions & 3 deletions src/libs/LangChain.Core/Indexes/VectorStoreIndexWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ public VectorStoreIndexWrapper(VectorStore vectorStore)
_vectorStore = vectorStore;
}

public Task<string?> QueryAsync(string question, BaseCombineDocumentsChain llm, string documentsKey= "input_documents", string questionKey="question", string outputKey= "output_text")
public Task<string?> QueryAsync(string question, BaseCombineDocumentsChain llm, string inputKey= "question", string outputKey= "output_text")
{
var chain = new RetrievalQaChain(
new RetrievalQaChainInput(
llm,
_vectorStore.AsRetreiver())
{
InputKey= questionKey,
DocumentsKey= documentsKey,
InputKey= inputKey,
OutputKey= outputKey,
}
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LangChain.Chains.CombineDocuments;
using LangChain.Abstractions.Embeddings.Base;
using LangChain.Chains.CombineDocuments;
using LangChain.Chains.LLM;
using LangChain.Databases;
using LangChain.Databases.InMemory;
Expand All @@ -9,6 +10,7 @@
using LangChain.Providers.Downloader;
using LangChain.Providers.LLamaSharp;
using LangChain.TextSplitters;
using Microsoft.SemanticKernel.AI.Embeddings;

namespace LangChain.Providers.LLamaSharp.IntegrationTests;

Expand Down Expand Up @@ -103,58 +105,85 @@ public void EmbeddingsTestWithInMemory()

Assert.AreEqual("My dog name is Bob", closest.PageContent);
}

[TestMethod]
#if CONTINUOUS_INTEGRATION_BUILD
[Ignore]
#endif
public void DocumentsQuestionAnsweringTest()


#region Helpers
IEmbeddings CreateEmbeddings()
{
// setup
var embeddings = new LLamaSharpEmbeddings(new LLamaSharpConfiguration
{
PathToModelFile = ModelPath,
Temperature = 0
});
return embeddings;

}

IChatModel CreateInstructionModel()
{
var model = new LLamaSharpModelInstruction(new LLamaSharpConfiguration
{
PathToModelFile = ModelPath,
Temperature = 0
});
return model;

}

VectorStoreIndexWrapper CreateVectorStoreIndex(IEmbeddings embeddings, string[] texts)
{
InMemoryVectorStore vectorStore = new InMemoryVectorStore(embeddings);
var textSplitter = new CharacterTextSplitter();
VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter);
var index = indexCreator.FromDocumentsAsync(texts.Select(x => new Document(x)).ToList()).Result;
return index;
}

PromptTemplate CreatePromptTemplate()
{
string prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:";
var template = new PromptTemplate(new PromptTemplateInput(prompt, new List<string>() { "context", "question" }));
return template;
}
#endregion


[TestMethod]
#if CONTINUOUS_INTEGRATION_BUILD
[Ignore]
#endif
public void DocumentsQuestionAnsweringTest()
{
// setup
var embeddings = CreateEmbeddings();
var model = CreateInstructionModel();

string[] texts = new string[]
{
"I spent entire day watching TV",
"My dog name is Bob",
"This icecream is delicious",
"It is cold in space"
};
var textSplitter = new CharacterTextSplitter();
VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter);
var index=indexCreator.FromDocumentsAsync(texts.Select(x=>new Document(x)).ToList()).Result;

var index = CreateVectorStoreIndex(embeddings, texts);
var template = CreatePromptTemplate();

var chain = new LlmChain(new LlmChainInput(model, template));

string prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:";
var template = new PromptTemplate(new PromptTemplateInput(prompt, new List<string>() { "context", "question" }));
var chain = new LlmChain(new LlmChainInput(model, template)
{
OutputKey = "text"
});
var stuffDocumentsChain = new StuffDocumentsChain(new StuffDocumentsChainInput(chain)
{
InputKey = "context",
DocumentVariableName = "context",
OutputKey = "text",
DocumentVariableName = "context", // variable name in prompt template
// for the documents
});


// test
var query = "What is the good name for a pet? Tell me only the name, no explanations.";
var answer=index.QueryAsync(query, stuffDocumentsChain,
documentsKey: "context",
questionKey: "question",
outputKey: "text").Result;
var question = "What is the good name for a pet?";
var answer=index.QueryAsync(question, stuffDocumentsChain,
inputKey:"question" // variable name in prompt template for the question
// it would be passed by to stuffDocumentsChain
).Result;



Expand Down
Loading