Skip to content
19 changes: 14 additions & 5 deletions ddlambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ const (
DatadogTraceEnabledEnvVar = "DD_TRACE_ENABLED"
// MergeXrayTracesEnvVar is the environment variable that enables the merging of X-Ray and Datadog traces.
MergeXrayTracesEnvVar = "DD_MERGE_XRAY_TRACES"
// UniversalInstrumentation is the environment variable that enables universal instrumentation with the DD Extension
UniversalInstrumentation = "DD_UNIVERSAL_INSTRUMENTATION"

// DefaultSite to send API messages to.
DefaultSite = "datadoghq.com"
Expand Down Expand Up @@ -183,8 +185,9 @@ func InvokeDryRun(callback func(ctx context.Context), cfg *Config) (interface{},

func (cfg *Config) toTraceConfig() trace.Config {
traceConfig := trace.Config{
DDTraceEnabled: false,
MergeXrayTraces: false,
DDTraceEnabled: false,
MergeXrayTraces: false,
UniversalInstrumentation: false,
}

if cfg != nil {
Expand All @@ -205,6 +208,10 @@ func (cfg *Config) toTraceConfig() trace.Config {
traceConfig.MergeXrayTraces, _ = strconv.ParseBool(os.Getenv(MergeXrayTracesEnvVar))
}

if !traceConfig.UniversalInstrumentation {
traceConfig.UniversalInstrumentation, _ = strconv.ParseBool(os.Getenv(UniversalInstrumentation))
}

return traceConfig
}

Expand All @@ -213,12 +220,14 @@ func initializeListeners(cfg *Config) []wrapper.HandlerListener {
if strings.EqualFold(logLevel, "debug") || (cfg != nil && cfg.DebugLogging) {
logger.SetLogLevel(logger.LevelDebug)
}
extensionManager := extension.BuildExtensionManager()
traceConfig := cfg.toTraceConfig()
extensionManager := extension.BuildExtensionManager(traceConfig.UniversalInstrumentation)
isExtensionRunning := extensionManager.IsExtensionRunning()
metricsConfig := cfg.toMetricsConfig(isExtensionRunning)

// Wrap the handler with listeners that add instrumentation for traces and metrics.
tl := trace.MakeListener(cfg.toTraceConfig(), extensionManager)
ml := metrics.MakeListener(cfg.toMetricsConfig(isExtensionRunning), extensionManager)
tl := trace.MakeListener(traceConfig, extensionManager)
ml := metrics.MakeListener(metricsConfig, extensionManager)
return []wrapper.HandlerListener{
&tl, &ml,
}
Expand Down
130 changes: 111 additions & 19 deletions internal/extension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,30 @@
package extension

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"time"

"github.com/DataDog/datadog-lambda-go/internal/logger"

"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
)

type ddTraceContext string

const (
DdTraceId ddTraceContext = "x-datadog-trace-id"
DdParentId ddTraceContext = "x-datadog-parent-id"
DdSpanId ddTraceContext = "x-datadog-span-id"
DdSamplingPriority ddTraceContext = "x-datadog-sampling-priority"
DdInvocationError ddTraceContext = "x-datadog-invocation-error"

DdSeverlessSpan ddTraceContext = "dd-tracer-serverless-span"
DdLambdaResponse ddTraceContext = "dd-response"
)

const (
Expand All @@ -23,30 +41,38 @@ const (
// want to let it having some time for its cold start so we should not set this too low.
timeout = 3000 * time.Millisecond

helloUrl = "http://localhost:8124/lambda/hello"
flushUrl = "http://localhost:8124/lambda/flush"
helloUrl = "http://localhost:8124/lambda/hello"
flushUrl = "http://localhost:8124/lambda/flush"
startInvocationUrl = "http://localhost:8124/lambda/start-invocation"
endInvocationUrl = "http://localhost:8124/lambda/end-invocation"

extensionPath = "/opt/extensions/datadog-agent"
)

type ExtensionManager struct {
helloRoute string
flushRoute string
extensionPath string
httpClient HTTPClient
isExtensionRunning bool
helloRoute string
flushRoute string
extensionPath string
startInvocationUrl string
endInvocationUrl string
httpClient HTTPClient
isExtensionRunning bool
isUniversalInstrumentation bool
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked into adding trace.Config + metrics.Config to this, but it was throwing errors for a circular dependency.

}

type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}

func BuildExtensionManager() *ExtensionManager {
func BuildExtensionManager(isUniversalInstrumentation bool) *ExtensionManager {
em := &ExtensionManager{
helloRoute: helloUrl,
flushRoute: flushUrl,
extensionPath: extensionPath,
httpClient: &http.Client{Timeout: timeout},
helloRoute: helloUrl,
flushRoute: flushUrl,
startInvocationUrl: startInvocationUrl,
endInvocationUrl: endInvocationUrl,
extensionPath: extensionPath,
httpClient: &http.Client{Timeout: timeout},
isUniversalInstrumentation: isUniversalInstrumentation,
}
em.checkAgentRunning()
return em
Expand All @@ -57,15 +83,81 @@ func (em *ExtensionManager) checkAgentRunning() {
logger.Debug("Will use the API")
em.isExtensionRunning = false
} else {
req, _ := http.NewRequest(http.MethodGet, em.helloRoute, nil)
if response, err := em.httpClient.Do(req); err == nil && response.StatusCode == 200 {
logger.Debug("Will use the Serverless Agent")
em.isExtensionRunning = true
} else {
logger.Debug("Will use the API since the Serverless Agent was detected but the hello route was unreachable")
em.isExtensionRunning = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if you think we should still make this call to /hello in the case where universal instrumentation is not enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha yup, good catch. I made a note to add this back but totally forgot 👍 side note: /hello is a confusing route name for its actual usage in the extension 😢

logger.Debug("Will use the Serverless Agent")
em.isExtensionRunning = true

// Tell the extension not to create an execution span if universal instrumentation is disabled
if !em.isUniversalInstrumentation {
req, _ := http.NewRequest(http.MethodGet, em.helloRoute, nil)
if response, err := em.httpClient.Do(req); err == nil && response.StatusCode == 200 {
logger.Debug("Hit the extension /hello route")
} else {
logger.Debug("Will use the API since the Serverless Agent was detected but the hello route was unreachable")
em.isExtensionRunning = false
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, if there was an error hitting the hello route, we'd set em.isExtensionRunning = false and log.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woops, adding that back in!

}
}
}

func (em *ExtensionManager) SendStartInvocationRequest(ctx context.Context, eventPayload json.RawMessage) context.Context {
body := bytes.NewBuffer(eventPayload)
req, _ := http.NewRequest(http.MethodPost, em.startInvocationUrl, body)

if response, err := em.httpClient.Do(req); err == nil && response.StatusCode == 200 {
// Propagate dd-trace context from the extension response if found in the response headers
traceId := response.Header.Get(string(DdTraceId))
if traceId != "" {
ctx = context.WithValue(ctx, DdTraceId, traceId)
}
parentId := response.Header.Get(string(DdParentId))
if parentId != "" {
ctx = context.WithValue(ctx, DdParentId, parentId)
}
samplingPriority := response.Header.Get(string(DdSamplingPriority))
if samplingPriority != "" {
ctx = context.WithValue(ctx, DdSamplingPriority, samplingPriority)
}
}
return ctx
}

func (em *ExtensionManager) SendEndInvocationRequest(ctx context.Context, functionExecutionSpan ddtrace.Span, err error) {
// Handle Lambda response
lambdaResponse := ctx.Value(DdLambdaResponse)
content, responseErr := json.Marshal(lambdaResponse)
if responseErr != nil {
content = []byte("{}")
}
body := bytes.NewBuffer(content)
req, _ := http.NewRequest(http.MethodPost, em.endInvocationUrl, body)

// Mark the invocation as an error if any
if err != nil {
req.Header.Set(string(DdInvocationError), "true")
}

// Extract the DD trace context and pass them to the extension via request headers
traceId, ok := ctx.Value(DdTraceId).(string)
if ok {
req.Header.Set(string(DdTraceId), traceId)
if parentId, ok := ctx.Value(DdParentId).(string); ok {
req.Header.Set(string(DdParentId), parentId)
}
if spanId, ok := ctx.Value(DdSpanId).(string); ok {
req.Header.Set(string(DdSpanId), spanId)
}
if samplingPriority, ok := ctx.Value(DdSamplingPriority).(string); ok {
req.Header.Set(string(DdSamplingPriority), samplingPriority)
}
} else {
req.Header.Set(string(DdTraceId), fmt.Sprint(functionExecutionSpan.Context().TraceID()))
req.Header.Set(string(DdSpanId), fmt.Sprint(functionExecutionSpan.Context().SpanID()))
}

resp, err := em.httpClient.Do(req)
if err != nil || resp.StatusCode != 200 {
logger.Error(fmt.Errorf("could not send end invocation payload to the extension: %v", err))
}
}

func (em *ExtensionManager) IsExtensionRunning() bool {
Expand Down
103 changes: 102 additions & 1 deletion internal/extension/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
package extension

import (
"bytes"
"context"
"fmt"
"net/http"
"os"
"testing"

"github.com/DataDog/datadog-lambda-go/internal/logger"
"github.com/stretchr/testify/assert"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)

type ClientErrorMock struct {
Expand All @@ -26,6 +30,19 @@ type ClientSuccessMock struct {
type ClientSuccess202Mock struct {
}

type ClientSuccessStartInvoke struct {
headers http.Header
}

type ClientSuccessEndInvoke struct {
}

const (
mockTraceId = "1"
mockParentId = "2"
mockSamplingPriority = "3"
)

func (c *ClientErrorMock) Do(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("KO")
}
Expand All @@ -38,11 +55,30 @@ func (c *ClientSuccess202Mock) Do(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: 202, Status: "KO"}, nil
}

func (c *ClientSuccessStartInvoke) Do(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: 200, Status: "KO", Header: c.headers}, nil
}

func (c *ClientSuccessEndInvoke) Do(req *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: 200, Status: "KO"}, nil
}

func captureLog(f func()) string {
var buf bytes.Buffer
logger.SetOutput(&buf)
f()
logger.SetOutput(os.Stdout)
return buf.String()
}

func TestBuildExtensionManager(t *testing.T) {
em := BuildExtensionManager()
em := BuildExtensionManager(false)
assert.Equal(t, "http://localhost:8124/lambda/hello", em.helloRoute)
assert.Equal(t, "http://localhost:8124/lambda/flush", em.flushRoute)
assert.Equal(t, "http://localhost:8124/lambda/start-invocation", em.startInvocationUrl)
assert.Equal(t, "http://localhost:8124/lambda/end-invocation", em.endInvocationUrl)
assert.Equal(t, "/opt/extensions/datadog-agent", em.extensionPath)
assert.Equal(t, false, em.isUniversalInstrumentation)
assert.NotNil(t, em.httpClient)
}

Expand Down Expand Up @@ -96,3 +132,68 @@ func TestFlushSuccess(t *testing.T) {
err := em.Flush()
assert.Nil(t, err)
}

func TestExtensionStartInvoke(t *testing.T) {
em := &ExtensionManager{
startInvocationUrl: startInvocationUrl,
httpClient: &ClientSuccessStartInvoke{},
}
ctx := em.SendStartInvocationRequest(context.TODO(), []byte{})
traceId := ctx.Value(DdTraceId)
parentId := ctx.Value(DdParentId)
samplingPriority := ctx.Value(DdSamplingPriority)
err := em.Flush()

assert.Nil(t, err)
assert.Nil(t, traceId)
assert.Nil(t, parentId)
assert.Nil(t, samplingPriority)
}

func TestExtensionStartInvokeWithTraceContext(t *testing.T) {
headers := http.Header{}
headers.Set(string(DdTraceId), mockTraceId)
headers.Set(string(DdParentId), mockParentId)
headers.Set(string(DdSamplingPriority), mockSamplingPriority)

em := &ExtensionManager{
startInvocationUrl: startInvocationUrl,
httpClient: &ClientSuccessStartInvoke{
headers: headers,
},
}
ctx := em.SendStartInvocationRequest(context.TODO(), []byte{})
traceId := ctx.Value(DdTraceId)
parentId := ctx.Value(DdParentId)
samplingPriority := ctx.Value(DdSamplingPriority)
err := em.Flush()

assert.Nil(t, err)
assert.Equal(t, mockTraceId, traceId)
assert.Equal(t, mockParentId, parentId)
assert.Equal(t, mockSamplingPriority, samplingPriority)
}

func TestExtensionEndInvocation(t *testing.T) {
em := &ExtensionManager{
endInvocationUrl: endInvocationUrl,
httpClient: &ClientSuccessEndInvoke{},
}
span := tracer.StartSpan("aws.lambda")
logOutput := captureLog(func() { em.SendEndInvocationRequest(context.TODO(), span, nil) })
span.Finish()

assert.Equal(t, "", logOutput)
}

func TestExtensionEndInvocationError(t *testing.T) {
em := &ExtensionManager{
endInvocationUrl: endInvocationUrl,
httpClient: &ClientErrorMock{},
}
span := tracer.StartSpan("aws.lambda")
logOutput := captureLog(func() { em.SendEndInvocationRequest(context.TODO(), span, nil) })
span.Finish()

assert.Contains(t, logOutput, "could not send end invocation payload to the extension")
}
Loading