diff --git a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpEmbeddings.cs b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpEmbeddings.cs new file mode 100644 index 00000000..1c4ed7d6 --- /dev/null +++ b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpEmbeddings.cs @@ -0,0 +1,41 @@ +using LangChain.Abstractions.Embeddings.Base; +using LLama.Common; +using LLama; + +namespace LangChain.Providers.LLamaSharp; + +public class LLamaSharpEmbeddings:IEmbeddings +{ + protected readonly LLamaSharpConfiguration _configuration; + protected readonly LLamaWeights _model; + protected readonly ModelParams _parameters; + private readonly LLamaEmbedder _embedder; + + public LLamaSharpEmbeddings(LLamaSharpConfiguration configuration) + { + _parameters = new ModelParams(configuration.PathToModelFile) + { + ContextSize = (uint)configuration.ContextSize, + Seed = (uint)configuration.Seed, + + }; + _model = LLamaWeights.LoadFromFile(_parameters); + _configuration = configuration; + _embedder = new LLamaEmbedder(_model, _parameters); + } + + public Task EmbedDocumentsAsync(string[] texts, CancellationToken cancellationToken = default) + { + float[][] result = new float[texts.Length][]; + for (int i = 0; i < texts.Length; i++) + { + result[i] = _embedder.GetEmbeddings(texts[i]); + } + return Task.FromResult(result); + } + + public Task EmbedQueryAsync(string text, CancellationToken cancellationToken = default) + { + return Task.FromResult(_embedder.GetEmbeddings(text)); + } +} \ No newline at end of file diff --git a/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs b/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs index 6cad5fdb..57153af7 100644 --- a/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs +++ b/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/LLamaSharpTests.cs @@ -54,4 +54,49 @@ public void InstructionTest() Assert.AreEqual("4",response.Messages.Last().Content.Trim()); } + + float VectorDistance(float[] a, float[] b) + { + float result = 0; + for (int i = 0; i < a.Length; i++) + { + result += (a[i] - b[i]) * (a[i] - b[i]); + } + + return result; + + } + [TestMethod] +#if CONTINUOUS_INTEGRATION_BUILD + [Ignore] +#endif + public void EmbeddingsTest() + { + var model = new LLamaSharpEmbeddings(new LLamaSharpConfiguration + { + PathToModelFile = ModelPath, + Temperature = 0 + }); + + string[] texts = new string[] + { + "I spent entire day watching TV", + "My dog name is Bob", + "This icecream is delicious", + "It is cold in space" + }; + + var database = model.EmbedDocumentsAsync(texts).Result; + + + var query = model.EmbedQueryAsync("How do you call your pet?").Result; + + var zipped = database.Zip(texts); + + var ordered= zipped.Select(x=>new {text=x.Second,dist=VectorDistance(x.First,query)}); + + var closest = ordered.OrderBy(x => x.dist).First(); + + Assert.AreEqual("My dog name is Bob", closest.text); + } } \ No newline at end of file