Skip to content

Commit

Permalink
feat: HuggingFace model downloader (#45)
Browse files Browse the repository at this point in the history
* LLamaSharp integration with tests

* changed test project name. added projects to solution

* added HF Downloader, used it in LLamaSharp tests

* added missing reference
  • Loading branch information
TesAnti authored Nov 4, 2023
1 parent c776a41 commit b2be43a
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
namespace LangChain.Providers.Downloader;

internal static class HttpClientExtensions
{
public static async Task DownloadAsync(this HttpClient client, string requestUri, Stream destination, IProgress<double> progress = null, CancellationToken cancellationToken = default)
{
// Get the http headers first to examine the content length
using (var response = await client.GetAsync(requestUri, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false))
{
var contentLength = response.Content.Headers.ContentLength;

using (var download = await response.Content.ReadAsStreamAsync().ConfigureAwait(false))
{

// Ignore progress reporting when no progress reporter was
// passed or when the content length is unknown
if (progress == null || !contentLength.HasValue)
{
await download.CopyToAsync(destination).ConfigureAwait(false);
return;
}

// Convert absolute progress (bytes downloaded) into relative progress (0% - 100%)
var relativeProgress = new Progress<long>(totalBytes => progress.Report((double)totalBytes / contentLength.Value));
// Use extension method to report progress while downloading
await download.CopyToAsync(destination, 81920, relativeProgress, cancellationToken).ConfigureAwait(false);
progress.Report(1);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
namespace LangChain.Providers.Downloader;

/// <summary>
/// A downloader for HuggingFace models
/// </summary>
public class HuggingFaceModelDownloader
{
public static HuggingFaceModelDownloader Instance { get; } = new HuggingFaceModelDownloader();


/// <summary>
/// The HttpClient used to download the models
/// </summary>
public HttpClient HttpClient { get; set; } = new HttpClient();

/// <summary>
/// The default storage path for the models
/// </summary>
public static string DefaultStoragePath =>
Path.Combine(
Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData),
"LangChain", "CSharp", "Models");

private async Task DownloadModel(string url, string path, CancellationToken? cancellationToken=null)
{
var client = HttpClient;

using (var file = new FileStream(path, FileMode.Create, FileAccess.Write, FileShare.None))
{
using ProgressBar progress = new ProgressBar();

await client.DownloadAsync(url, file, progress, cancellationToken??CancellationToken.None).ConfigureAwait(false);
}
}

/// <summary>
/// Downloads a model from HuggingFace with caching and return path to it
/// </summary>
public async Task<string> GetModel(string repository, string fileName, string version="master", string storagePath = null)
{
storagePath ??= HuggingFaceModelDownloader.DefaultStoragePath;
var repositoryPath = Path.Combine(storagePath, repository);
if (!Directory.Exists(repositoryPath))
{
Directory.CreateDirectory(repositoryPath);
}

var modelPath = Path.Combine(repositoryPath, version, fileName);
var directory = Path.GetDirectoryName(modelPath);
if (!Directory.Exists(directory))
{
Directory.CreateDirectory(directory);
}
var downloadMarkerPath = modelPath + ".hfdownload"; // to verify if the download is complete
if (!File.Exists(modelPath)||File.Exists(downloadMarkerPath))
{
File.WriteAllText(downloadMarkerPath, "");
File.Delete(modelPath);
Console.WriteLine($"No model file found. Downloading...");
var downloadUrl = $"https://huggingface.co/{repository}/resolve/{version}/{fileName}";
await DownloadModel(downloadUrl, modelPath).ConfigureAwait(false);
File.Delete(downloadMarkerPath);
}


return modelPath;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using System.Text;

namespace LangChain.Providers.Downloader;

/// <summary>
/// An ASCII progress bar
/// </summary>
internal class ProgressBar : IDisposable, IProgress<double>
{
private const int blockCount = 10;
private readonly TimeSpan animationInterval = TimeSpan.FromSeconds(1.0 / 8);
private const string animation = @"|/-\";

private readonly Timer timer;

private double currentProgress = 0;
private string currentText = string.Empty;
private bool disposed = false;
private int animationIndex = 0;

public ProgressBar()
{
timer = new Timer(TimerHandler);

// A progress bar is only for temporary display in a console window.
// If the console output is redirected to a file, draw nothing.
// Otherwise, we'll end up with a lot of garbage in the target file.
if (!Console.IsOutputRedirected)
{
ResetTimer();
}
}

public void Report(double value)
{
// Make sure value is in [0..1] range
value = Math.Max(0, Math.Min(1, value));
Interlocked.Exchange(ref currentProgress, value);
}

private void TimerHandler(object state)
{
lock (timer)
{
if (disposed) return;

int progressBlockCount = (int)(currentProgress * blockCount);
int percent = (int)(currentProgress * 100);
string text = string.Format("[{0}{1}] {2,3}% {3}",
new string('#', progressBlockCount), new string('-', blockCount - progressBlockCount),
percent,
animation[animationIndex++ % animation.Length]);
UpdateText(text);

ResetTimer();
}
}

private void UpdateText(string text)
{
// Get length of common portion
int commonPrefixLength = 0;
int commonLength = Math.Min(currentText.Length, text.Length);
while (commonPrefixLength < commonLength && text[commonPrefixLength] == currentText[commonPrefixLength])
{
commonPrefixLength++;
}

// Backtrack to the first differing character
StringBuilder outputBuilder = new StringBuilder();
outputBuilder.Append('\b', currentText.Length - commonPrefixLength);

// Output new suffix
outputBuilder.Append(text.Substring(commonPrefixLength));

// If the new text is shorter than the old one: delete overlapping characters
int overlapCount = currentText.Length - text.Length;
if (overlapCount > 0)
{
outputBuilder.Append(' ', overlapCount);
outputBuilder.Append('\b', overlapCount);
}

Console.Write(outputBuilder);
currentText = text;
}

private void ResetTimer()
{
timer.Change(animationInterval, TimeSpan.FromMilliseconds(-1));
}

public void Dispose()
{
lock (timer)
{
disposed = true;
UpdateText(string.Empty);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
namespace LangChain.Providers.Downloader;

internal static class StreamExtensions
{
public static async Task CopyToAsync(this Stream source, Stream destination, int bufferSize, IProgress<long> progress = null, CancellationToken cancellationToken = default)
{
if (source == null)
throw new ArgumentNullException(nameof(source));
if (!source.CanRead)
throw new ArgumentException("Has to be readable", nameof(source));
if (destination == null)
throw new ArgumentNullException(nameof(destination));
if (!destination.CanWrite)
throw new ArgumentException("Has to be writable", nameof(destination));
if (bufferSize < 0)
throw new ArgumentOutOfRangeException(nameof(bufferSize));

var buffer = new byte[bufferSize];
long totalBytesRead = 0;
int bytesRead;
while ((bytesRead = await source.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) != 0)
{
await destination.WriteAsync(buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false);
totalBytesRead += bytesRead;
progress?.Report(totalBytesRead);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LangChain.Providers;
using LangChain.Providers.Downloader;
using LangChain.Providers.LLamaSharp;

namespace LangChain.Providers.LLamaSharp.IntegrationTests;
Expand All @@ -14,7 +15,7 @@ public void PrepromptTest()
{
var model = new LLamaSharpModel(new LLamaSharpConfiguration
{
PathToModelFile = Path.Combine(Environment.ExpandEnvironmentVariables("%LLAMA_MODELS%"), "ggml-model-f32-q4_0.bin"),
PathToModelFile = HuggingFaceModelDownloader.Instance.GetModel("AsakusaRinne/LLamaSharpSamples", "LLaMa/7B/ggml-model-f32-q4_0.bin", version: "v0.3.0").Result,
});

var response=model.GenerateAsync(new ChatRequest(new List<Message>
Expand All @@ -39,7 +40,7 @@ public void InstructionTest()
{
var model = new LLamaSharpModel(new LLamaSharpConfiguration
{
PathToModelFile = Path.Combine(Environment.ExpandEnvironmentVariables("%LLAMA_MODELS%"), "ggml-model-f32-q4_0.bin"),
PathToModelFile = HuggingFaceModelDownloader.Instance.GetModel("AsakusaRinne/LLamaSharpSamples", "LLaMa/7B/ggml-model-f32-q4_0.bin", version: "v0.3.0").Result,
Mode = ELLamaSharpModelMode.Instruction
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\libs\Providers\LangChain.Providers.HuggingFace\LangChain.Providers.HuggingFace.csproj" />
<ProjectReference Include="..\..\libs\Providers\LangChain.Providers.LLamaSharp\LangChain.Providers.LLamaSharp.csproj" />
</ItemGroup>

Expand Down

0 comments on commit b2be43a

Please sign in to comment.