Skip to content

Commit

Permalink
feat: Сreated various Memory classes from LangChain documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
HavenDV committed Oct 19, 2023
1 parent ced4b45 commit 41e916a
Show file tree
Hide file tree
Showing 16 changed files with 156 additions and 44 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2023 tryAGI, James Eastha, Uros Biberdzicand, @CrazyBaran and Contributors
Copyright (c) 2023 tryAGI, HavenDV, TesAnti, James Eastha, Uros Biberdzicand, @CrazyBaran, scule and Contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
27 changes: 12 additions & 15 deletions src/libs/LangChain.Core/Chains/LLM/LLMChain.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using LangChain.Abstractions.Schema;
using LangChain.Base;
using LangChain.Callback;
using LangChain.Memory;
using LangChain.Prompts.Base;
using LangChain.Providers;
using LangChain.Schema;
Expand All @@ -12,11 +13,13 @@ namespace LangChain.Chains.LLM;
using System.Collections.Generic;
using System.Threading.Tasks;

public class LlmChain : BaseChain, ILlmChainInput
public class LlmChain(LlmChainInput fields) : BaseChain, ILlmChainInput
{
public BasePromptTemplate Prompt { get; }
public IChatModel Llm { get; }
public string OutputKey { get; set; }
public BasePromptTemplate Prompt { get; } = fields.Prompt;
public IChatModel Llm { get; } = fields.Llm;
public BaseMemory? Memory { get; } = fields.Memory;
public string OutputKey { get; set; } = fields.OutputKey;

public override string ChainType() => "llm_chain";

public bool? Verbose { get; set; }
Expand All @@ -25,13 +28,6 @@ public class LlmChain : BaseChain, ILlmChainInput
public override string[] InputKeys => Prompt.InputVariables.ToArray();
public override string[] OutputKeys => new[] { OutputKey };

public LlmChain(LlmChainInput fields)
{
Prompt = fields.Prompt;
Llm = fields.Llm;
OutputKey = fields.OutputKey;
}

protected async Task<object?> GetFinalOutput(
List<Generation> generations,
BasePromptValue promptValue,
Expand All @@ -56,15 +52,15 @@ public override async Task<IChainValues> CallAsync(IChainValues values)
stop = stopList;
}

BasePromptValue promptValue = await Prompt.FormatPromptValue(new InputValues(values.Value));
var chatMessages = promptValue.ToChatMessages();
var promptValue = await Prompt.FormatPromptValue(new InputValues(values.Value));
var chatMessages = promptValue.ToChatMessages().WithHistory(Memory);
if (Verbose == true)
{

Console.WriteLine(string.Join("\n\n", chatMessages));
Console.WriteLine("\n".PadLeft(Console.WindowWidth, '>'));
}
var response = await Llm.GenerateAsync(new ChatRequest(promptValue.ToChatMessages(), stop));
var response = await Llm.GenerateAsync(new ChatRequest(chatMessages, stop));
if (Verbose == true)
{

Expand All @@ -77,10 +73,11 @@ public override async Task<IChainValues> CallAsync(IChainValues values)

return new ChainValues(OutputKey,response.Messages.Last().Content);
}

public async Task<object> Predict(ChainValues values)
{
var output = await CallAsync(values);
return output.Value[OutputKey];
}

}
19 changes: 9 additions & 10 deletions src/libs/LangChain.Core/Chains/LLM/LLMChainInput.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
using LangChain.Base;
using LangChain.Callback;
using LangChain.Memory;
using LangChain.Prompts.Base;
using LangChain.Providers;

namespace LangChain.Chains.LLM;

