Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
read and append queue attributes
  • Loading branch information
Amit Mor committed Nov 25, 2021
commit e33e2b39c298cad8bd7e6e09538cdc0666ffefb1
104 changes: 83 additions & 21 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,27 @@ type snsSqsMetadata struct {
messageMaxNumber int64
}

type ArnEquals struct {
AwsSourceArn string `json:"aws\:SourceArn"`
}

type Condition struct {
ArnEquals ArnEquals
}

type Statement struct {
Effect string
Principal string
Action string
Resource string
Condition Condition
}

type policy struct {
Version string
Statement []Statement
}

const (
awsSqsQueueNameKey = "dapr-queue-name"
awsSnsTopicNameKey = "dapr-topic-name"
Expand Down Expand Up @@ -122,6 +143,23 @@ func nameToAWSSanitizedName(name string) string {
return string(s[:j])
}

func (p *policy) statementExists(other *Statement) bool {
for _, s := range p.Statement {
if s.Effect == other.Effect &&
s.Principal == other.Principal &&
s.Action == other.Action &&
s.Resource == other.Resource &&
s.Condition.ArnEquals.AwsSourceArn == other.Condition.ArnEquals.AwsSourceArn {
return true
}
}
return false
}

func (p *policy) addStatement(other *Statement) {
p.Statement = append(p.Statement, *other)
}

func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) {
md := snsSqsMetadata{}
props := metadata.Properties
Expand Down Expand Up @@ -356,7 +394,7 @@ func (s *snsSqs) Publish(req *pubsub.PublishRequest) error {
})

if err != nil {
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %v", req.Topic, topicArn, err)
wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err)
s.logger.Error(wrappedErr)

return wrappedErr
Expand All @@ -375,12 +413,14 @@ func parseTopicArn(arn string) string {
}

func (s *snsSqs) acknowledgeMessage(queueURL string, receiptHandle *string) error {
_, err := s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
if _, err := s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{
QueueUrl: &queueURL,
ReceiptHandle: receiptHandle,
})
}); err != nil {
return fmt.Errorf("error deleting SQS message: %w", err)
}

return fmt.Errorf("error deleting SQS message: %w", err)
return nil
}

func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueueInfo *sqsQueueInfo, handler pubsub.Handler) error {
Expand Down Expand Up @@ -413,7 +453,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue
"message received greater than %v times, moving this message without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName)
}

// otherwise try to handle the message
// otherwise try to handle the message.
var messageBody snsMessage
err = json.Unmarshal([]byte(*(message.Body)), &messageBody)

Expand Down Expand Up @@ -511,22 +551,44 @@ func (s *snsSqs) createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueu

func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(sqsQueueInfo *sqsQueueInfo, snsARN string) error {
// only permit SNS to send messages to SQS using the created subscription.
if _, err := s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}})
if err != nil {
return fmt.Errorf("error getting queue attributes: %w", err)
}

newStatement := &Statement{
Effect: "Allow",
Principal: `{"Service": "sns.amazonaws.com"}`,
Action: "sqs:SendMessage",
Resource: sqsQueueInfo.arn,
Condition: Condition{
ArnEquals: ArnEquals{
AwsSourceArn: snsARN,
},
},
}

policy := &policy{Version: "2012-11-05"}
if policyStr, ok := getQueueAttributesOutput.Attributes[sqs.QueueAttributeNamePolicy]; ok {
// look for the current statement if exists, else add it and store.
if err = json.Unmarshal([]byte(*policyStr), policy); err != nil {
return fmt.Errorf("error unmarshalling sqs policy: %w", err)
}
if policy.statementExists(newStatement) {
// nothing to do.
return nil
}
}

policy.addStatement(newStatement)
b, uerr := json.Marshal(policy)
if uerr != nil {
return fmt.Errorf("failed serializing new sqs policy: %w", uerr)
}

if _, err = s.sqsClient.SetQueueAttributes(&(sqs.SetQueueAttributesInput{
Attributes: map[string]*string{
"Policy": aws.String(fmt.Sprintf(`{
"Version": "2012-10-17",
"Statement": [{
"Effect":"Allow",
"Principal":{"Service": "sns.amazonaws.com"},
"Action":"sqs:SendMessage",
"Resource":"%s",
"Condition": {
"ArnEquals":{
"aws:SourceArn":"%s"
}
}
}]
}`, sqsQueueInfo.arn, snsARN)),
"Policy": aws.String(string(b)),
},
QueueUrl: &sqsQueueInfo.url,
})); err != nil {
Expand Down Expand Up @@ -594,7 +656,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
// subscription creation is idempotent. Subscriptions are unique by topic/queue.
subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{
Attributes: nil,
Endpoint: &queueInfo.arn, // create SQS queue per subscription
Endpoint: &queueInfo.arn, // create SQS queue per subscription.
Protocol: aws.String("sqs"),
ReturnSubscriptionArn: nil,
TopicArn: &topicArn,
Expand Down