Skip to content

Commit

Permalink
feat: implement optimized batch payloads (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
co-go authored Mar 22, 2024
1 parent dcc82c8 commit c103cb5
Show file tree
Hide file tree
Showing 2 changed files with 391 additions and 59 deletions.
275 changes: 220 additions & 55 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"os"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -43,18 +45,25 @@ type S3Client interface {
DeleteObjects(ctx context.Context, params *s3.DeleteObjectsInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectsOutput, error)
}

type Logger interface {
Warn(msg string, args ...any)
}

// Client is a wrapper for the [github.com/aws/aws-sdk-go-v2/service/sqs.Client], providing extra
// functionality for retrieving, sending and deleting messages.
type Client struct {
SQSClient
s3c S3Client
logger Logger
bucketName string
messageSizeThreshold int64
batchMessageSizeThreshold int64
alwaysThroughS3 bool
pointerClass string
reservedAttrs []string
objectPrefix string
baseS3PointerSize int
baseAttributeSize int
}

type ClientOption func(*Client) error
Expand All @@ -75,6 +84,7 @@ func New(
c := Client{
SQSClient: sqsc,
s3c: s3c,
logger: slog.New(slog.NewTextHandler(os.Stdout, nil)),
messageSizeThreshold: maxMsgSizeInBytes,
batchMessageSizeThreshold: maxMsgSizeInBytes,
pointerClass: "software.amazon.payloadoffloading.PayloadS3Pointer",
Expand All @@ -89,9 +99,38 @@ func New(
}
}

// create an example s3 pointer
ptr := &s3Pointer{
S3Key: uuid.NewString(),
class: c.pointerClass,
}

// get its string representation
s3PointerBytes, _ := ptr.MarshalJSON()

// store the length of this string to be used when calculating optimal payload sizing in the
// BatchSendMessage method. Note the size with the S3 Bucket is excluded here as this can
// change on a per-call basis.
c.baseS3PointerSize = len(s3PointerBytes)

// similarly, store the base size of the attribute added for extended payloads. Note the string
// representation of the length of the payload is also omitted here as that obviously changes
// between each message.
c.baseAttributeSize = len(c.reservedAttrs[0]) + len("Number")

return &c, nil
}

// WithLogger allows the caller to control how messages will be logged from the client. The expected
// interface matches the `log/slog` function signature and will default to a TextHandler unless
// overwritten by this method.
func WithLogger(logger Logger) ClientOption {
return func(c *Client) error {
c.logger = logger
return nil
}
}

// Set the destination bucket for large messages that are sent by this client. This is a
// soft-requirement for using the SendMessage function.
func WithS3BucketName(bucketName string) ClientOption {
Expand Down Expand Up @@ -167,15 +206,45 @@ func (c *Client) s3Key(filename string) string {
return filename
}

// messageSize describes the size of a SQS message (body and its attributes)
type messageSize struct {
bodySize int64
attributeSize int64
}

// Total returns the full message size
func (m messageSize) Total() int64 {
return m.bodySize + m.attributeSize
}

// ToExtendedSize will convert a messageSize to its equivalent extended payload size. This can be
// useful for estimating the size of a message if it were to be converted without actually having to
// handle the conversion.
func (m messageSize) ToExtendedSize(pointerSize, attributeSize int) messageSize {
n, numDigits := int64(10), int64(1)
for n <= m.bodySize {
n *= 10
numDigits++
}

return messageSize{
bodySize: int64(pointerSize),
attributeSize: int64(attributeSize) + numDigits + m.attributeSize,
}
}

// getMessageSize returns the size of the body and attributes of a message
func (c *Client) messageSize(body *string, attributes map[string]types.MessageAttributeValue) int64 {
return int64(len(*body)) + c.attributeSize(attributes)
func (c *Client) messageSize(body *string, attributes map[string]types.MessageAttributeValue) messageSize {
return messageSize{
bodySize: int64(len(*body)),
attributeSize: c.attributeSize(attributes),
}
}

// messageExceedsThreshold determines if the size of the body and attributes exceeds the configured
// message size threshold
func (c *Client) messageExceedsThreshold(body *string, attributes map[string]types.MessageAttributeValue) bool {
return c.messageSize(body, attributes) > c.messageSizeThreshold
return c.messageSize(body, attributes).Total() > c.messageSizeThreshold
}

// attributeSize will return the size of all provided attributes and their values
Expand Down Expand Up @@ -268,7 +337,7 @@ func (c *Client) SendMessage(ctx context.Context, params *sqs.SendMessageInput,

if c.alwaysThroughS3 || c.messageExceedsThreshold(input.MessageBody, input.MessageAttributes) {
// generate s3 object key
s3Key := c.s3Key(uuid.New().String())
s3Key := c.s3Key(uuid.NewString())

// upload large payload to S3
_, err := c.s3c.PutObject(ctx, &s3.PutObjectInput{
Expand Down Expand Up @@ -312,13 +381,82 @@ func (c *Client) SendMessage(ctx context.Context, params *sqs.SendMessageInput,
return c.SQSClient.SendMessage(ctx, &input, optFns...)
}

// batchMessageMeta is used to maintain a reference to the original payload inside of a batch
// request while also storing metadata about its size to be used during the optimize step
type batchMessageMeta struct {
payloadIndex int
msgSize messageSize
}

// batchPayload stores information about a combination of messages in order to determine the most
// efficient batches
type batchPayload struct {
batchBytes int64
s3PointerSize int
extendedMessages []batchMessageMeta
}

// optimizeBatchPayload will attempt to recursively determine the best way to distribute the passed
// in messages into extended and regular messages in order to optimize the resulting batchPayload
// across a series of factors. The most important factors for determining the best batch are:
// 1. Always prefer a payload that is under the batchMessageSizeThreshold
// 2. Prefer a payload with the LEAST amount of extended messages
// 3. Prefer a payload which sends the LEAST amount of data to S3
//
// These decisions are made in an effort to be cognizant about performance and costs of the caller.
func (c *Client) optimizeBatchPayload(bp *batchPayload, messages []batchMessageMeta) *batchPayload {
// return if we have no more messages to examine
if len(messages) == 0 {
return bp
}

currMsg := messages[0]
numExtMsg := len(bp.extendedMessages)

// case 1 - assume we leave the message as-is
c1 := c.optimizeBatchPayload(&batchPayload{
batchBytes: bp.batchBytes + currMsg.msgSize.Total(),
extendedMessages: bp.extendedMessages,
s3PointerSize: bp.s3PointerSize,
}, messages[1:])

// case 2 - assume we convert the message into an extended payload
extendedMessageSize := currMsg.msgSize.ToExtendedSize(bp.s3PointerSize, c.baseAttributeSize).Total()
c2 := c.optimizeBatchPayload(&batchPayload{
batchBytes: bp.batchBytes + extendedMessageSize,
extendedMessages: append(bp.extendedMessages[:numExtMsg:numExtMsg], currMsg),
s3PointerSize: bp.s3PointerSize,
}, messages[1:])

// preform the checks against factors provided in the function description
if c1.batchBytes <= c.batchMessageSizeThreshold && c2.batchBytes > c.batchMessageSizeThreshold {
return c1
} else if c2.batchBytes <= c.batchMessageSizeThreshold && c1.batchBytes > c.batchMessageSizeThreshold {
return c2
} else if c1.batchBytes > c.batchMessageSizeThreshold && c2.batchBytes > c.batchMessageSizeThreshold {
// in this case, both payloads suck- attempt to return the best of the worst
if c1.batchBytes <= c2.batchBytes {
return c1
}
return c2
} else if len(c2.extendedMessages) > len(c1.extendedMessages) {
return c1
} else if c1.batchBytes > c2.batchBytes {
return c1
}

return c2
}

// Extended SQS Client wrapper around
// [github.com/aws/aws-sdk-go-v2/service/sqs.Client.SendMessageBatch]. When preparing the messages
// for transport, each message will be iterated through and checks performed:
//
// 1. If the size of the message exceeds the messageSizeThreshold, it will be uploaded to S3
// 2. If the size of the message when added to the size of all previous messages exceeds the
// batchMessageSizeThreshold, then that message with be uploaded to S3.
// for transport, if the size of any message exceeds the messageSizeThreshold or if alwaysS3 is set
// to true, the message will be uploaded to S3. For the remaining messages, this method will
// calculate the least amount of messages required to upload to S3 in order to reduce the overall
// payload size under the batchMessageSizeThreshold. If there are multiple combinations to reduce
// the payload below the threshold with uploading the same amount of messages, preference will be
// given to the combination that results in the smallest amount of data sent to S3 in order to
// minimize costs.
//
// For each message that is successfully uploaded to S3, the messages will be altered by:
//
Expand Down Expand Up @@ -354,7 +492,6 @@ func (c *Client) SendMessage(ctx context.Context, params *sqs.SendMessageInput,
func (c *Client) SendMessageBatch(ctx context.Context, params *sqs.SendMessageBatchInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageBatchOutput, error) {
input := *params
copyEntries := make([]types.SendMessageBatchRequestEntry, len(input.Entries))
g := new(errgroup.Group)

// determine bucket name, either from client (default) or from provided SQS URL
queueURL, s3Bucket, found := strings.Cut(*params.QueueUrl, "|")
Expand All @@ -364,70 +501,98 @@ func (c *Client) SendMessageBatch(ctx context.Context, params *sqs.SendMessageBa

input.QueueUrl = &queueURL

batchSizeBytes := int64(0)
// initialize the payload struct which will hold the data describing the batch
bp := &batchPayload{
s3PointerSize: c.baseS3PointerSize + len(s3Bucket),
extendedMessages: make([]batchMessageMeta, 0, len(input.Entries)),
}

// store "regular" (non-extended) messages separately to iterate during optimize step
regularMessages := make([]batchMessageMeta, 0, len(input.Entries))
for i, e := range input.Entries {
i, e := i, e

// always copy the entry, regardless of size
// always copy the entry
copyEntries[i] = e

// calculate the starting message size
msgSize := c.messageSize(e.MessageBody, e.MessageAttributes)

// check if we always send through s3, or if the message size exceeds the threshold, or if
// sending the message would cause the batch itself to overflow
if c.alwaysThroughS3 ||
msgSize > c.messageSizeThreshold ||
batchSizeBytes+msgSize > c.batchMessageSizeThreshold {

// generate s3 object key
s3Key := c.s3Key(uuid.New().String())

// upload large payload to S3
g.Go(func() error {
_, err := c.s3c.PutObject(ctx, &s3.PutObjectInput{
Bucket: &s3Bucket,
Key: aws.String(s3Key),
Body: strings.NewReader(*e.MessageBody),
})

if err != nil {
return fmt.Errorf("unable to upload large payload to s3: %w", err)
}
// build a "meta" struct in order to keep track of this message when optimizing the batch
// payload in the next stage
msgMeta := batchMessageMeta{
payloadIndex: i,
msgSize: msgSize,
}

return nil
})
// check if we always send through s3, or if the message size exceeds the threshold
if c.alwaysThroughS3 || msgSize.Total() > c.messageSizeThreshold {
// track the payload under the batch's extendedMessages
bp.extendedMessages = append(bp.extendedMessages, msgMeta)

// update the base payload size
bp.batchBytes += msgSize.ToExtendedSize(bp.s3PointerSize, c.baseAttributeSize).Total()
} else {
regularMessages = append(regularMessages, msgMeta)
}
}

// attempt to find the most efficient message combination for our batch
bp = c.optimizeBatchPayload(bp, regularMessages)

if bp.batchBytes > c.batchMessageSizeThreshold {
c.logger.Warn(fmt.Sprintf("SendMessageBatch is only able to reduce the batch size to <%d> even though BatchMessageSizeThreshold is set to <%d>. Errors might occur.", bp.batchBytes, c.batchMessageSizeThreshold))
}

// create an s3 pointer that will be uploaded to SQS in place of the large payload
asBytes, err := jsonMarshal(&s3Pointer{
S3BucketName: s3Bucket,
S3Key: s3Key,
class: c.pointerClass,
g := new(errgroup.Group)
for _, em := range bp.extendedMessages {
// generate s3 object key
s3Key := c.s3Key(uuid.NewString())

// deep copy full message payload to send to S3
msgBody := *copyEntries[em.payloadIndex].MessageBody

// upload large payload to S3
g.Go(func() error {
_, err := c.s3c.PutObject(ctx, &s3.PutObjectInput{
Bucket: &s3Bucket,
Key: aws.String(s3Key),
Body: strings.NewReader(msgBody),
})

if err != nil {
return nil, fmt.Errorf("unable to marshal S3 pointer: %w", err)
return fmt.Errorf("unable to upload large payload to s3: %w", err)
}

// copy over all attributes, leaving space for our reserved attribute
updatedAttributes := make(map[string]types.MessageAttributeValue, len(e.MessageAttributes)+1)
for k, v := range e.MessageAttributes {
updatedAttributes[k] = v
}
return nil
})

// assign the reserved attribute to a number containing the size of the original body
updatedAttributes[c.reservedAttrs[0]] = types.MessageAttributeValue{
DataType: aws.String("Number"),
StringValue: aws.String(strconv.Itoa(len(*e.MessageBody))),
}
// create an s3 pointer that will be uploaded to SQS in place of the large payload
asBytes, err := jsonMarshal(&s3Pointer{
S3BucketName: s3Bucket,
S3Key: s3Key,
class: c.pointerClass,
})

if err != nil {
return nil, fmt.Errorf("unable to marshal S3 pointer: %w", err)
}

// override attributes and body in the original message
copyEntries[i].MessageAttributes = updatedAttributes
copyEntries[i].MessageBody = aws.String(string(asBytes))
// copy over all attributes, leaving space for our reserved attribute
updatedAttributes := make(map[string]types.MessageAttributeValue, len(copyEntries[em.payloadIndex].MessageAttributes)+1)
for k, v := range copyEntries[em.payloadIndex].MessageAttributes {
updatedAttributes[k] = v
}

msgSize = c.messageSize(copyEntries[i].MessageBody, copyEntries[i].MessageAttributes)
// assign the reserved attribute to a number containing the size of the original body
updatedAttributes[c.reservedAttrs[0]] = types.MessageAttributeValue{
DataType: aws.String("Number"),
StringValue: aws.String(strconv.FormatInt(em.msgSize.bodySize, 10)),
}

batchSizeBytes += msgSize
// override attributes and body in the original message
copyEntries[em.payloadIndex].MessageAttributes = updatedAttributes
copyEntries[em.payloadIndex].MessageBody = aws.String(string(asBytes))
}

if err := g.Wait(); err != nil {
Expand Down
Loading

0 comments on commit c103cb5

Please sign in to comment.