Skip to content

Commit

Permalink
Merge pull request #12 from marcominerva/develop
Browse files Browse the repository at this point in the history
Add parameters to callbacks
  • Loading branch information
marcominerva authored Oct 11, 2023
2 parents 041803f + 3bfdf8c commit 3e23d79
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 22 deletions.
10 changes: 6 additions & 4 deletions src/DatabaseGpt/DatabaseGptClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using ChatGptNet;
using DatabaseGpt.DataAccessLayer;
using DatabaseGpt.Exceptions;
using DatabaseGpt.Models;
using DatabaseGpt.Settings;
using Microsoft.Extensions.Options;
using Polly;
Expand Down Expand Up @@ -50,7 +51,8 @@ public async Task<IDataReader> ExecuteNaturalLanguageQueryAsync(Guid sessionId,
You must answer the following question, '{question}', using a T-SQL query. Take into account also the previous messages.
From the comma separated list of tables available in the database, select those tables that might be useful in the generated T-SQL query.
The selected tables should be returned in a comma separated list. Your response should just contain the comma separated list of selected tables.
If there are no tables that might be useful, then return just the string 'NONE'.
If there are no tables that might be useful, return only the string 'NONE', without any other words. You shouldn't never explain the reason why you haven't found any table.
If the question is unclear or you don't understand the question, or you need a clarification, then return only the string 'NONE', without any other words.
""";

if (options?.OnStarting is not null)
Expand All @@ -67,10 +69,10 @@ The selected tables should be returned in a comma separated list. Your response
throw new NoTableFoundException($"I'm sorry, but there's no available information in the provided tables that can be useful for the question '{question}'.");
}

var tables = candidateTables.Split(',', StringSplitOptions.RemoveEmptyEntries);
var tables = candidateTables.Split(',', StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries);
if (options?.OnCandidateTablesFound is not null)
{
await options.OnCandidateTablesFound.Invoke(tables, serviceProvider);
await options.OnCandidateTablesFound.Invoke(new(sessionId, question, tables), serviceProvider);
}

var createTableScripts = await GetCreateTablesScriptAsync(tables, databaseSettings.ExcludedColumns);
Expand All @@ -93,7 +95,7 @@ CREATE TABLE Table2 (Column3 VARCHAR(255), Column4 VARCHAR(255))

if (options?.OnQueryGenerated is not null)
{
await options.OnQueryGenerated.Invoke(sql, serviceProvider);
await options.OnQueryGenerated.Invoke(new(sessionId, question, tables, sql), serviceProvider);
}

var reader = await ExecuteQueryAsync(sql);
Expand Down
3 changes: 1 addition & 2 deletions src/DatabaseGpt/DatabaseGptServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ public static IServiceCollection AddDatabaseGpt(this IServiceCollection services
ShouldHandle = new PredicateBuilder().Handle<ArgumentOutOfRangeException>().Handle<IndexOutOfRangeException>().Handle<SqlException>(),
OnRetry = args =>
{
Console.WriteLine($"Error ('{args.Outcome.Exception!.Message}'). Retrying (Attempt {args.AttemptNumber + 1} of {databaseSettings.MaxRetries})...");

//Console.WriteLine($"Error ('{args.Outcome.Exception!.Message}'). Retrying (Attempt {args.AttemptNumber + 1} of {databaseSettings.MaxRetries})...");
return default;
}
});
Expand Down
1 change: 1 addition & 0 deletions src/DatabaseGpt/IDatabaseGptClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Data;
using DatabaseGpt.Models;

namespace DatabaseGpt;

Expand Down
14 changes: 14 additions & 0 deletions src/DatabaseGpt/Models/CallbackArguments.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
namespace DatabaseGpt.Models;

public abstract class CallbackArguments
{
public Guid SessionId { get; }

public string Question { get; }

public CallbackArguments(Guid sessionId, string question)
{
SessionId = sessionId;
Question = question;
}
}
10 changes: 10 additions & 0 deletions src/DatabaseGpt/Models/NaturalLanguageQueryOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace DatabaseGpt.Models;

public class NaturalLanguageQueryOptions
{
public Func<IServiceProvider, ValueTask>? OnStarting { get; set; }

public Func<OnCandidateTablesFoundArguments, IServiceProvider, ValueTask>? OnCandidateTablesFound { get; set; }

public Func<OnQueryGeneratedArguments, IServiceProvider, ValueTask>? OnQueryGenerated { get; set; }
}
12 changes: 12 additions & 0 deletions src/DatabaseGpt/Models/OnCandidateTablesFoundArguments.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace DatabaseGpt.Models;

public class OnCandidateTablesFoundArguments : CallbackArguments
{
public IEnumerable<string> Tables { get; }

public OnCandidateTablesFoundArguments(Guid sessionId, string question, IEnumerable<string> tables)
: base(sessionId, question)
{
Tables = tables;
}
}
12 changes: 12 additions & 0 deletions src/DatabaseGpt/Models/OnQueryGeneratedArguments.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace DatabaseGpt.Models;

public class OnQueryGeneratedArguments : OnCandidateTablesFoundArguments
{
public string Sql { get; }

public OnQueryGeneratedArguments(Guid sessionId, string question, IEnumerable<string> tables, string sql)
: base(sessionId, question, tables)
{
Sql = sql;
}
}
10 changes: 0 additions & 10 deletions src/DatabaseGpt/NaturalLanguageQueryOptions.cs

This file was deleted.

14 changes: 8 additions & 6 deletions src/DatabaseGptConsole/Application.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DatabaseGpt;
using DatabaseGpt.Exceptions;
using DatabaseGpt.Extensions;
using DatabaseGpt.Models;
using DatabaseGpt.Settings;
using Microsoft.Extensions.Options;
using Spectre.Console;
Expand Down Expand Up @@ -37,23 +38,23 @@ public async Task ExecuteAsync()

return default;
},
OnCandidateTablesFound = (tables, _) =>
OnCandidateTablesFound = (args, _) =>
{
AnsiConsole.WriteLine();
AnsiConsole.WriteLine();

AnsiConsole.Write($"I think the following tables might be useful: {string.Join(", ", tables)}.");
AnsiConsole.Write($"I think the following tables might be useful: {string.Join(", ", args.Tables)}.");

return default;
},
OnQueryGenerated = (sql, _) =>
OnQueryGenerated = (args, _) =>
{
AnsiConsole.WriteLine();
AnsiConsole.WriteLine();

AnsiConsole.WriteLine("The query to answer the question should be the following:");

AnsiConsole.WriteLine(sql);
AnsiConsole.WriteLine(args.Sql);
AnsiConsole.WriteLine();

return default;
Expand All @@ -75,7 +76,8 @@ public async Task ExecuteAsync()
using var reader = await databaseGptClient.ExecuteNaturalLanguageQueryAsync(conversationId, question, options);

var table = new Table();
table.AddColumns(reader.GetColumnNames().ToArray());
var columns = reader.GetColumnNames().Select(c => $"[olive]{c}[/]").ToArray();
table.AddColumns(columns);

while (reader.Read())
{
Expand Down Expand Up @@ -107,7 +109,7 @@ public async Task ExecuteAsync()
catch (Exception ex)
{
AnsiConsole.WriteException(ex,
ExceptionFormats.ShortenPaths | ExceptionFormats.ShortenTypes | ExceptionFormats.ShortenMethods | ExceptionFormats.ShowLinks);
ExceptionFormats.ShortenPaths | ExceptionFormats.ShortenTypes | ExceptionFormats.ShortenMethods);
}
finally
{
Expand Down

0 comments on commit 3e23d79

Please sign in to comment.