public class LlmChainInput : ILlmChainInput
public class LlmChainInput(
IChatModel llm,
BasePromptTemplate prompt,
BaseMemory? memory = null)
: ILlmChainInput
{
public LlmChainInput(IChatModel llm, BasePromptTemplate prompt)
{
this.Llm = llm;
this.Prompt = prompt;
}

public BasePromptTemplate Prompt { get; set; }
public IChatModel Llm { get; set; }
public BasePromptTemplate Prompt { get; set; } = prompt;
public IChatModel Llm { get; set; } = llm;
public string OutputKey { get; set; }
public bool? Verbose { get; set; }
public CallbackManager CallbackManager { get; set; }
public BaseMemory? Memory { get; set; } = memory;
}
10 changes: 6 additions & 4 deletions src/libs/LangChain.Core/Docstore/Document.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace LangChain.Docstore;
using System.Globalization;

namespace LangChain.Docstore;

/// <summary>
/// Class for storing document
Expand Down Expand Up @@ -41,9 +43,9 @@ public string Summary()
public string Lookup(string searchString)
{
// if there is a new search string, reset the index
if (searchString.ToLower() != LookupStr)
if (searchString.ToLower(CultureInfo.InvariantCulture) != LookupStr)

Check warning on line 46 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)

Check warning on line 46 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)

Check warning on line 46 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)

Check warning on line 46 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)
{
LookupStr = searchString.ToLower();
LookupStr = searchString.ToLower(CultureInfo.InvariantCulture);

Check warning on line 48 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)

Check warning on line 48 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)

Check warning on line 48 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)
LookupIndex = 0;
}
else
Expand All @@ -52,7 +54,7 @@ public string Lookup(string searchString)
}

// get all the paragraphs that contain the search string
var lookups = Paragraphs().Where(p => p.ToLower().Contains(LookupStr)).ToList();
var lookups = Paragraphs().Where(p => p.ToLower(CultureInfo.InvariantCulture).Contains(LookupStr)).ToList();

Check warning on line 57 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)

Check warning on line 57 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)

Check warning on line 57 in src/libs/LangChain.Core/Docstore/Document.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

In method 'Lookup', replace the call to 'ToLower' with 'ToUpperInvariant' (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1308)

