Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions pubsub/aws/snssqs/snssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package snssqs

import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
Expand All @@ -20,8 +19,8 @@ import (
type snsSqs struct {
// key is the topic name, value is the ARN of the topic
topics map[string]string
// key is the hashed topic name, value is the actual topic name
topicHash map[string]string
// key is the sanitized topic name, value is the actual topic name
topicSanitized map[string]string
// key is the topic name, value holds the ARN of the queue and its url
queues map[string]*sqsQueueInfo
snsClient *sns.SNS
Expand Down Expand Up @@ -93,13 +92,28 @@ func parseInt64(input string, propertyName string) (int64, error) {
return int64(number), nil
}

// take a name and hash it for compatibility with AWS resource names
// the output is fixed at 64 characters
func nameToHash(name string) string {
h := sha256.New()
h.Write([]byte(name))
// sanitize topic/queue name to conform with:
// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-queues.html
func nameToAWSSanitizedName(name string) string {
s := []byte(name)

j := 0
for _, b := range s {
if ('a' <= b && b <= 'z') ||
('A' <= b && b <= 'Z') ||
('0' <= b && b <= '9') ||
(b == '-') ||
(b == '_') {
s[j] = b
j++

if j == 80 {
break
}
}
}

return fmt.Sprintf("%x", h.Sum(nil))
return string(s[:j])
}

func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) {
Expand Down Expand Up @@ -207,7 +221,7 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
// both Publish and Subscribe need reference the topic ARN
// track these ARNs in this map
s.topics = make(map[string]string)
s.topicHash = make(map[string]string)
s.topicSanitized = make(map[string]string)
s.queues = make(map[string]*sqsQueueInfo)
sess, err := aws_auth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint)
if err != nil {
Expand All @@ -220,16 +234,16 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error {
}

func (s *snsSqs) createTopic(topic string) (string, string, error) {
hashedName := nameToHash(topic)
sanitizedName := nameToAWSSanitizedName(topic)
createTopicResponse, err := s.snsClient.CreateTopic(&sns.CreateTopicInput{
Name: aws.String(hashedName),
Name: aws.String(sanitizedName),
Tags: []*sns.Tag{{Key: aws.String(awsSnsTopicNameKey), Value: aws.String(topic)}},
})
if err != nil {
return "", "", err
}

return *(createTopicResponse.TopicArn), hashedName, nil
return *(createTopicResponse.TopicArn), sanitizedName, nil
}

// get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist
Expand All @@ -245,7 +259,7 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {

s.logger.Debugf("No topic ARN found for %s\n Creating topic instead.", topic)

topicArn, hashedName, err := s.createTopic(topic)
topicArn, sanitizedName, err := s.createTopic(topic)
if err != nil {
s.logger.Errorf("error creating new topic %s: %v", topic, err)

Expand All @@ -254,14 +268,14 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) {

// record topic ARN
s.topics[topic] = topicArn
s.topicHash[hashedName] = topic
s.topicSanitized[sanitizedName] = topic

return topicArn, nil
}

func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) {
createQueueResponse, err := s.sqsClient.CreateQueue(&sqs.CreateQueueInput{
QueueName: aws.String(nameToHash(queueName)),
QueueName: aws.String(nameToAWSSanitizedName(queueName)),
Tags: map[string]*string{awsSqsQueueNameKey: aws.String(queueName)},
})
if err != nil {
Expand Down Expand Up @@ -397,7 +411,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo *sqsQueueInfo, ha
}

topic := parseTopicArn(messageBody.TopicArn)
topic = s.topicHash[topic]
topic = s.topicSanitized[topic]
err = handler(context.Background(), &pubsub.NewMessage{
Data: []byte(messageBody.Message),
Topic: topic,
Expand Down Expand Up @@ -491,12 +505,6 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler)
}

func (s *snsSqs) Close() error {
for _, sub := range s.subscriptions {
s.snsClient.Unsubscribe(&sns.UnsubscribeInput{
SubscriptionArn: sub,
})
}

return nil
}

Expand Down
23 changes: 6 additions & 17 deletions pubsub/aws/snssqs/snssqs_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package snssqs

import (
"fmt"
"strings"
"testing"

"github.com/dapr/components-contrib/pubsub"
Expand Down Expand Up @@ -236,22 +234,13 @@ func Test_parseInt64(t *testing.T) {
r.Error(err)
}

func Test_nameToHash(t *testing.T) {
func Test_replaceNameToAWSSanitizedName(t *testing.T) {
r := require.New(t)

// This string is too long and contains invalid character for either an SQS queue or an SNS topic
hashedName := nameToHash(`
Some invalid name // for an AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^Some invalid
s := `Some_invalid-name // for an AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^Some invalid
name // for an AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^Some invalid name // for an
AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^
`)

r.Equal(64, len(hashedName))
// Output is only expected to contain lower case characters representing valid hexadecimal numerals
for _, c := range hashedName {
r.True(
strings.ContainsAny(
"abcdef0123456789", string(c)),
fmt.Sprintf("Invalid character %s in hashed name", string(c)))
}
AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^`
v := nameToAWSSanitizedName(s)
r.Equal(80, len(v))
r.Equal("Some_invalid-nameforanAWSresourceSomeinvalidnameforanAWSresourceSomeinvalidnamef", v)
}