Skip to content

Commit

Permalink
images quick regeneration
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-zhur committed Oct 9, 2024
1 parent 8830880 commit 110a747
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace OneShelf.Common.OpenAi.Models.Memory;

public class ChatBotMemoryPointWithDeserializableTraces : ChatBotMemoryPoint
{
public List<string> ImageTraces { get; init; } = new();

public List<string> ImageUrlTraces { get; init; } = new();
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
namespace OneShelf.Common.OpenAi.Models.Memory;

public class ChatBotMemoryPointWithTraces : ChatBotMemoryPoint
public class ChatBotMemoryPointWithTraces : ChatBotMemoryPointWithDeserializableTraces
{
public List<MemoryPointTrace> Traces { get; init; } = new();

public List<string> ImageTraces { get; init; } = new();

public List<string> ImageUrlTraces { get; init; } = new();
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public AiDialogHandler(
_dogContext = dogContext;
}

protected override void OnInitializing(Update update)
protected override void OnInitializing(long userId, long chatId)
{
_dogDatabase.InitializeInteractionsRepositoryScope(_dogContext.DomainId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ public AiDialogHandler(
_options = options;
}

protected override void OnInitializing(Update update)
protected override void OnInitializing(long userId, long chatId)
{
_dragonDatabase.InitializeInteractionsRepositoryScope(update.Message!.From!.Id, update.Message.Chat.Id);
_dragonDatabase.InitializeInteractionsRepositoryScope(userId, chatId);
}

protected override bool TraceImages => _options.Value.IsAdmin(_dragonScope.UserId);

protected override IInteraction<InteractionType> CreateInteraction(Update update) => new Interaction
{
ChatId = update.Message!.Chat.Id,
ChatId = update.Message?.Chat.Id ?? update.CallbackQuery!.Message!.Chat.Id,
UpdateId = update.UpdateId,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ protected override async Task<bool> HandleSync(Update update)
_dragonDatabase.Updates.Add(dbUpdate);
await _dragonDatabase.SaveChangesAsync();

_scope.Initialize(dbUpdate.Id, update.Message?.Chat.Id, update.Message?.From?.Id);
_scope.Initialize(dbUpdate.Id, update.Message?.Chat.Id ?? update.CallbackQuery?.Message?.Chat.Id, update.Message?.From?.Id ?? update.CallbackQuery?.From.Id);
return false;
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
using System.Text;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using OneShelf.Common;
using OneShelf.Common.OpenAi.Models;
using OneShelf.Common.OpenAi.Models.Memory;
using OneShelf.Common.OpenAi.Services;
using OneShelf.Telegram.Ai.Model;
using OneShelf.Telegram.Options;
using OneShelf.Telegram.Services.Base;
using Telegram.BotAPI;
using Telegram.BotAPI.AvailableMethods;
Expand Down Expand Up @@ -80,68 +78,78 @@ protected TelegramBotClient GetApi()
return _api ??= new(ScopedAbstractions.GetBotToken());
}

protected async Task SendMessage(Update respondTo, IReadOnlyList<string> images, bool reply)
protected async Task SendMessage(long chatId, int? messageThreadId, int messageId, IReadOnlyList<string> images, bool reply)
{
await GetApi().SendMediaGroupAsync(
new(respondTo.Message!.Chat.Id, images.Select(x => new InputMediaPhoto(x.ToString())))
new(chatId, images.Select(x => new InputMediaPhoto(x.ToString())))
{
MessageThreadId = respondTo.Message!.MessageThreadId,
MessageThreadId = messageThreadId,
ReplyParameters = !reply ? null : new()
{
MessageId = respondTo.Message.MessageId,
MessageId = messageId,
AllowSendingWithoutReply = false,
},
DisableNotification = true,
});
}

protected async Task SendMessage(Update respondTo, string text, IReadOnlyList<string> images, bool reply)
protected async Task SendMessage(long chatId, int? messageThreadId, int messageId, string text, IReadOnlyList<string> images, bool reply, ReplyMarkup? replyMarkup = null)
{
var (messageEntities, result) = GetMessageEntities(text);
if (replyMarkup != null)
{
await SendSeparately(chatId, messageThreadId, messageId, text, images, reply, replyMarkup);
return;
}

try
{
await GetApi().SendMediaGroupAsync(new(respondTo.Message!.Chat.Id, images.WithIndices().Select(x => new InputMediaPhoto(x.x.ToString())
var (messageEntities, result) = GetMessageEntities(text);
await GetApi().SendMediaGroupAsync(new(chatId, images.WithIndices().Select(x => new InputMediaPhoto(x.x.ToString())
{
Caption = x.i == 0 ? result : null,
CaptionEntities = x.i == 0 ? messageEntities : null,
}))
{
MessageThreadId = respondTo.Message!.MessageThreadId,
MessageThreadId = messageThreadId,
ReplyParameters = !reply ? null : new()
{
MessageId = respondTo.Message.MessageId,
MessageId = messageId,
AllowSendingWithoutReply = false,
},
DisableNotification = true,
});
}
catch (BotRequestException e) when (e.Message.Contains("message caption is too long"))
{
await SendMessage(respondTo, images, reply);

await SendMessage(respondTo, text, reply);
await SendSeparately(chatId, messageThreadId, messageId, text, images, reply);
}
}

protected async Task SendMessage(Update respondTo, string text, bool reply)
private async Task SendSeparately(long chatId, int? messageThreadId, int messageId, string text, IReadOnlyList<string> images, bool reply, ReplyMarkup? replyMarkup = null)
{
await SendMessage(chatId, messageThreadId, messageId, images, reply);
await SendMessage(chatId, messageThreadId, messageId, text, reply, replyMarkup);
}

protected async Task SendMessage(long chatId, int? messageThreadId, int messageId, string text, bool reply, ReplyMarkup? replyMarkup = null)
{
var (messageEntities, result) = GetMessageEntities(text);

await GetApi().SendMessageAsync(new(respondTo.Message!.Chat.Id, result)
await GetApi().SendMessageAsync(new(chatId, result)
{
MessageThreadId = respondTo.Message!.MessageThreadId,
MessageThreadId = messageThreadId,
ReplyParameters = !reply ? null : new()
{
MessageId = respondTo.Message.MessageId,
MessageId = messageId,
AllowSendingWithoutReply = false,
},
DisableNotification = true,
LinkPreviewOptions = new()
{
IsDisabled = true,
},
Entities = messageEntities
Entities = messageEntities,
ReplyMarkup = replyMarkup,
});
}

Expand Down Expand Up @@ -203,16 +211,22 @@ protected async void LongTyping(Update update, CancellationToken cancellationTok

protected override async Task<bool> HandleSync(Update update)
{
if (update.CallbackQuery?.Data?.StartsWith("image, ") == true && update.CallbackQuery.Message != null)
{
OnInitializing(update.CallbackQuery.From.Id, update.CallbackQuery.Message!.Chat.Id);
return await HandleCallback(update);
}

if (update.Message?.From == null) return false;

if (!CheckRelevant(update)) return false;

OnInitializing(update);
OnInitializing(update.Message!.From!.Id, update.Message.Chat.Id);

var chatUnavailableUntil = await GetChatUnavailableUntil();
if (chatUnavailableUntil.HasValue)
{
Queued(SendMessage(update, string.Format(UnavailableUntilTemplate, chatUnavailableUntil.Value.ToString("f")), false));
Queued(SendMessage(update.Message!.Chat.Id, update.Message.MessageThreadId, update.Message.MessageId, string.Format(UnavailableUntilTemplate, chatUnavailableUntil.Value.ToString("f")), false));
return true;
}

Expand All @@ -227,9 +241,97 @@ protected override async Task<bool> HandleSync(Update update)
return false;
}

private async Task<bool> HandleCallback(Update update)
{
var callbackQuery = update.CallbackQuery!;

var data = callbackQuery.Data!.Split(", ");

if (data is not ["image", { } interactionIdValue, { } imageIndexValue, { } timesValue]
|| !int.TryParse(interactionIdValue, out var interactionId)
|| !int.TryParse(imageIndexValue, out var imageIndex)
|| !int.TryParse(timesValue, out var times))
{
return false;
}

var interaction = (await _repository.Get(x => x.Where(x => x.Id == interactionId))).Single();
DateTime? imagesUnavailableUntil = null;
if (times != 0)
{
imagesUnavailableUntil = await GetImagesUnavailableUntil(DateTime.Now);
if (!imagesUnavailableUntil.HasValue)
{
var interaction2 = CreateInteraction(update);
interaction2.CreatedOn = DateTime.Now;
interaction2.InteractionType = _repository.ImagesSuccess;
interaction2.Serialized = timesValue;
interaction2.ShortInfoSerialized = callbackQuery.Data;
interaction2.UserId = callbackQuery.From.Id;
await _repository.Add(interaction2.Once().ToList());
}
}

QueueApi(callbackQuery.From.Id.ToString(), api => React(api, callbackQuery, interaction, imageIndex, times, imagesUnavailableUntil));
return true;
}

private async Task React(TelegramBotClient api, CallbackQuery callbackQuery, IInteraction<TInteractionType> interaction, int imageIndex, int times, DateTime? imagesUnavailableUntil)
{
var prompt = JsonSerializer.Deserialize<ChatBotMemoryPointWithDeserializableTraces>(interaction.Serialized)!.ImageTraces[imageIndex];
var text = times == 0
? prompt
: imagesUnavailableUntil.HasValue
? string.Format(RegenerationUnavailableUntilTemplate, imagesUnavailableUntil!.Value.ToString("f"))
: RegenerationTemplate;

try
{
await api.AnswerCallbackQueryAsync(new AnswerCallbackQueryArgs(callbackQuery.Id)
{
ShowAlert = true,
Text = text,
});
}
catch (BotRequestException e) when (e.Message.Contains("query is too old and response timeout expired or query ID is invalid"))
{
_logger.LogWarning(e, "Couldn't interactively respond to the image callback query. {chat}, {user}.", callbackQuery.Message!.Chat, callbackQuery.From.Id);
}

if (!imagesUnavailableUntil.HasValue && times > 0)
{
var aiParameters = await GetAiParameters();
var images = await _dialogRunner.GenerateImages(
Enumerable.Repeat(prompt, times).ToList(),
new()
{
ImagesVersion = aiParameters.imagesVersion,
UserId = callbackQuery.From.Id,
DomainId = -1,
Version = aiParameters.version!,
ChatId = callbackQuery.Message!.Chat.Id,
UseCase = "direct regeneration",
AdditionalBillingInfo = "images regeneration",
SystemMessage = "no message",
});

images = images.Where(x => !string.IsNullOrWhiteSpace(x)).ToList();
if (images.Any())
{
await SendSeparately(callbackQuery.Message!.Chat.Id, null, callbackQuery.Message!.MessageId, $"⟳ {imageIndex + 1} × {times}", images!, true);
}
else
{
await SendMessage(callbackQuery.Message!.Chat.Id, null, callbackQuery.Message!.MessageId, "Не получилось нарисовать новые изображения.", true);
}
}
}

protected abstract string UnavailableUntilTemplate { get; }
protected virtual string RegenerationUnavailableUntilTemplate => "Я пока отдыхаю, а ты возвращайся {0} UTC.";
protected virtual string RegenerationTemplate => "Минутку, процесс идёт...";

protected virtual void OnInitializing(Update update)
protected virtual void OnInitializing(long userId, long chatId)
{
}

Expand Down Expand Up @@ -311,7 +413,7 @@ protected async Task Respond(Update update)
catch (Exception e)
{
_logger.LogError(e, "Error requesting the data.");
await SendMessage(update, ResponseError, true);
await SendMessage(update.Message!.Chat.Id, update.Message.MessageThreadId, update.Message.MessageId, ResponseError, true);
return;
}
finally
Expand All @@ -333,6 +435,7 @@ protected async Task Respond(Update update)
interaction.ShortInfoSerialized = JsonSerializer.Serialize(result);
interaction.UserId = update.Message.From.Id;
await _repository.Add(interaction.Once().ToList());
var memoryPointInteractionId = interaction.Id;

if (result.Images.Any())
{
Expand All @@ -354,25 +457,30 @@ protected async Task Respond(Update update)
{
try
{
if (TraceImages)
InlineKeyboardMarkup? replyMarkup = null;
if (TraceImages && IsPrivate(update.Message.Chat))
{
try
if (string.IsNullOrWhiteSpace(text))
{
await SendMessage(update, string.Join(Environment.NewLine, newMessagePoint.ImageTraces.Select(x => $"- {x}").Prepend(string.Empty).Prepend("Traces:")), false);
}
catch (Exception e)
{
_logger.LogError(e, "Error writing the traces.");
text = "⚙ управление";
}

replyMarkup = new(Enumerable.Range(0, result.Images.Count).Select(i => (InlineKeyboardButton[])[
new($"👀 {i + 1}") { CallbackData = $"image, {memoryPointInteractionId}, {i}, 0" },
new($"⟳ {i + 1}") { CallbackData = $"image, {memoryPointInteractionId}, {i}, 1" },
new($"⟳ {i + 1} × 2") { CallbackData = $"image, {memoryPointInteractionId}, {i}, 2" },
new($"⟳ {i + 1} × 3") { CallbackData = $"image, {memoryPointInteractionId}, {i}, 3" },
new($"⟳ {i + 1} × 4") { CallbackData = $"image, {memoryPointInteractionId}, {i}, 4" },
]));
}

if (!string.IsNullOrWhiteSpace(text))
{
await SendMessage(update, text, result.Images, false);
await SendMessage(update.Message!.Chat.Id, update.Message.MessageThreadId, update.Message.MessageId, text, result.Images, false, replyMarkup);
}
else
{
await SendMessage(update, result.Images, false);
await SendMessage(update.Message!.Chat.Id, update.Message!.MessageThreadId, update.Message!.MessageId, result.Images, false);
}

return;
Expand All @@ -385,7 +493,7 @@ protected async Task Respond(Update update)

if (!string.IsNullOrWhiteSpace(text))
{
await SendMessage(update, text, false);
await SendMessage(update.Message!.Chat.Id, update.Message.MessageThreadId, update.Message.MessageId, text, false);
}
}

Expand Down

0 comments on commit 110a747

Please sign in to comment.