Skip to content

Commit

Permalink
feat: bugfixes and massive simplification of DocumentQnA test (#51)
Browse files Browse the repository at this point in the history
* VectorStoreIndexCreator, error fixes, tweeking, Documents QnA test

* changing Web and Pdf sources to use propper document

* bugfixes and massive simplification of DocumentQnA test

* bugfix
  • Loading branch information
TesAnti authored Nov 5, 2023
1 parent cbd1a9c commit fcb2cf1
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public abstract class BaseCombineDocumentsChain(BaseCombineDocumentsChainInput f
/// <returns></returns>
public override async Task<IChainValues> CallAsync(IChainValues values)
{
var docs = values.Value[InputKey];
var docs = values.Value["input_documents"];

//Other keys are assumed to be needed for LLM prediction
var otherKeys = values.Value
Expand Down
2 changes: 1 addition & 1 deletion src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class LlmChainInput(
{
public BasePromptTemplate Prompt { get; set; } = prompt;
public IChatModel Llm { get; set; } = llm;
public string OutputKey { get; set; }
public string OutputKey { get; set; } = "text";
public bool? Verbose { get; set; }
public CallbackManager CallbackManager { get; set; }
public BaseMemory? Memory { get; set; } = memory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ public abstract class BaseRetrievalQaChain(BaseRetrievalQaChainInput fields) : B
/// <exception cref="NotImplementedException"></exception>
public override async Task<IChainValues> CallAsync(IChainValues values)
{

var question = values.Value[_inputKey].ToString();

var docs = (await GetDocsAsync(question)).ToList();

var input = new Dictionary<string, object>
{
[fields.DocumentsKey] = docs,
[_inputKey] = question
["input_documents"] = docs,
[_inputKey]= question
};

var answer = await _combineDocumentsChain.Run(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public class BaseRetrievalQaChainInput(BaseCombineDocumentsChain combineDocument
public bool ReturnSourceDocuments { get; set; }

public string InputKey { get; set; } = "question";
public string DocumentsKey { get; set; } = "input_documents";
public string OutputKey { get; set; } = "output_text";
public bool? Verbose { get; set; }
public CallbackManager? CallbackManager { get; set; }
Expand Down
5 changes: 2 additions & 3 deletions src/libs/LangChain.Core/Indexes/VectorStoreIndexWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ public VectorStoreIndexWrapper(VectorStore vectorStore)
_vectorStore = vectorStore;
}

public Task<string?> QueryAsync(string question, BaseCombineDocumentsChain llm, string documentsKey= "input_documents", string questionKey="question", string outputKey= "output_text")
public Task<string?> QueryAsync(string question, BaseCombineDocumentsChain llm, string inputKey= "question", string outputKey= "output_text")
{
var chain = new RetrievalQaChain(
new RetrievalQaChainInput(
llm,
_vectorStore.AsRetreiver())
{
InputKey= questionKey,
DocumentsKey= documentsKey,
InputKey= inputKey,
OutputKey= outputKey,
}
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LangChain.Chains.CombineDocuments;
using LangChain.Abstractions.Embeddings.Base;
using LangChain.Chains.CombineDocuments;
using LangChain.Chains.LLM;
using LangChain.Databases;
using LangChain.Databases.InMemory;
Expand All @@ -9,6 +10,7 @@
using LangChain.Providers.Downloader;
using LangChain.Providers.LLamaSharp;
using LangChain.TextSplitters;
using Microsoft.SemanticKernel.AI.Embeddings;

namespace LangChain.Providers.LLamaSharp.IntegrationTests;

Expand Down Expand Up @@ -103,58 +105,85 @@ public void EmbeddingsTestWithInMemory()

Assert.AreEqual("My dog name is Bob", closest.PageContent);
}

[TestMethod]
#if CONTINUOUS_INTEGRATION_BUILD
[Ignore]
#endif
public void DocumentsQuestionAnsweringTest()


#region Helpers
IEmbeddings CreateEmbeddings()
{
// setup
var embeddings = new LLamaSharpEmbeddings(new LLamaSharpConfiguration
{
PathToModelFile = ModelPath,
Temperature = 0
});
return embeddings;

}

IChatModel CreateInstructionModel()
{
var model = new LLamaSharpModelInstruction(new LLamaSharpConfiguration
{
PathToModelFile = ModelPath,
Temperature = 0
});
return model;

}

VectorStoreIndexWrapper CreateVectorStoreIndex(IEmbeddings embeddings, string[] texts)
{
InMemoryVectorStore vectorStore = new InMemoryVectorStore(embeddings);
var textSplitter = new CharacterTextSplitter();
VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter);
var index = indexCreator.FromDocumentsAsync(texts.Select(x => new Document(x)).ToList()).Result;
return index;
}

PromptTemplate CreatePromptTemplate()
{
string prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:";
var template = new PromptTemplate(new PromptTemplateInput(prompt, new List<string>() { "context", "question" }));
return template;
}
#endregion


[TestMethod]
#if CONTINUOUS_INTEGRATION_BUILD
[Ignore]
#endif
public void DocumentsQuestionAnsweringTest()
{
// setup
var embeddings = CreateEmbeddings();
var model = CreateInstructionModel();

string[] texts = new string[]
{
"I spent entire day watching TV",
"My dog name is Bob",
"This icecream is delicious",
"It is cold in space"
};
var textSplitter = new CharacterTextSplitter();
VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter);
var index=indexCreator.FromDocumentsAsync(texts.Select(x=>new Document(x)).ToList()).Result;

var index = CreateVectorStoreIndex(embeddings, texts);
var template = CreatePromptTemplate();

var chain = new LlmChain(new LlmChainInput(model, template));

string prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n\n{context}\n\nQuestion: {question}\nHelpful Answer:";
var template = new PromptTemplate(new PromptTemplateInput(prompt, new List<string>() { "context", "question" }));
var chain = new LlmChain(new LlmChainInput(model, template)
{
OutputKey = "text"
});
var stuffDocumentsChain = new StuffDocumentsChain(new StuffDocumentsChainInput(chain)
{
InputKey = "context",
DocumentVariableName = "context",
OutputKey = "text",
DocumentVariableName = "context", // variable name in prompt template
// for the documents
});


// test
var query = "What is the good name for a pet? Tell me only the name, no explanations.";
var answer=index.QueryAsync(query, stuffDocumentsChain,
documentsKey: "context",
questionKey: "question",
outputKey: "text").Result;
var question = "What is the good name for a pet?";
var answer=index.QueryAsync(question, stuffDocumentsChain,
inputKey:"question" // variable name in prompt template for the question
// it would be passed by to stuffDocumentsChain
).Result;



Expand Down

0 comments on commit fcb2cf1

Please sign in to comment.