diff --git a/client.go b/client.go index 6740eac..daf2785 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "io" + "log/slog" + "os" "regexp" "strconv" "strings" @@ -43,11 +45,16 @@ type S3Client interface { DeleteObjects(ctx context.Context, params *s3.DeleteObjectsInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectsOutput, error) } +type Logger interface { + Warn(msg string, args ...any) +} + // Client is a wrapper for the [github.com/aws/aws-sdk-go-v2/service/sqs.Client], providing extra // functionality for retrieving, sending and deleting messages. type Client struct { SQSClient s3c S3Client + logger Logger bucketName string messageSizeThreshold int64 batchMessageSizeThreshold int64 @@ -55,6 +62,8 @@ type Client struct { pointerClass string reservedAttrs []string objectPrefix string + baseS3PointerSize int + baseAttributeSize int } type ClientOption func(*Client) error @@ -75,6 +84,7 @@ func New( c := Client{ SQSClient: sqsc, s3c: s3c, + logger: slog.New(slog.NewTextHandler(os.Stdout, nil)), messageSizeThreshold: maxMsgSizeInBytes, batchMessageSizeThreshold: maxMsgSizeInBytes, pointerClass: "software.amazon.payloadoffloading.PayloadS3Pointer", @@ -89,9 +99,38 @@ func New( } } + // create an example s3 pointer + ptr := &s3Pointer{ + S3Key: uuid.NewString(), + class: c.pointerClass, + } + + // get its string representation + s3PointerBytes, _ := ptr.MarshalJSON() + + // store the length of this string to be used when calculating optimal payload sizing in the + // BatchSendMessage method. Note the size with the S3 Bucket is excluded here as this can + // change on a per-call basis. + c.baseS3PointerSize = len(s3PointerBytes) + + // similarly, store the base size of the attribute added for extended payloads. Note the string + // representation of the length of the payload is also omitted here as that obviously changes + // between each message. + c.baseAttributeSize = len(c.reservedAttrs[0]) + len("Number") + return &c, nil } +// WithLogger allows the caller to control how messages will be logged from the client. The expected +// interface matches the `log/slog` function signature and will default to a TextHandler unless +// overwritten by this method. +func WithLogger(logger Logger) ClientOption { + return func(c *Client) error { + c.logger = logger + return nil + } +} + // Set the destination bucket for large messages that are sent by this client. This is a // soft-requirement for using the SendMessage function. func WithS3BucketName(bucketName string) ClientOption { @@ -167,15 +206,45 @@ func (c *Client) s3Key(filename string) string { return filename } +// messageSize describes the size of a SQS message (body and its attributes) +type messageSize struct { + bodySize int64 + attributeSize int64 +} + +// Total returns the full message size +func (m messageSize) Total() int64 { + return m.bodySize + m.attributeSize +} + +// ToExtendedSize will convert a messageSize to its equivalent extended payload size. This can be +// useful for estimating the size of a message if it were to be converted without actually having to +// handle the conversion. +func (m messageSize) ToExtendedSize(pointerSize, attributeSize int) messageSize { + n, numDigits := int64(10), int64(1) + for n <= m.bodySize { + n *= 10 + numDigits++ + } + + return messageSize{ + bodySize: int64(pointerSize), + attributeSize: int64(attributeSize) + numDigits + m.attributeSize, + } +} + // getMessageSize returns the size of the body and attributes of a message -func (c *Client) messageSize(body *string, attributes map[string]types.MessageAttributeValue) int64 { - return int64(len(*body)) + c.attributeSize(attributes) +func (c *Client) messageSize(body *string, attributes map[string]types.MessageAttributeValue) messageSize { + return messageSize{ + bodySize: int64(len(*body)), + attributeSize: c.attributeSize(attributes), + } } // messageExceedsThreshold determines if the size of the body and attributes exceeds the configured // message size threshold func (c *Client) messageExceedsThreshold(body *string, attributes map[string]types.MessageAttributeValue) bool { - return c.messageSize(body, attributes) > c.messageSizeThreshold + return c.messageSize(body, attributes).Total() > c.messageSizeThreshold } // attributeSize will return the size of all provided attributes and their values @@ -268,7 +337,7 @@ func (c *Client) SendMessage(ctx context.Context, params *sqs.SendMessageInput, if c.alwaysThroughS3 || c.messageExceedsThreshold(input.MessageBody, input.MessageAttributes) { // generate s3 object key - s3Key := c.s3Key(uuid.New().String()) + s3Key := c.s3Key(uuid.NewString()) // upload large payload to S3 _, err := c.s3c.PutObject(ctx, &s3.PutObjectInput{ @@ -312,13 +381,82 @@ func (c *Client) SendMessage(ctx context.Context, params *sqs.SendMessageInput, return c.SQSClient.SendMessage(ctx, &input, optFns...) } +// batchMessageMeta is used to maintain a reference to the original payload inside of a batch +// request while also storing metadata about its size to be used during the optimize step +type batchMessageMeta struct { + payloadIndex int + msgSize messageSize +} + +// batchPayload stores information about a combination of messages in order to determine the most +// efficient batches +type batchPayload struct { + batchBytes int64 + s3PointerSize int + extendedMessages []batchMessageMeta +} + +// optimizeBatchPayload will attempt to recursively determine the best way to distribute the passed +// in messages into extended and regular messages in order to optimize the resulting batchPayload +// across a series of factors. The most important factors for determining the best batch are: +// 1. Always prefer a payload that is under the batchMessageSizeThreshold +// 2. Prefer a payload with the LEAST amount of extended messages +// 3. Prefer a payload which sends the LEAST amount of data to S3 +// +// These decisions are made in an effort to be cognizant about performance and costs of the caller. +func (c *Client) optimizeBatchPayload(bp *batchPayload, messages []batchMessageMeta) *batchPayload { + // return if we have no more messages to examine + if len(messages) == 0 { + return bp + } + + currMsg := messages[0] + numExtMsg := len(bp.extendedMessages) + + // case 1 - assume we leave the message as-is + c1 := c.optimizeBatchPayload(&batchPayload{ + batchBytes: bp.batchBytes + currMsg.msgSize.Total(), + extendedMessages: bp.extendedMessages, + s3PointerSize: bp.s3PointerSize, + }, messages[1:]) + + // case 2 - assume we convert the message into an extended payload + extendedMessageSize := currMsg.msgSize.ToExtendedSize(bp.s3PointerSize, c.baseAttributeSize).Total() + c2 := c.optimizeBatchPayload(&batchPayload{ + batchBytes: bp.batchBytes + extendedMessageSize, + extendedMessages: append(bp.extendedMessages[:numExtMsg:numExtMsg], currMsg), + s3PointerSize: bp.s3PointerSize, + }, messages[1:]) + + // preform the checks against factors provided in the function description + if c1.batchBytes <= c.batchMessageSizeThreshold && c2.batchBytes > c.batchMessageSizeThreshold { + return c1 + } else if c2.batchBytes <= c.batchMessageSizeThreshold && c1.batchBytes > c.batchMessageSizeThreshold { + return c2 + } else if c1.batchBytes > c.batchMessageSizeThreshold && c2.batchBytes > c.batchMessageSizeThreshold { + // in this case, both payloads suck- attempt to return the best of the worst + if c1.batchBytes <= c2.batchBytes { + return c1 + } + return c2 + } else if len(c2.extendedMessages) > len(c1.extendedMessages) { + return c1 + } else if c1.batchBytes > c2.batchBytes { + return c1 + } + + return c2 +} + // Extended SQS Client wrapper around // [github.com/aws/aws-sdk-go-v2/service/sqs.Client.SendMessageBatch]. When preparing the messages -// for transport, each message will be iterated through and checks performed: -// -// 1. If the size of the message exceeds the messageSizeThreshold, it will be uploaded to S3 -// 2. If the size of the message when added to the size of all previous messages exceeds the -// batchMessageSizeThreshold, then that message with be uploaded to S3. +// for transport, if the size of any message exceeds the messageSizeThreshold or if alwaysS3 is set +// to true, the message will be uploaded to S3. For the remaining messages, this method will +// calculate the least amount of messages required to upload to S3 in order to reduce the overall +// payload size under the batchMessageSizeThreshold. If there are multiple combinations to reduce +// the payload below the threshold with uploading the same amount of messages, preference will be +// given to the combination that results in the smallest amount of data sent to S3 in order to +// minimize costs. // // For each message that is successfully uploaded to S3, the messages will be altered by: // @@ -354,7 +492,6 @@ func (c *Client) SendMessage(ctx context.Context, params *sqs.SendMessageInput, func (c *Client) SendMessageBatch(ctx context.Context, params *sqs.SendMessageBatchInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageBatchOutput, error) { input := *params copyEntries := make([]types.SendMessageBatchRequestEntry, len(input.Entries)) - g := new(errgroup.Group) // determine bucket name, either from client (default) or from provided SQS URL queueURL, s3Bucket, found := strings.Cut(*params.QueueUrl, "|") @@ -364,70 +501,98 @@ func (c *Client) SendMessageBatch(ctx context.Context, params *sqs.SendMessageBa input.QueueUrl = &queueURL - batchSizeBytes := int64(0) + // initialize the payload struct which will hold the data describing the batch + bp := &batchPayload{ + s3PointerSize: c.baseS3PointerSize + len(s3Bucket), + extendedMessages: make([]batchMessageMeta, 0, len(input.Entries)), + } + + // store "regular" (non-extended) messages separately to iterate during optimize step + regularMessages := make([]batchMessageMeta, 0, len(input.Entries)) for i, e := range input.Entries { i, e := i, e - // always copy the entry, regardless of size + // always copy the entry copyEntries[i] = e + // calculate the starting message size msgSize := c.messageSize(e.MessageBody, e.MessageAttributes) - // check if we always send through s3, or if the message size exceeds the threshold, or if - // sending the message would cause the batch itself to overflow - if c.alwaysThroughS3 || - msgSize > c.messageSizeThreshold || - batchSizeBytes+msgSize > c.batchMessageSizeThreshold { - - // generate s3 object key - s3Key := c.s3Key(uuid.New().String()) - - // upload large payload to S3 - g.Go(func() error { - _, err := c.s3c.PutObject(ctx, &s3.PutObjectInput{ - Bucket: &s3Bucket, - Key: aws.String(s3Key), - Body: strings.NewReader(*e.MessageBody), - }) - - if err != nil { - return fmt.Errorf("unable to upload large payload to s3: %w", err) - } + // build a "meta" struct in order to keep track of this message when optimizing the batch + // payload in the next stage + msgMeta := batchMessageMeta{ + payloadIndex: i, + msgSize: msgSize, + } - return nil - }) + // check if we always send through s3, or if the message size exceeds the threshold + if c.alwaysThroughS3 || msgSize.Total() > c.messageSizeThreshold { + // track the payload under the batch's extendedMessages + bp.extendedMessages = append(bp.extendedMessages, msgMeta) + + // update the base payload size + bp.batchBytes += msgSize.ToExtendedSize(bp.s3PointerSize, c.baseAttributeSize).Total() + } else { + regularMessages = append(regularMessages, msgMeta) + } + } + + // attempt to find the most efficient message combination for our batch + bp = c.optimizeBatchPayload(bp, regularMessages) + + if bp.batchBytes > c.batchMessageSizeThreshold { + c.logger.Warn(fmt.Sprintf("SendMessageBatch is only able to reduce the batch size to <%d> even though BatchMessageSizeThreshold is set to <%d>. Errors might occur.", bp.batchBytes, c.batchMessageSizeThreshold)) + } - // create an s3 pointer that will be uploaded to SQS in place of the large payload - asBytes, err := jsonMarshal(&s3Pointer{ - S3BucketName: s3Bucket, - S3Key: s3Key, - class: c.pointerClass, + g := new(errgroup.Group) + for _, em := range bp.extendedMessages { + // generate s3 object key + s3Key := c.s3Key(uuid.NewString()) + + // deep copy full message payload to send to S3 + msgBody := *copyEntries[em.payloadIndex].MessageBody + + // upload large payload to S3 + g.Go(func() error { + _, err := c.s3c.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &s3Bucket, + Key: aws.String(s3Key), + Body: strings.NewReader(msgBody), }) if err != nil { - return nil, fmt.Errorf("unable to marshal S3 pointer: %w", err) + return fmt.Errorf("unable to upload large payload to s3: %w", err) } - // copy over all attributes, leaving space for our reserved attribute - updatedAttributes := make(map[string]types.MessageAttributeValue, len(e.MessageAttributes)+1) - for k, v := range e.MessageAttributes { - updatedAttributes[k] = v - } + return nil + }) - // assign the reserved attribute to a number containing the size of the original body - updatedAttributes[c.reservedAttrs[0]] = types.MessageAttributeValue{ - DataType: aws.String("Number"), - StringValue: aws.String(strconv.Itoa(len(*e.MessageBody))), - } + // create an s3 pointer that will be uploaded to SQS in place of the large payload + asBytes, err := jsonMarshal(&s3Pointer{ + S3BucketName: s3Bucket, + S3Key: s3Key, + class: c.pointerClass, + }) + + if err != nil { + return nil, fmt.Errorf("unable to marshal S3 pointer: %w", err) + } - // override attributes and body in the original message - copyEntries[i].MessageAttributes = updatedAttributes - copyEntries[i].MessageBody = aws.String(string(asBytes)) + // copy over all attributes, leaving space for our reserved attribute + updatedAttributes := make(map[string]types.MessageAttributeValue, len(copyEntries[em.payloadIndex].MessageAttributes)+1) + for k, v := range copyEntries[em.payloadIndex].MessageAttributes { + updatedAttributes[k] = v + } - msgSize = c.messageSize(copyEntries[i].MessageBody, copyEntries[i].MessageAttributes) + // assign the reserved attribute to a number containing the size of the original body + updatedAttributes[c.reservedAttrs[0]] = types.MessageAttributeValue{ + DataType: aws.String("Number"), + StringValue: aws.String(strconv.FormatInt(em.msgSize.bodySize, 10)), } - batchSizeBytes += msgSize + // override attributes and body in the original message + copyEntries[em.payloadIndex].MessageAttributes = updatedAttributes + copyEntries[em.payloadIndex].MessageBody = aws.String(string(asBytes)) } if err := g.Wait(); err != nil { diff --git a/client_test.go b/client_test.go index 803a281..6e2766f 100644 --- a/client_test.go +++ b/client_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "log/slog" "strings" "testing" @@ -99,6 +100,8 @@ func TestNewClient(t *testing.T) { } func TestNewClientOptions(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + c, err := New( nil, nil, @@ -108,6 +111,7 @@ func TestNewClientOptions(t *testing.T) { WithReservedAttributeNames([]string{"Reserved", "Attributes"}), WithS3BucketName("BUCKET!"), WithObjectPrefix("custom_prefix"), + WithLogger(logger), ) assert.Nil(t, err) @@ -119,6 +123,7 @@ func TestNewClientOptions(t *testing.T) { assert.Equal(t, []string{"Reserved", "Attributes"}, c.reservedAttrs) assert.Equal(t, "BUCKET!", c.bucketName) assert.Equal(t, "custom_prefix", c.objectPrefix) + assert.Equal(t, logger, c.logger) } func TestNewClientOptionsFailure(t *testing.T) { @@ -132,6 +137,43 @@ func TestNewClientOptionsFailure(t *testing.T) { assert.Nil(t, c) } +func TestNewClientSizeCalculation(t *testing.T) { + testCases := []struct { + desc string + options []ClientOption + expectedAttributeSize int + expectedPointerSize int + }{ + { + desc: "default sizes", + options: nil, + expectedAttributeSize: 25, + expectedPointerSize: 121, + }, + { + desc: "custom pointer class", + options: []ClientOption{WithPointerClass("custom.pointer")}, + expectedAttributeSize: 25, + expectedPointerSize: 85, + }, + { + desc: "custom reserved attributes", + options: []ClientOption{WithReservedAttributeNames([]string{"CustomAttr"})}, + expectedAttributeSize: 16, + expectedPointerSize: 121, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + c, err := New(nil, nil, tC.options...) + assert.NoError(t, err) + assert.Equal(t, tC.expectedAttributeSize, c.baseAttributeSize) + assert.Equal(t, tC.expectedPointerSize, c.baseS3PointerSize) + }) + } + +} + func TestAttributeSize(t *testing.T) { c, err := New(nil, nil) assert.Nil(t, err) @@ -384,6 +426,130 @@ func TestSendMessageMarshalFailure(t *testing.T) { assert.ErrorContains(t, err, "unable to marshal S3 pointer") } +func TestOptimizeBatchPayload(t *testing.T) { + testCases := []struct { + desc string + clientOptions []ClientOption + baseBatchPayload batchPayload + messages []batchMessageMeta + checks func(*testing.T, *batchPayload) + }{ + { + desc: "payload under threshold", + clientOptions: []ClientOption{WithBatchMessageSizeThreshold(50)}, + messages: []batchMessageMeta{ + {payloadIndex: 0, msgSize: messageSize{bodySize: 10}}, + {payloadIndex: 1, msgSize: messageSize{bodySize: 30}}, + {payloadIndex: 2, msgSize: messageSize{bodySize: 5}}, + }, + baseBatchPayload: batchPayload{}, + checks: func(t *testing.T, bp *batchPayload) { + assert.Equal(t, int64(10+30+5), bp.batchBytes) + assert.Len(t, bp.extendedMessages, 0) + }, + }, + { + desc: "small payloads under large threshold", + clientOptions: []ClientOption{WithBatchMessageSizeThreshold(5000)}, + messages: []batchMessageMeta{ + {payloadIndex: 0, msgSize: messageSize{bodySize: 10}}, + {payloadIndex: 1, msgSize: messageSize{bodySize: 30}}, + {payloadIndex: 2, msgSize: messageSize{bodySize: 5}}, + }, + baseBatchPayload: batchPayload{}, + checks: func(t *testing.T, bp *batchPayload) { + assert.Equal(t, int64(10+30+5), bp.batchBytes) + assert.Len(t, bp.extendedMessages, 0) + }, + }, + { + desc: "payload equals threshold", + clientOptions: []ClientOption{WithBatchMessageSizeThreshold(45)}, + messages: []batchMessageMeta{ + {payloadIndex: 0, msgSize: messageSize{bodySize: 10}}, + {payloadIndex: 1, msgSize: messageSize{bodySize: 30}}, + {payloadIndex: 2, msgSize: messageSize{bodySize: 5}}, + }, + baseBatchPayload: batchPayload{}, + checks: func(t *testing.T, bp *batchPayload) { + assert.Equal(t, int64(10+30+5), bp.batchBytes) + assert.Len(t, bp.extendedMessages, 0) + }, + }, + { + desc: "single message causes payload to exceed threshold", + clientOptions: []ClientOption{WithBatchMessageSizeThreshold(200)}, + messages: []batchMessageMeta{ + {payloadIndex: 1, msgSize: messageSize{bodySize: 30}}, + {payloadIndex: 2, msgSize: messageSize{bodySize: 5}}, + {payloadIndex: 3, msgSize: messageSize{bodySize: 200}}, // replaced with a payload of size 149 + }, + baseBatchPayload: batchPayload{}, + checks: func(t *testing.T, bp *batchPayload) { + assert.Equal(t, int64(30+5+149), bp.batchBytes) + assert.Len(t, bp.extendedMessages, 1) + assert.Equal(t, 3, bp.extendedMessages[0].payloadIndex) + }, + }, + { + desc: "all messages cause payload to exceed threshold", + clientOptions: []ClientOption{WithBatchMessageSizeThreshold(300)}, + messages: []batchMessageMeta{ + {payloadIndex: 1, msgSize: messageSize{bodySize: 1000}}, + {payloadIndex: 2, msgSize: messageSize{bodySize: 100}}, + {payloadIndex: 3, msgSize: messageSize{bodySize: 120}}, + {payloadIndex: 4, msgSize: messageSize{bodySize: 200}}, + }, + baseBatchPayload: batchPayload{}, + checks: func(t *testing.T, bp *batchPayload) { + assert.Equal(t, int64(150+100+120+149), bp.batchBytes) + assert.Len(t, bp.extendedMessages, 2) + assert.Equal(t, 1, bp.extendedMessages[0].payloadIndex) + assert.Equal(t, 4, bp.extendedMessages[1].payloadIndex) + }, + }, + { + desc: "minimize data sent to s3", + clientOptions: []ClientOption{WithBatchMessageSizeThreshold(600)}, + messages: []batchMessageMeta{ + {payloadIndex: 1, msgSize: messageSize{bodySize: 400}}, + {payloadIndex: 2, msgSize: messageSize{bodySize: 450}}, + }, + baseBatchPayload: batchPayload{}, + checks: func(t *testing.T, bp *batchPayload) { + assert.Equal(t, int64(149+450), bp.batchBytes) + assert.Len(t, bp.extendedMessages, 1) + assert.Equal(t, 1, bp.extendedMessages[0].payloadIndex) + }, + }, + { + desc: "minimize data sent to s3 - alternate", + clientOptions: []ClientOption{WithBatchMessageSizeThreshold(800)}, + messages: []batchMessageMeta{ + {payloadIndex: 2, msgSize: messageSize{bodySize: 450}}, + {payloadIndex: 1, msgSize: messageSize{bodySize: 400}}, + }, + baseBatchPayload: batchPayload{}, + checks: func(t *testing.T, bp *batchPayload) { + assert.Equal(t, int64(149+450), bp.batchBytes) + assert.Len(t, bp.extendedMessages, 1) + assert.Equal(t, 1, bp.extendedMessages[0].payloadIndex) + }, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + c, err := New(nil, nil, tC.clientOptions...) + assert.NoError(t, err) + + tC.baseBatchPayload.s3PointerSize += c.baseS3PointerSize + + bp := c.optimizeBatchPayload(&tC.baseBatchPayload, tC.messages) + tC.checks(t, bp) + }) + } +} + func TestSendMessageBatch(t *testing.T) { key1, key2 := new(string), new(string) ms3c := &mockS3Client{&mock.Mock{}} @@ -521,9 +687,10 @@ func TestSendMessageBatchSizeAboveThreshold(t *testing.T) { assert.Equal(t, "entry_2", *params.Entries[1].Id) assert.Equal(t, "entry_3", *params.Entries[2].Id) assert.Nil(t, params.Entries[0].MessageAttributes["ExtendedPayloadSize"].StringValue) - assert.Equal(t, "43", *params.Entries[1].MessageAttributes["ExtendedPayloadSize"].StringValue) - assert.Equal(t, "53", *params.Entries[2].MessageAttributes["ExtendedPayloadSize"].StringValue) + assert.Nil(t, params.Entries[1].MessageAttributes["ExtendedPayloadSize"].StringValue) + assert.Equal(t, "500", *params.Entries[2].MessageAttributes["ExtendedPayloadSize"].StringValue) assert.Equal(t, "testing body 1", *params.Entries[0].MessageBody) + assert.Equal(t, "testing body 2 with a little larger payload", *params.Entries[1].MessageBody) return true }), mock.Anything). @@ -544,13 +711,13 @@ func TestSendMessageBatchSizeAboveThreshold(t *testing.T) { }, { Id: aws.String("entry_3"), - MessageBody: aws.String("testing body 3 with an even bigger and larger payload"), + MessageBody: aws.String(strings.Repeat("large", 100)), }, }, QueueUrl: aws.String("some_url"), }) - assert.Len(t, ms3c.Calls, 2) + assert.Len(t, ms3c.Calls, 1) assert.Nil(t, err) }