From da6fc4497178efd4888651eb8733be9e3aad77ea Mon Sep 17 00:00:00 2001 From: alufers Date: Tue, 17 Oct 2023 15:12:16 +0200 Subject: [PATCH] Add OnAfterUpdate hook, pass stop processing flag via context --- paczkobot/image_scanning_service.go | 12 ++++++--- providers/cainiao/cainiao_provider.go | 11 ++++---- tghelpers/ask_service.go | 18 ++++++++------ tghelpers/ask_service_test.go | 4 +-- tghelpers/command_dispatcher.go | 36 +++++++++++++-------------- tghelpers/command_dispatcher_test.go | 8 ++++-- tghelpers/update_hook.go | 17 ++++++++++++- 7 files changed, 67 insertions(+), 39 deletions(-) diff --git a/paczkobot/image_scanning_service.go b/paczkobot/image_scanning_service.go index 8bd9e1c..09359b0 100644 --- a/paczkobot/image_scanning_service.go +++ b/paczkobot/image_scanning_service.go @@ -35,7 +35,7 @@ func NewImageScanningService(app *BotApp) *ImageScanningService { } } -func (i *ImageScanningService) OnUpdate(ctx context.Context) bool { +func (i *ImageScanningService) OnUpdate(ctx context.Context) context.Context { update := tghelpers.UpdateFromCtx(ctx) if update.Message != nil && update.Message.Photo != nil && len(update.Message.Photo) > 0 { @@ -46,16 +46,16 @@ func (i *ImageScanningService) OnUpdate(ctx context.Context) bool { }) if err != nil { log.Printf("Failed to get file: %v", err) - return false + return ctx } url := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", i.App.Bot.Token, file.FilePath) err = i.ScanIncomingImage(ctx, tghelpers.ArgsFromCtx(ctx), url) if err != nil { log.Printf("Failed to ScanIncomingImage: %v", err) } - return true + return tghelpers.WithStopProcessingCommands(ctx) } - return false + return ctx } func (i *ImageScanningService) ScanIncomingImage(ctx context.Context, args *tghelpers.CommandArguments, url string) error { @@ -178,6 +178,10 @@ func (i *ImageScanningService) ScanIncomingImage(ctx context.Context, args *tghe return nil } +func (i *ImageScanningService) OnAfterUpdate(ctx context.Context) context.Context { + return ctx +} + func (*ImageScanningService) DrawResultPoints(img image.Image, points []gozxing.ResultPoint) image.Image { if len(points) <= 1 { return img diff --git a/providers/cainiao/cainiao_provider.go b/providers/cainiao/cainiao_provider.go index a5583ff..38928f3 100644 --- a/providers/cainiao/cainiao_provider.go +++ b/providers/cainiao/cainiao_provider.go @@ -23,7 +23,6 @@ func (pp *CainiaoProvider) MatchesNumber(trackingNumber string) bool { } func (pp *CainiaoProvider) Track(ctx context.Context, trackingNumber string) (*commondata.TrackingData, error) { - req, err := http.NewRequest( "GET", "https://global.cainiao.com/global/detail.json?mailNos="+url.QueryEscape(trackingNumber)+"&lang=en-US&language=en-US", @@ -37,6 +36,7 @@ func (pp *CainiaoProvider) Track(ctx context.Context, trackingNumber string) (*c if err != nil { return nil, commonerrors.NewNetworkError(pp.GetName(), req) } + defer httpResponse.Body.Close() if httpResponse.StatusCode != 200 { return nil, commonerrors.NotFoundError @@ -83,20 +83,21 @@ func (pp *CainiaoProvider) Track(ctx context.Context, trackingNumber string) (*c nil, ) if err != nil { - return td, nil + return td, nil //nolint:nilerr } commondata.SetCommonHTTPHeaders(&cityReq.Header) cityResp, err := http.DefaultClient.Do(cityReq) if err != nil { - return td, nil + return td, nil //nolint:nilerr } + defer cityResp.Body.Close() if cityResp.StatusCode != 200 { - return td, nil + return td, nil //nolint:nilerr } var cityResponse GetCityResponse err = json.NewDecoder(cityResp.Body).Decode(&cityResponse) if err != nil || !cityResponse.Success { - return td, nil + return td, nil //nolint:nilerr } td.Destination = cityResponse.Module + ", " + td.Destination diff --git a/tghelpers/ask_service.go b/tghelpers/ask_service.go index a30847a..5a99695 100644 --- a/tghelpers/ask_service.go +++ b/tghelpers/ask_service.go @@ -27,13 +27,13 @@ func NewAskService(bot BotAPI) *AskService { } // Implements UpdateHook -func (a *AskService) OnUpdate(ctx context.Context) bool { +func (a *AskService) OnUpdate(ctx context.Context) context.Context { update := UpdateFromCtx(ctx) a.AskCallbacksMutex.Lock() defer a.AskCallbacksMutex.Unlock() if update.CallbackQuery != nil { if update.CallbackQuery.Message == nil || update.CallbackQuery.Message.Chat == nil { - return false + return ctx } chatID := update.CallbackQuery.Message.Chat.ID if update.CallbackQuery.Data == "/cancel" { @@ -45,7 +45,7 @@ func (a *AskService) OnUpdate(ctx context.Context) bool { callback("", errors.New("canceled")) delete(a.AskCallbacks, chatID) } - return true + return WithStopProcessingCommands(ctx) } if update.CallbackQuery.Data == "/yes" { if callback, ok := a.AskCallbacks[chatID]; ok { @@ -56,7 +56,7 @@ func (a *AskService) OnUpdate(ctx context.Context) bool { callback("", nil) delete(a.AskCallbacks, chatID) } - return true + return WithStopProcessingCommands(ctx) } if strings.HasPrefix(update.CallbackQuery.Data, "/sugg ") { val := strings.TrimPrefix(update.CallbackQuery.Data, "/sugg ") @@ -77,7 +77,7 @@ func (a *AskService) OnUpdate(ctx context.Context) bool { if callback, ok := a.AskCallbacks[update.Message.Chat.ID]; ok { callback("", errors.New("canceled")) delete(a.AskCallbacks, update.Message.Chat.ID) - return false + return ctx } } @@ -88,11 +88,15 @@ func (a *AskService) OnUpdate(ctx context.Context) bool { } callback(update.Message.Text, nil) delete(a.AskCallbacks, update.Message.Chat.ID) - return true + return WithStopProcessingCommands(ctx) } } - return false + return ctx +} + +func (a *AskService) OnAfterUpdate(ctx context.Context) context.Context { + return ctx } // AskForArgument asks the user at the specified chatID for a text value. diff --git a/tghelpers/ask_service_test.go b/tghelpers/ask_service_test.go index ef52c44..38a0f39 100644 --- a/tghelpers/ask_service_test.go +++ b/tghelpers/ask_service_test.go @@ -27,7 +27,7 @@ func TestAskServiceReturnsFalseForUnrelatedUpdates(t *testing.T) { }, ) res := askService.OnUpdate(ctx) - assert.False(t, res) + assert.True(t, res.Value(tghelpers.StopProcessingCommandsCtxKey) == nil) // should return false because it's not a related update } func TestAskServiceConfirmWorks(t *testing.T) { @@ -63,7 +63,7 @@ func TestAskServiceConfirmWorks(t *testing.T) { }, ) res := askService.OnUpdate(ctx) - assert.True(t, res) // should return true because it's a related update + assert.True(t, res.Value(tghelpers.StopProcessingCommandsCtxKey) != nil) // should return true because it's a related update }() return msg, nil diff --git a/tghelpers/command_dispatcher.go b/tghelpers/command_dispatcher.go index 03d58d0..fdda1cd 100644 --- a/tghelpers/command_dispatcher.go +++ b/tghelpers/command_dispatcher.go @@ -79,30 +79,30 @@ func (d *CommandDispatcher) processIncomingUpdate(ctx context.Context, u tgbotap } for _, hook := range d.UpdateHooks { - if hook.OnUpdate(ctx) { - return // hook has handled the message stop processing - } + ctx = hook.OnUpdate(ctx) } + shouldProcessCommands := ctx.Value(StopProcessingCommandsCtxKey) == nil var err error - - for _, cmd := range d.Commands { - if CommandMatches(cmd, cmdText) { - args.Command = cmd - for i, argTpl := range cmd.Arguments() { - if argTpl.Variadic { - args.NamedArguments[argTpl.Name] = strings.Join(args.Arguments[i:], " ") - break + if shouldProcessCommands { + for _, cmd := range d.Commands { + if CommandMatches(cmd, cmdText) { + args.Command = cmd + for i, argTpl := range cmd.Arguments() { + if argTpl.Variadic { + args.NamedArguments[argTpl.Name] = strings.Join(args.Arguments[i:], " ") + break + } + if i >= len(args.Arguments) { + break + } + args.NamedArguments[argTpl.Name] = args.Arguments[i] } - if i >= len(args.Arguments) { - break - } - args.NamedArguments[argTpl.Name] = args.Arguments[i] - } - err = cmd.Execute(ctx) + err = cmd.Execute(ctx) - break + break + } } } diff --git a/tghelpers/command_dispatcher_test.go b/tghelpers/command_dispatcher_test.go index 5b35b76..a256481 100644 --- a/tghelpers/command_dispatcher_test.go +++ b/tghelpers/command_dispatcher_test.go @@ -11,9 +11,13 @@ type FakeUpdateHook struct { didRun bool } -func (h *FakeUpdateHook) OnUpdate(context.Context) bool { +func (h *FakeUpdateHook) OnUpdate(ctx context.Context) context.Context { h.didRun = true - return true + return ctx +} + +func (h *FakeUpdateHook) OnAfterUpdate(ctx context.Context) context.Context { + return ctx } // a test that chcks if command dispatcher executes update hooks diff --git a/tghelpers/update_hook.go b/tghelpers/update_hook.go index 63ee3ec..a1fe2e4 100644 --- a/tghelpers/update_hook.go +++ b/tghelpers/update_hook.go @@ -4,6 +4,19 @@ import ( "context" ) +type stopProcessingCommandsCtxKeyType struct{} + +// StopProcessingCommandsCtxKey is a context key that can be used to stop +// processing commands for an update. + +// It should be added to the context by an UpdateHook, which +// wishes to stop processing commands. +var StopProcessingCommandsCtxKey = stopProcessingCommandsCtxKeyType{} + +func WithStopProcessingCommands(ctx context.Context) context.Context { + return context.WithValue(ctx, StopProcessingCommandsCtxKey, true) +} + // UpdateHook allows a service to listen for all telegram updates // before they are processed for commands type UpdateHook interface { @@ -12,5 +25,7 @@ type UpdateHook interface { // as handled by the hook. Further processing is stopped. // The update shall be extracted from the context using // tghelpers.UpdateFromCtx(ctx) - OnUpdate(context.Context) bool + OnUpdate(context.Context) context.Context + + OnAfterUpdate(context.Context) context.Context }