if (lookups.Count == 0)
{
Expand Down
29 changes: 29 additions & 0 deletions src/libs/LangChain.Core/Memory/BaseChatMemory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using LangChain.Schema;

namespace LangChain.Memory;

public abstract class BaseChatMemory : BaseMemory
{
protected BaseChatMessageHistory ChatHistory { get; set; }
public bool ReturnMessages { get; set; }

public BaseChatMemory(BaseChatMemoryInput input)
{
if (input.ChatHistory is null) ChatHistory = new ChatMessageHistory();
else ChatHistory = input.ChatHistory;
ReturnMessages = input.ReturnMessages;
}

public abstract override OutputValues LoadMemoryVariables(InputValues inputValues);

public override void SaveContext(InputValues inputValues, OutputValues outputValues)
{
ChatHistory.AddUserMessage(inputValues.Value[inputValues.Value.Keys.FirstOrDefault().ToString()].ToString());
ChatHistory.AddAiMessage(outputValues.Value[outputValues.Value.Keys.FirstOrDefault().ToString()].ToString());
}

public void Clear()
{
ChatHistory.Clear();
}
}
9 changes: 9 additions & 0 deletions src/libs/LangChain.Core/Memory/BaseChatMemoryInput.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace LangChain.Memory;

public class BaseChatMemoryInput
{
public BaseChatMessageHistory ChatHistory { get; set; }
public string InputKey { get; set; }
public string MemoryKey { get; set; }
public bool ReturnMessages { get; set; }
}
2 changes: 1 addition & 1 deletion src/libs/LangChain.Core/Memory/BaseChatMessageHistory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace LangChain.Memory;

public abstract class BaseChatMessageHistory
{
public IList<Message> Messages;
public IList<Message> Messages { get; set; } = new List<Message>();

public void AddUserMessage(string message)
{
Expand Down
9 changes: 9 additions & 0 deletions src/libs/LangChain.Core/Memory/BaseMemory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using LangChain.Schema;

namespace LangChain.Memory;

public abstract class BaseMemory
{
public abstract OutputValues LoadMemoryVariables(InputValues inputValues);
public abstract void SaveContext(InputValues inputValues, OutputValues outputValues);
}
15 changes: 15 additions & 0 deletions src/libs/LangChain.Core/Memory/BufferMemory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using LangChain.Schema;

namespace LangChain.Memory;

public class BufferMemory : BaseChatMemory
{
public BufferMemory(BufferMemoryInput input) : base(input)
{
}

public override OutputValues LoadMemoryVariables(InputValues? inputValues)
{
return new OutputValues(new Dictionary<string, object> { { "history", ChatHistory } });
}
}
7 changes: 7 additions & 0 deletions src/libs/LangChain.Core/Memory/BufferMemoryInput.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace LangChain.Memory;

public sealed class BufferMemoryInput : BaseChatMemoryInput
{
public string AiPrefix { get; set; }
public string HumanPrefix { get; set; }
}
4 changes: 0 additions & 4 deletions src/libs/LangChain.Core/Memory/ChatMessageHistory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ namespace LangChain.Memory;

public class ChatMessageHistory : BaseChatMessageHistory
{
public ChatMessageHistory()
{
Messages = new List<Message>();
}
public override void AddMessage(Message message)
{
Messages.Add(message);
Expand Down
8 changes: 8 additions & 0 deletions src/libs/LangChain.Core/Memory/ConversationBufferMemory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace LangChain.Memory;

public class ConversationBufferMemory : BufferMemory
{
public ConversationBufferMemory(BufferMemoryInput input) : base(input)
{
}
}
26 changes: 26 additions & 0 deletions src/libs/LangChain.Core/Memory/MemoryExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using LangChain.Providers;

namespace LangChain.Memory;

public static class MemoryExtensions
{
public static IReadOnlyCollection<Message> WithHistory(this IReadOnlyCollection<Message> messages, BaseMemory? memory)
{
var history = "These are our previous conversations:\n";
var previousMessages = memory.LoadMemoryVariables(null);
if (previousMessages.Value is { } messageDict &&
messageDict["history"] is ChatMessageHistory msg)
{
foreach (var chatMessage in msg.Messages)
{
history += chatMessage.Content + "\n";
}
}

return new[]
{
history.AsHumanMessage(),
}.Concat(messages).ToArray();
}

}
11 changes: 11 additions & 0 deletions src/libs/LangChain.Core/Schema/OutputValues.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace LangChain.Schema;

public class OutputValues
{
public OutputValues(Dictionary<string, object> value)
{
this.Value = value;
}

public Dictionary<string, object> Value { get; set; } = new();
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ public static Message Add(Message left, Message right)
return left + right;
}

/// <summary>
///
/// </summary>
/// <returns></returns>
public override string ToString()
{
if (FunctionName!=null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,30 @@ public class HuggingFaceConfiguration
/// </summary>
public string? ModelId { get; set; }

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Top_k"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Top_k"/>
public int? TopK { get; set; } = default!;

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Top_p"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Top_p"/>
public double? TopP { get; set; } = default!;

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Temperature"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Temperature"/>
public double? Temperature { get; set; } = 1D;

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Repetition_penalty"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Repetition_penalty"/>
public double? RepetitionPenalty { get; set; } = default!;

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Max_new_tokens"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Max_new_tokens"/>
public int? MaxNewTokens { get; set; } = default!;

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Max_time"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Max_time"/>
public double? MaxTime { get; set; } = default!;

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Return_full_text"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Return_full_text"/>
public object? ReturnFullText { get; set; } = default!;

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Num_return_sequences"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Num_return_sequences"/>
public int? NumReturnSequences { get; set; } = 1;

/// <inheritdoc cref="HuggingFace.GenerateTextRequestParameters.Do_sample"/>
/// <inheritdoc cref="GenerateTextRequestParameters.Do_sample"/>
public bool? DoSample { get; set; } = default!;
}

0 comments on commit 41e916a

Please sign in to comment.