Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add tests
  • Loading branch information
cjerad committed May 17, 2022
commit eac92610127fcc25016b0f9471058182904148d8
217 changes: 217 additions & 0 deletions src/test/reconciliation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"time"

Expand Down Expand Up @@ -56,6 +58,7 @@ import (
"github.com/aws/aws-node-termination-handler/pkg/sqsmessage"
"github.com/aws/aws-node-termination-handler/pkg/terminator"
terminatoradapter "github.com/aws/aws-node-termination-handler/pkg/terminator/adapter"
"github.com/aws/aws-node-termination-handler/pkg/webhook"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
Expand Down Expand Up @@ -124,6 +127,8 @@ var _ = Describe("Reconciliation", func() {
resizeCluster func(nodeCount uint)
// Create an ASG lifecycle action state entry for an EC2 instance ID.
createPendingASGLifecycleAction func(EC2InstanceID)
// Requests sent to the configured webhook.
webhookRequests []*http.Request

// Name of default terminator.
terminatorNamespaceName types.NamespacedName
Expand All @@ -144,6 +149,7 @@ var _ = Describe("Reconciliation", func() {
deleteSQSMessageFunc DeleteSQSMessageFunc
cordonFunc kubectlcordondrain.CordonFunc
drainFunc kubectlcordondrain.DrainFunc
webhookSendFunc webhook.HttpSendFunc
)

When("the SQS queue is empty", func() {
Expand Down Expand Up @@ -1427,6 +1433,196 @@ var _ = Describe("Reconciliation", func() {
})
})

When("the terminator has webhook configuration", func() {
const webhookURL = "http://webhook.example.aws"
webhookHeaders := []v1alpha1.HeaderSpec{{Name: "Content-Type", Value: "application/json"}}
webhookTemplate := fmt.Sprintf(
`EventID={{ .EventID }}, Kind={{ .Kind }}, InstanceID={{ .InstanceID }}, NodeName={{ .NodeName }}, StartTime={{ (.StartTime.Format "%s") }}`,
time.RFC3339,
)

When("the reconciliation takes no action", func() {
BeforeEach(func() {
terminator := terminators[terminatorNamespaceName]
terminator.Spec.Webhook.URL = webhookURL
})

It("returns success and requeues the request with the reconciler's configured interval", func() {
Expect(result, err).To(HaveField("RequeueAfter", Equal(reconciler.RequeueInterval)))
})

It("does not send any webhook requests", func() {
Expect(webhookRequests).To(BeEmpty())
})
})

When("the reconciliation acts on a node", func() {
const msgID = "id-123"
msgTime := time.Now().Format(time.RFC3339)

BeforeEach(func() {
resizeCluster(3)

sqsQueues[queueURL] = append(sqsQueues[queueURL], &sqs.Message{
ReceiptHandle: aws.String("msg-1"),
Body: aws.String(fmt.Sprintf(`{
"id": "%s",
"time": "%s",
"source": "aws.ec2",
"detail-type": "EC2 Spot Instance Interruption Warning",
"version": "1",
"detail": {
"instance-id": "%s"
}
}`, msgID, msgTime, instanceIDs[1])),
})

terminator := terminators[terminatorNamespaceName]
terminator.Spec.Webhook.URL = webhookURL
terminator.Spec.Webhook.Headers = webhookHeaders
terminator.Spec.Webhook.Template = webhookTemplate
})

It("returns success and requeues the request with the reconciler's configured interval", func() {
Expect(result, err).To(HaveField("RequeueAfter", Equal(reconciler.RequeueInterval)))
})

It("sends a webhook notification", func() {
Expect(webhookRequests).To(HaveLen(1))
Expect(webhookRequests[0].Method).To(Equal(http.MethodPost))
Expect(webhookRequests[0].URL.String()).To(Equal(webhookURL))
Expect(webhookRequests[0].Header).To(And(
HaveLen(1),
HaveKeyWithValue("Content-Type", And(
HaveLen(1),
ContainElement("application/json"),
))))

Expect(ReadAll(webhookRequests[0].Body)).To(Equal(fmt.Sprintf(
"EventID=%s, Kind=spotInterruption, InstanceID=%s, NodeName=%s, StartTime=%s",
msgID, instanceIDs[1], nodeNames[1], msgTime,
)))
})
})

When("the reconciliation acts on multiple nodes", func() {
msgIDs := []string{"msg-1", "msg-2", "msg-2"}
msgBaseTime := time.Now()
msgTimes := []string{
msgBaseTime.Add(-1 * time.Minute).Format(time.RFC3339),
msgBaseTime.Format(time.RFC3339),
msgBaseTime.Format(time.RFC3339),
}
kinds := []string{"spotInterruption", "scheduledChange", "scheduledChange"}

BeforeEach(func() {
resizeCluster(5)

sqsQueues[queueURL] = append(sqsQueues[queueURL],
&sqs.Message{
ReceiptHandle: aws.String("msg-1"),
Body: aws.String(fmt.Sprintf(`{
"id": "%s",
"time": "%s",
"source": "aws.ec2",
"detail-type": "EC2 Spot Instance Interruption Warning",
"version": "1",
"detail": {
"instance-id": "%s"
}
}`, msgIDs[0], msgTimes[0], instanceIDs[1])),
},
&sqs.Message{
ReceiptHandle: aws.String("msg-1"),
Body: aws.String(fmt.Sprintf(`{
"id": "%s",
"time": "%s",
"source": "aws.health",
"detail-type": "AWS Health Event",
"version": "1",
"detail": {
"service": "EC2",
"eventTypeCategory": "scheduledChange",
"affectedEntities": [
{"entityValue": "%s"},
{"entityValue": "%s"}
]
}
}`, msgIDs[1], msgTimes[1], instanceIDs[2], instanceIDs[3])),
},
)

terminator := terminators[terminatorNamespaceName]
terminator.Spec.Webhook.URL = webhookURL
terminator.Spec.Webhook.Headers = webhookHeaders
terminator.Spec.Webhook.Template = webhookTemplate
})

It("returns success and requeues the request with the reconciler's configured interval", func() {
Expect(result, err).To(HaveField("RequeueAfter", Equal(reconciler.RequeueInterval)))
})

It("sends a webhook notification", func() {
Expect(webhookRequests).To(HaveLen(3))

for i := 0; i < 3; i++ {
Expect(webhookRequests[i].Method).To(Equal(http.MethodPost), "request #%d", i)
Expect(webhookRequests[i].URL.String()).To(Equal(webhookURL), "request #%d", i)
Expect(webhookRequests[i].Header).To(And(
HaveLen(1),
HaveKeyWithValue("Content-Type", And(
HaveLen(1),
ContainElement("application/json"),
))),
"request #%d", i)

Expect(ReadAll(webhookRequests[i].Body)).To(Equal(fmt.Sprintf(
"EventID=%s, Kind=%s, InstanceID=%s, NodeName=%s, StartTime=%s",
msgIDs[i], kinds[i], instanceIDs[i+1], nodeNames[i+1], msgTimes[i],
)),
"request #%d", i,
)
}
})
})

When("there is an error sending the request", func() {
const msgID = "id-123"
msgTime := time.Now().Format(time.RFC3339)

BeforeEach(func() {
resizeCluster(3)

sqsQueues[queueURL] = append(sqsQueues[queueURL], &sqs.Message{
ReceiptHandle: aws.String("msg-1"),
Body: aws.String(fmt.Sprintf(`{
"id": "%s",
"time": "%s",
"source": "aws.ec2",
"detail-type": "EC2 Spot Instance Interruption Warning",
"version": "1",
"detail": {
"instance-id": "%s"
}
}`, msgID, msgTime, instanceIDs[1])),
})

terminator := terminators[terminatorNamespaceName]
terminator.Spec.Webhook.URL = webhookURL
terminator.Spec.Webhook.Headers = webhookHeaders
terminator.Spec.Webhook.Template = webhookTemplate

webhookSendFunc = func(_ *http.Request) (*http.Response, error) {
return nil, errors.New("test error")
}
})

It("returns success and requeues the request with the reconciler's configured interval", func() {
Expect(result, err).To(HaveField("RequeueAfter", Equal(reconciler.RequeueInterval)))
})
})
})

When("there is an error deleting an SQS message", func() {
BeforeEach(func() {
resizeCluster(3)
Expand Down Expand Up @@ -2340,6 +2536,12 @@ var _ = Describe("Reconciliation", func() {
asgLifecycleActions[instanceID] = StatePending
}

webhookRequests = []*http.Request{}
webhookSendFunc = func(req *http.Request) (*http.Response, error) {
webhookRequests = append(webhookRequests, req)
return &http.Response{StatusCode: 200}, nil
}

// 2. Setup stub clients.

describeEC2InstancesFunc = func(ctx aws.Context, input *ec2.DescribeInstancesInput, _ ...awsrequest.Option) (*ec2.DescribeInstancesOutput, error) {
Expand Down Expand Up @@ -2493,6 +2695,10 @@ var _ = Describe("Reconciliation", func() {
Drainer: drainer,
}

newHttpClientDoFunc := func(_ webhook.ProxyFunc) webhook.HttpSendFunc {
return webhookSendFunc
}

reconciler = terminator.Reconciler{
Name: "terminator",
RequeueInterval: time.Duration(10) * time.Second,
Expand All @@ -2508,6 +2714,9 @@ var _ = Describe("Reconciliation", func() {
CordonDrainerBuilder: terminatoradapter.CordonDrainerBuilder{
Builder: cordonDrainerBuilder,
},
WebhookClientBuilder: terminatoradapter.WebhookClientBuilder(
webhook.ClientBuilder(newHttpClientDoFunc).NewClient,
),
}
})

Expand All @@ -2516,3 +2725,11 @@ var _ = Describe("Reconciliation", func() {
result, err = reconciler.Reconcile(ctx, request)
})
})

func ReadAll(r io.Reader) (string, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return "", err
}
return string(bs), nil
}