From 28c61f6ed1f7331d860d7439851dcafd84d58db6 Mon Sep 17 00:00:00 2001 From: Aaron Powell Date: Mon, 18 Nov 2024 01:22:49 +0000 Subject: [PATCH] Moving from a lifecycle hook to event for model download The lifecycle hook would get blocked if you used WaitFor on the model resource, as it wouldn't fire until the model was healthy, which it couldn't achieve. Refactored to use an event instead that relies on the server resource ready event. Had to turn the event into a non-blocking with Task.Run otherwise models couldn't be downloaded concurrently, as the events would run one at a time, waiting for model download completion before continuing on. Also took this time to refactor the AddModel method to have a few extension methods, making it a bit clearer as to what happens when. Fixes #256 --- .../Program.cs | 1 + .../OllamaModelResourceLifecycleHook.cs | 120 -------- .../OllamaResourceBuilderExtensions.Model.cs | 282 ++++++++++++------ 3 files changed, 190 insertions(+), 213 deletions(-) delete mode 100644 src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaModelResourceLifecycleHook.cs diff --git a/examples/ollama/CommunityToolkit.Aspire.Hosting.Ollama.AppHost/Program.cs b/examples/ollama/CommunityToolkit.Aspire.Hosting.Ollama.AppHost/Program.cs index 1a06a195..d6c867d5 100644 --- a/examples/ollama/CommunityToolkit.Aspire.Hosting.Ollama.AppHost/Program.cs +++ b/examples/ollama/CommunityToolkit.Aspire.Hosting.Ollama.AppHost/Program.cs @@ -10,6 +10,7 @@ builder.AddProject("webfrontend") .WithExternalHttpEndpoints() .WithReference(phi3) + .WaitFor(phi3) .WithReference(llama); builder.Build().Run(); diff --git a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaModelResourceLifecycleHook.cs b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaModelResourceLifecycleHook.cs deleted file mode 100644 index b4af6404..00000000 --- a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaModelResourceLifecycleHook.cs +++ /dev/null @@ -1,120 +0,0 @@ -using Aspire.Hosting; -using Aspire.Hosting.ApplicationModel; -using Aspire.Hosting.Lifecycle; -using Microsoft.Extensions.Logging; -using OllamaSharp; -using System.Data.Common; -using System.Globalization; - -namespace CommunityToolkit.Aspire.Hosting.Ollama; - -internal class OllamaModelResourceLifecycleHook( - ResourceLoggerService loggerService, - ResourceNotificationService notificationService, - DistributedApplicationExecutionContext context) : IDistributedApplicationLifecycleHook, IAsyncDisposable -{ - private readonly CancellationTokenSource _tokenSource = new(); - - public async Task AfterResourcesCreatedAsync(DistributedApplicationModel appModel, CancellationToken cancellationToken = default) - { - if (context.IsPublishMode) - { - return; - } - - await Parallel.ForEachAsync(appModel.Resources.OfType(), _tokenSource.Token, async (resource, ct) => - { - await DownloadModel(resource, ct); - }); - } - - private async Task DownloadModel(OllamaModelResource modelResource, CancellationToken cancellationToken) - { - var logger = loggerService.GetLogger(modelResource); - string model = modelResource.ModelName; - var ollamaResource = modelResource.Parent; - - try - { - var connectionString = await ollamaResource.ConnectionStringExpression.GetValueAsync(cancellationToken).ConfigureAwait(false); - - if (string.IsNullOrWhiteSpace(connectionString)) - { - await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot("No connection string", KnownResourceStateStyles.Error) }); - return; - } - - if (!Uri.TryCreate(connectionString, UriKind.Absolute, out _)) - { - var connectionBuilder = new DbConnectionStringBuilder - { - ConnectionString = connectionString - }; - - if (connectionBuilder.ContainsKey("Endpoint") && Uri.TryCreate(connectionBuilder["Endpoint"].ToString(), UriKind.Absolute, out var endpoint)) - { - connectionString = endpoint.ToString(); - } - } - - var ollamaClient = new OllamaApiClient(new Uri(connectionString)); - - await notificationService.PublishUpdateAsync(modelResource, state => state with - { - State = new ResourceStateSnapshot($"Checking {model}", KnownResourceStateStyles.Info), - Properties = [.. state.Properties, new(CustomResourceKnownProperties.Source, model)] - }); - - var hasModel = await HasModelAsync(ollamaClient, model, cancellationToken); - - if (!hasModel) - { - logger.LogInformation("{TimeStamp}: [{Model}] needs to be downloaded for {ResourceName}", - DateTime.UtcNow.ToString("yyyy-MM-ddTHH:mm:ss.fffZ", CultureInfo.InvariantCulture), - model, - ollamaResource.Name); - await OllamaUtilities.PullModelAsync(modelResource, ollamaClient, model, logger, notificationService, cancellationToken); - } - else - { - logger.LogInformation("{TimeStamp}: [{Model}] already exists for {ResourceName}", - DateTime.UtcNow.ToString("yyyy-MM-ddTHH:mm:ss.fffZ", CultureInfo.InvariantCulture), - model, - ollamaResource.Name); - } - - await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot(KnownResourceStates.Running, KnownResourceStateStyles.Success) }); - } - catch (Exception ex) - { - await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot(ex.Message, KnownResourceStateStyles.Error) }); - } - } - - private static async Task HasModelAsync(OllamaApiClient ollamaClient, string model, CancellationToken cancellationToken) - { - int retryCount = 0; - while (retryCount < 5) - { - try - { - var localModels = await ollamaClient.ListLocalModelsAsync(cancellationToken); - return localModels.Any(m => m.Name.Equals(model, StringComparison.OrdinalIgnoreCase)); - } - catch (TaskCanceledException) - { - // wait 30 seconds before retrying - await Task.Delay(TimeSpan.FromSeconds(30), cancellationToken); - retryCount++; - } - } - - throw new TimeoutException("Failed to list local models after 5 retries. Likely that the container image was not pulled in time, or the container is not running."); - } - - public ValueTask DisposeAsync() - { - _tokenSource.Cancel(); - return default; - } -} diff --git a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.Model.cs b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.Model.cs index 344dde67..83c7492e 100644 --- a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.Model.cs +++ b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.Model.cs @@ -1,10 +1,10 @@ using Aspire.Hosting.ApplicationModel; -using Aspire.Hosting.Lifecycle; using CommunityToolkit.Aspire.Hosting.Ollama; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Diagnostics.HealthChecks; using Microsoft.Extensions.Logging; using OllamaSharp; +using System.Data.Common; +using System.Globalization; namespace Aspire.Hosting; @@ -39,67 +39,9 @@ public static IResourceBuilder AddModel(this IResourceBuild ArgumentNullException.ThrowIfNull(builder, nameof(builder)); ArgumentException.ThrowIfNullOrWhiteSpace(modelName, nameof(modelName)); - builder.ApplicationBuilder.Services.TryAddLifecycleHook(); - builder.Resource.AddModel(modelName); var modelResource = new OllamaModelResource(name, modelName, builder.Resource); - modelResource.AddModelResourceCommand( - name: "Redownload", - displayName: "Redownload Model", - executeCommand: async (modelResource, ollamaClient, logger, notificationService, ct) => - { - await OllamaUtilities.PullModelAsync(modelResource, ollamaClient, modelResource.ModelName, logger, notificationService, ct); - await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot(KnownResourceStates.Running, KnownResourceStateStyles.Success) }); - - return CommandResults.Success(); - }, - displayDescription: $"Redownload the model {modelName}.", - iconName: "ArrowDownload" - ).AddModelResourceCommand( - name: "Delete", - displayName: "Delete Model", - executeCommand: async (modelResource, ollamaClient, logger, notificationService, ct) => - { - await ollamaClient.DeleteModelAsync(modelResource.ModelName, ct); - await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot("Stopped", KnownResourceStateStyles.Success) }); - - return CommandResults.Success(); - }, - displayDescription: $"Delete the model {modelName}.", - iconName: "Delete", - confirmationMessage: $"Are you sure you want to delete the model {modelName}?" - ).AddModelResourceCommand( - name: "ModelInfo", - displayName: "Print Model Info", - executeCommand: async (modelResource, ollamaClient, logger, notificationService, ct) => - { - var modelInfo = await ollamaClient.ShowModelAsync(modelResource.ModelName, ct); - logger.LogInformation("Model Info: {ModelInfo}", modelInfo.ToJson()); - - return CommandResults.Success(); - }, - displayDescription: $"Print the info for the model {modelName}.", - iconName: "Info" - ).AddModelResourceCommand( - name: "Stop", - displayName: "Stop Model", - executeCommand: async (modelResource, ollamaClient, logger, notificationService, ct) => - { - await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot(KnownResourceStates.Stopping, KnownResourceStateStyles.Success) }); - await foreach (var result in ollamaClient.GenerateAsync(new OllamaSharp.Models.GenerateRequest { Model = modelResource.ModelName, KeepAlive = "0" }, ct)) - { - logger.LogInformation("{Result}", result?.ToJson()); - } - await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot("Stopped", KnownResourceStateStyles.Success) }); - - return CommandResults.Success(); - }, - displayDescription: $"Stop the model {modelName}.", - iconName: "Stop", - isHighlighted: true - ); - var healthCheckKey = $"{name}-{modelName}-health"; builder.ApplicationBuilder.Services.AddHealthChecks() @@ -107,7 +49,9 @@ public static IResourceBuilder AddModel(this IResourceBuild return builder.ApplicationBuilder .AddResource(modelResource) - .WithHealthCheck(healthCheckKey); + .WithHealthCheck(healthCheckKey) + .WithModelCommands(modelName) + .WithModelDownload(); } /// @@ -122,8 +66,6 @@ public static IResourceBuilder AddHuggingFaceModel(this IRe ArgumentNullException.ThrowIfNull(builder, nameof(builder)); ArgumentException.ThrowIfNullOrWhiteSpace(modelName, nameof(modelName)); - builder.ApplicationBuilder.Services.TryAddLifecycleHook(); - if (!modelName.StartsWith("hf.co/") && !modelName.StartsWith("huggingface.co/")) { modelName = "hf.co/" + modelName; @@ -132,8 +74,65 @@ public static IResourceBuilder AddHuggingFaceModel(this IRe return AddModel(builder, name, modelName); } - private static OllamaModelResource AddModelResourceCommand( - this OllamaModelResource modelResource, + private static IResourceBuilder WithModelCommands(this IResourceBuilder builder, string modelName) => + builder.AddModelResourceCommand( + name: "Redownload", + displayName: "Redownload Model", + executeCommand: async (modelResource, ollamaClient, logger, notificationService, ct) => + { + await OllamaUtilities.PullModelAsync(modelResource, ollamaClient, modelResource.ModelName, logger, notificationService, ct); + await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot(KnownResourceStates.Running, KnownResourceStateStyles.Success) }); + + return CommandResults.Success(); + }, + displayDescription: $"Redownload the model {modelName}.", + iconName: "ArrowDownload" + ).AddModelResourceCommand( + name: "Delete", + displayName: "Delete Model", + executeCommand: async (modelResource, ollamaClient, logger, notificationService, ct) => + { + await ollamaClient.DeleteModelAsync(modelResource.ModelName, ct); + await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot("Stopped", KnownResourceStateStyles.Success) }); + + return CommandResults.Success(); + }, + displayDescription: $"Delete the model {modelName}.", + iconName: "Delete", + confirmationMessage: $"Are you sure you want to delete the model {modelName}?" + ).AddModelResourceCommand( + name: "ModelInfo", + displayName: "Print Model Info", + executeCommand: async (modelResource, ollamaClient, logger, notificationService, ct) => + { + var modelInfo = await ollamaClient.ShowModelAsync(modelResource.ModelName, ct); + logger.LogInformation("Model Info: {ModelInfo}", modelInfo.ToJson()); + + return CommandResults.Success(); + }, + displayDescription: $"Print the info for the model {modelName}.", + iconName: "Info" + ).AddModelResourceCommand( + name: "Stop", + displayName: "Stop Model", + executeCommand: async (modelResource, ollamaClient, logger, notificationService, ct) => + { + await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot(KnownResourceStates.Stopping, KnownResourceStateStyles.Success) }); + await foreach (var result in ollamaClient.GenerateAsync(new OllamaSharp.Models.GenerateRequest { Model = modelResource.ModelName, KeepAlive = "0" }, ct)) + { + logger.LogInformation("{Result}", result?.ToJson()); + } + await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot("Stopped", KnownResourceStateStyles.Success) }); + + return CommandResults.Success(); + }, + displayDescription: $"Stop the model {modelName}.", + iconName: "Stop", + isHighlighted: true + ); + + private static IResourceBuilder AddModelResourceCommand( + this IResourceBuilder builder, string name, string displayName, Func> executeCommand, @@ -142,39 +141,136 @@ private static OllamaModelResource AddModelResourceCommand( string? confirmationMessage = null, string? iconName = null, IconVariant? iconVariant = IconVariant.Filled, - bool isHighlighted = false) + bool isHighlighted = false) => + builder.WithCommand( + name: name, + displayName: displayName, + updateState: context => + context.ResourceSnapshot.State?.Text == KnownResourceStates.Running ? + ResourceCommandState.Enabled : + ResourceCommandState.Disabled, + executeCommand: async context => + { + var modelResource = builder.Resource; + (var success, var endpoint) = await OllamaUtilities.TryGetEndpointAsync(modelResource, context.CancellationToken); + + if (!success || endpoint is null) + { + return new ExecuteCommandResult { Success = false, ErrorMessage = "Invalid connection string" }; + } + + var ollamaClient = new OllamaApiClient(endpoint); + var logger = context.ServiceProvider.GetRequiredService().GetLogger(modelResource); + var notificationService = context.ServiceProvider.GetRequiredService(); + + return await executeCommand(modelResource, ollamaClient, logger, notificationService, context.CancellationToken); + }, + displayDescription: displayDescription, + parameter: parameter, + confirmationMessage: confirmationMessage, + iconName: iconName, + iconVariant: iconVariant, + isHighlighted: isHighlighted + ); + + private static IResourceBuilder WithModelDownload(this IResourceBuilder builder) { - modelResource.Annotations.Add(new ResourceCommandAnnotation( - name: name, - displayName: displayName, - updateState: context => - context.ResourceSnapshot.State?.Text == KnownResourceStates.Running ? - ResourceCommandState.Enabled : - ResourceCommandState.Disabled, - executeCommand: async context => + builder.ApplicationBuilder.Eventing.Subscribe(builder.Resource.Parent, (@event, cancellationToken) => + { + var loggerService = @event.Services.GetRequiredService(); + var notificationService = @event.Services.GetRequiredService(); + + if (builder.Resource is not OllamaModelResource modelResource) { - (var success, var endpoint) = await OllamaUtilities.TryGetEndpointAsync(modelResource, context.CancellationToken); + return Task.CompletedTask; + } + + var logger = loggerService.GetLogger(modelResource); + string model = modelResource.ModelName; + + _ = Task.Run(async () => + { + try + { + var connectionString = await modelResource.ConnectionStringExpression.GetValueAsync(cancellationToken).ConfigureAwait(false); + + if (string.IsNullOrWhiteSpace(connectionString)) + { + await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot("No connection string", KnownResourceStateStyles.Error) }); + return; + } + + if (!Uri.TryCreate(connectionString, UriKind.Absolute, out _)) + { + var connectionBuilder = new DbConnectionStringBuilder + { + ConnectionString = connectionString + }; + + if (connectionBuilder.ContainsKey("Endpoint") && Uri.TryCreate(connectionBuilder["Endpoint"].ToString(), UriKind.Absolute, out var endpoint)) + { + connectionString = endpoint.ToString(); + } + } + + var ollamaClient = new OllamaApiClient(new Uri(connectionString)); + + await notificationService.PublishUpdateAsync(modelResource, state => state with + { + State = new ResourceStateSnapshot($"Checking {model}", KnownResourceStateStyles.Info), + Properties = [.. state.Properties, new(CustomResourceKnownProperties.Source, model)] + }); - if (!success || endpoint is null) + var hasModel = await HasModelAsync(ollamaClient, model, cancellationToken); + + if (!hasModel) + { + logger.LogInformation("{TimeStamp}: [{Model}] needs to be downloaded for {ResourceName}", + DateTime.UtcNow.ToString("yyyy-MM-ddTHH:mm:ss.fffZ", CultureInfo.InvariantCulture), + model, + modelResource.Name); + await OllamaUtilities.PullModelAsync(modelResource, ollamaClient, model, logger, notificationService, cancellationToken); + } + else + { + logger.LogInformation("{TimeStamp}: [{Model}] already exists for {ResourceName}", + DateTime.UtcNow.ToString("yyyy-MM-ddTHH:mm:ss.fffZ", CultureInfo.InvariantCulture), + model, + modelResource.Name); + } + + await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot(KnownResourceStates.Running, KnownResourceStateStyles.Success) }); + } + catch (Exception ex) { - return new ExecuteCommandResult { Success = false, ErrorMessage = "Invalid connection string" }; + await notificationService.PublishUpdateAsync(modelResource, state => state with { State = new ResourceStateSnapshot(ex.Message, KnownResourceStateStyles.Error) }); } + }, cancellationToken); - var ollamaClient = new OllamaApiClient(endpoint); - var logger = context.ServiceProvider.GetRequiredService().GetLogger(modelResource); - var notificationService = context.ServiceProvider.GetRequiredService(); - - return await executeCommand(modelResource, ollamaClient, logger, notificationService, context.CancellationToken); - }, - displayDescription: displayDescription, - parameter: parameter, - confirmationMessage: confirmationMessage, - iconName: iconName, - iconVariant: iconVariant, - isHighlighted: isHighlighted - )); - - return modelResource; - } + return Task.CompletedTask; + + static async Task HasModelAsync(OllamaApiClient ollamaClient, string model, CancellationToken cancellationToken) + { + int retryCount = 0; + while (retryCount < 5) + { + try + { + var localModels = await ollamaClient.ListLocalModelsAsync(cancellationToken); + return localModels.Any(m => m.Name.Equals(model, StringComparison.OrdinalIgnoreCase)); + } + catch (TaskCanceledException) + { + // wait 30 seconds before retrying + await Task.Delay(TimeSpan.FromSeconds(30), cancellationToken); + retryCount++; + } + } + throw new TimeoutException("Failed to list local models after 5 retries. Likely that the container image was not pulled in time, or the container is not running."); + } + }); + + return builder; + } } \ No newline at end of file