diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 79892dcbda..1c8a7a332e 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -2,7 +2,6 @@ package snssqs import ( "context" - "crypto/sha256" "encoding/json" "errors" "fmt" @@ -20,8 +19,8 @@ import ( type snsSqs struct { // key is the topic name, value is the ARN of the topic topics map[string]string - // key is the hashed topic name, value is the actual topic name - topicHash map[string]string + // key is the sanitized topic name, value is the actual topic name + topicSanitized map[string]string // key is the topic name, value holds the ARN of the queue and its url queues map[string]*sqsQueueInfo snsClient *sns.SNS @@ -93,13 +92,28 @@ func parseInt64(input string, propertyName string) (int64, error) { return int64(number), nil } -// take a name and hash it for compatibility with AWS resource names -// the output is fixed at 64 characters -func nameToHash(name string) string { - h := sha256.New() - h.Write([]byte(name)) +// sanitize topic/queue name to conform with: +// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-queues.html +func nameToAWSSanitizedName(name string) string { + s := []byte(name) + + j := 0 + for _, b := range s { + if ('a' <= b && b <= 'z') || + ('A' <= b && b <= 'Z') || + ('0' <= b && b <= '9') || + (b == '-') || + (b == '_') { + s[j] = b + j++ + + if j == 80 { + break + } + } + } - return fmt.Sprintf("%x", h.Sum(nil)) + return string(s[:j]) } func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, error) { @@ -207,7 +221,7 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error { // both Publish and Subscribe need reference the topic ARN // track these ARNs in this map s.topics = make(map[string]string) - s.topicHash = make(map[string]string) + s.topicSanitized = make(map[string]string) s.queues = make(map[string]*sqsQueueInfo) sess, err := aws_auth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint) if err != nil { @@ -220,16 +234,16 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error { } func (s *snsSqs) createTopic(topic string) (string, string, error) { - hashedName := nameToHash(topic) + sanitizedName := nameToAWSSanitizedName(topic) createTopicResponse, err := s.snsClient.CreateTopic(&sns.CreateTopicInput{ - Name: aws.String(hashedName), + Name: aws.String(sanitizedName), Tags: []*sns.Tag{{Key: aws.String(awsSnsTopicNameKey), Value: aws.String(topic)}}, }) if err != nil { return "", "", err } - return *(createTopicResponse.TopicArn), hashedName, nil + return *(createTopicResponse.TopicArn), sanitizedName, nil } // get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist @@ -245,7 +259,7 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) { s.logger.Debugf("No topic ARN found for %s\n Creating topic instead.", topic) - topicArn, hashedName, err := s.createTopic(topic) + topicArn, sanitizedName, err := s.createTopic(topic) if err != nil { s.logger.Errorf("error creating new topic %s: %v", topic, err) @@ -254,14 +268,14 @@ func (s *snsSqs) getOrCreateTopic(topic string) (string, error) { // record topic ARN s.topics[topic] = topicArn - s.topicHash[hashedName] = topic + s.topicSanitized[sanitizedName] = topic return topicArn, nil } func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) { createQueueResponse, err := s.sqsClient.CreateQueue(&sqs.CreateQueueInput{ - QueueName: aws.String(nameToHash(queueName)), + QueueName: aws.String(nameToAWSSanitizedName(queueName)), Tags: map[string]*string{awsSqsQueueNameKey: aws.String(queueName)}, }) if err != nil { @@ -397,7 +411,7 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo *sqsQueueInfo, ha } topic := parseTopicArn(messageBody.TopicArn) - topic = s.topicHash[topic] + topic = s.topicSanitized[topic] err = handler(context.Background(), &pubsub.NewMessage{ Data: []byte(messageBody.Message), Topic: topic, @@ -491,12 +505,6 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) } func (s *snsSqs) Close() error { - for _, sub := range s.subscriptions { - s.snsClient.Unsubscribe(&sns.UnsubscribeInput{ - SubscriptionArn: sub, - }) - } - return nil } diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 3979df4001..835100bc83 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -1,8 +1,6 @@ package snssqs import ( - "fmt" - "strings" "testing" "github.com/dapr/components-contrib/pubsub" @@ -236,22 +234,13 @@ func Test_parseInt64(t *testing.T) { r.Error(err) } -func Test_nameToHash(t *testing.T) { +func Test_replaceNameToAWSSanitizedName(t *testing.T) { r := require.New(t) - // This string is too long and contains invalid character for either an SQS queue or an SNS topic - hashedName := nameToHash(` - Some invalid name // for an AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^Some invalid + s := `Some_invalid-name // for an AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^Some invalid name // for an - AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^ - `) - - r.Equal(64, len(hashedName)) - // Output is only expected to contain lower case characters representing valid hexadecimal numerals - for _, c := range hashedName { - r.True( - strings.ContainsAny( - "abcdef0123456789", string(c)), - fmt.Sprintf("Invalid character %s in hashed name", string(c))) - } + AWS resource &*()*&&^Some invalid name // for an AWS resource &*()*&&^` + v := nameToAWSSanitizedName(s) + r.Equal(80, len(v)) + r.Equal("Some_invalid-nameforanAWSresourceSomeinvalidnameforanAWSresourceSomeinvalidnamef", v) }