115 lines
2.8 KiB
Go
115 lines
2.8 KiB
Go
package sqs
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/awserr"
|
|
"github.com/aws/aws-sdk-go/aws/request"
|
|
)
|
|
|
|
var (
|
|
errChecksumMissingBody = fmt.Errorf("cannot compute checksum. missing body")
|
|
errChecksumMissingMD5 = fmt.Errorf("cannot verify checksum. missing response MD5")
|
|
)
|
|
|
|
func setupChecksumValidation(r *request.Request) {
|
|
if aws.BoolValue(r.Config.DisableComputeChecksums) {
|
|
return
|
|
}
|
|
|
|
switch r.Operation.Name {
|
|
case opSendMessage:
|
|
r.Handlers.Unmarshal.PushBack(verifySendMessage)
|
|
case opSendMessageBatch:
|
|
r.Handlers.Unmarshal.PushBack(verifySendMessageBatch)
|
|
case opReceiveMessage:
|
|
r.Handlers.Unmarshal.PushBack(verifyReceiveMessage)
|
|
}
|
|
}
|
|
|
|
func verifySendMessage(r *request.Request) {
|
|
if r.DataFilled() && r.ParamsFilled() {
|
|
in := r.Params.(*SendMessageInput)
|
|
out := r.Data.(*SendMessageOutput)
|
|
err := checksumsMatch(in.MessageBody, out.MD5OfMessageBody)
|
|
if err != nil {
|
|
setChecksumError(r, err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
func verifySendMessageBatch(r *request.Request) {
|
|
if r.DataFilled() && r.ParamsFilled() {
|
|
entries := map[string]*SendMessageBatchResultEntry{}
|
|
ids := []string{}
|
|
|
|
out := r.Data.(*SendMessageBatchOutput)
|
|
for _, entry := range out.Successful {
|
|
entries[*entry.Id] = entry
|
|
}
|
|
|
|
in := r.Params.(*SendMessageBatchInput)
|
|
for _, entry := range in.Entries {
|
|
if e := entries[*entry.Id]; e != nil {
|
|
err := checksumsMatch(entry.MessageBody, e.MD5OfMessageBody)
|
|
if err != nil {
|
|
ids = append(ids, *e.MessageId)
|
|
}
|
|
}
|
|
}
|
|
if len(ids) > 0 {
|
|
setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
|
|
}
|
|
}
|
|
}
|
|
|
|
func verifyReceiveMessage(r *request.Request) {
|
|
if r.DataFilled() && r.ParamsFilled() {
|
|
ids := []string{}
|
|
out := r.Data.(*ReceiveMessageOutput)
|
|
for i, msg := range out.Messages {
|
|
err := checksumsMatch(msg.Body, msg.MD5OfBody)
|
|
if err != nil {
|
|
if msg.MessageId == nil {
|
|
if r.Config.Logger != nil {
|
|
r.Config.Logger.Log(fmt.Sprintf(
|
|
"WARN: SQS.ReceiveMessage failed checksum request id: %s, message %d has no message ID.",
|
|
r.RequestID, i,
|
|
))
|
|
}
|
|
continue
|
|
}
|
|
|
|
ids = append(ids, *msg.MessageId)
|
|
}
|
|
}
|
|
if len(ids) > 0 {
|
|
setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
|
|
}
|
|
}
|
|
}
|
|
|
|
func checksumsMatch(body, expectedMD5 *string) error {
|
|
if body == nil {
|
|
return errChecksumMissingBody
|
|
} else if expectedMD5 == nil {
|
|
return errChecksumMissingMD5
|
|
}
|
|
|
|
msum := md5.Sum([]byte(*body))
|
|
sum := hex.EncodeToString(msum[:])
|
|
if sum != *expectedMD5 {
|
|
return fmt.Errorf("expected MD5 checksum '%s', got '%s'", *expectedMD5, sum)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func setChecksumError(r *request.Request, format string, args ...interface{}) {
|
|
r.Retryable = aws.Bool(true)
|
|
r.Error = awserr.New("InvalidChecksum", fmt.Sprintf(format, args...), nil)
|
|
}
|