99package  extension
1010
1111import  (
12+ 	"bytes" 
13+ 	"context" 
14+ 	"encoding/json" 
1215	"fmt" 
1316	"net/http" 
1417	"os" 
1518	"time" 
1619
1720	"github.com/DataDog/datadog-lambda-go/internal/logger" 
21+ 
22+ 	"gopkg.in/DataDog/dd-trace-go.v1/ddtrace" 
23+ )
24+ 
25+ type  ddTraceContext  string 
26+ 
27+ const  (
28+ 	DdTraceId           ddTraceContext  =  "x-datadog-trace-id" 
29+ 	DdParentId          ddTraceContext  =  "x-datadog-parent-id" 
30+ 	DdSpanId            ddTraceContext  =  "x-datadog-span-id" 
31+ 	DdSamplingPriority  ddTraceContext  =  "x-datadog-sampling-priority" 
32+ 	DdInvocationError   ddTraceContext  =  "x-datadog-invocation-error" 
33+ 
34+ 	DdSeverlessSpan   ddTraceContext  =  "dd-tracer-serverless-span" 
35+ 	DdLambdaResponse  ddTraceContext  =  "dd-response" 
1836)
1937
2038const  (
@@ -23,30 +41,38 @@ const (
2341	// want to let it having some time for its cold start so we should not set this too low. 
2442	timeout  =  3000  *  time .Millisecond 
2543
26- 	helloUrl  =  "http://localhost:8124/lambda/hello" 
27- 	flushUrl  =  "http://localhost:8124/lambda/flush" 
44+ 	helloUrl            =  "http://localhost:8124/lambda/hello" 
45+ 	flushUrl            =  "http://localhost:8124/lambda/flush" 
46+ 	startInvocationUrl  =  "http://localhost:8124/lambda/start-invocation" 
47+ 	endInvocationUrl    =  "http://localhost:8124/lambda/end-invocation" 
2848
2949	extensionPath  =  "/opt/extensions/datadog-agent" 
3050)
3151
3252type  ExtensionManager  struct  {
33- 	helloRoute          string 
34- 	flushRoute          string 
35- 	extensionPath       string 
36- 	httpClient          HTTPClient 
37- 	isExtensionRunning  bool 
53+ 	helloRoute                  string 
54+ 	flushRoute                  string 
55+ 	extensionPath               string 
56+ 	startInvocationUrl          string 
57+ 	endInvocationUrl            string 
58+ 	httpClient                  HTTPClient 
59+ 	isExtensionRunning          bool 
60+ 	isUniversalInstrumentation  bool 
3861}
3962
4063type  HTTPClient  interface  {
4164	Do (req  * http.Request ) (* http.Response , error )
4265}
4366
44- func  BuildExtensionManager () * ExtensionManager  {
67+ func  BuildExtensionManager (isUniversalInstrumentation   bool ) * ExtensionManager  {
4568	em  :=  & ExtensionManager {
46- 		helloRoute :    helloUrl ,
47- 		flushRoute :    flushUrl ,
48- 		extensionPath : extensionPath ,
49- 		httpClient :    & http.Client {Timeout : timeout },
69+ 		helloRoute :                 helloUrl ,
70+ 		flushRoute :                 flushUrl ,
71+ 		startInvocationUrl :         startInvocationUrl ,
72+ 		endInvocationUrl :           endInvocationUrl ,
73+ 		extensionPath :              extensionPath ,
74+ 		httpClient :                 & http.Client {Timeout : timeout },
75+ 		isUniversalInstrumentation : isUniversalInstrumentation ,
5076	}
5177	em .checkAgentRunning ()
5278	return  em 
@@ -57,15 +83,81 @@ func (em *ExtensionManager) checkAgentRunning() {
5783		logger .Debug ("Will use the API" )
5884		em .isExtensionRunning  =  false 
5985	} else  {
60- 		req , _  :=  http .NewRequest (http .MethodGet , em .helloRoute , nil )
61- 		if  response , err  :=  em .httpClient .Do (req ); err  ==  nil  &&  response .StatusCode  ==  200  {
62- 			logger .Debug ("Will use the Serverless Agent" )
63- 			em .isExtensionRunning  =  true 
64- 		} else  {
65- 			logger .Debug ("Will use the API since the Serverless Agent was detected but the hello route was unreachable" )
66- 			em .isExtensionRunning  =  false 
86+ 		logger .Debug ("Will use the Serverless Agent" )
87+ 		em .isExtensionRunning  =  true 
88+ 
89+ 		// Tell the extension not to create an execution span if universal instrumentation is disabled 
90+ 		if  ! em .isUniversalInstrumentation  {
91+ 			req , _  :=  http .NewRequest (http .MethodGet , em .helloRoute , nil )
92+ 			if  response , err  :=  em .httpClient .Do (req ); err  ==  nil  &&  response .StatusCode  ==  200  {
93+ 				logger .Debug ("Hit the extension /hello route" )
94+ 			} else  {
95+ 				logger .Debug ("Will use the API since the Serverless Agent was detected but the hello route was unreachable" )
96+ 				em .isExtensionRunning  =  false 
97+ 			}
98+ 		}
99+ 	}
100+ }
101+ 
102+ func  (em  * ExtensionManager ) SendStartInvocationRequest (ctx  context.Context , eventPayload  json.RawMessage ) context.Context  {
103+ 	body  :=  bytes .NewBuffer (eventPayload )
104+ 	req , _  :=  http .NewRequest (http .MethodPost , em .startInvocationUrl , body )
105+ 
106+ 	if  response , err  :=  em .httpClient .Do (req ); err  ==  nil  &&  response .StatusCode  ==  200  {
107+ 		// Propagate dd-trace context from the extension response if found in the response headers 
108+ 		traceId  :=  response .Header .Get (string (DdTraceId ))
109+ 		if  traceId  !=  ""  {
110+ 			ctx  =  context .WithValue (ctx , DdTraceId , traceId )
111+ 		}
112+ 		parentId  :=  response .Header .Get (string (DdParentId ))
113+ 		if  parentId  !=  ""  {
114+ 			ctx  =  context .WithValue (ctx , DdParentId , parentId )
115+ 		}
116+ 		samplingPriority  :=  response .Header .Get (string (DdSamplingPriority ))
117+ 		if  samplingPriority  !=  ""  {
118+ 			ctx  =  context .WithValue (ctx , DdSamplingPriority , samplingPriority )
67119		}
68120	}
121+ 	return  ctx 
122+ }
123+ 
124+ func  (em  * ExtensionManager ) SendEndInvocationRequest (ctx  context.Context , functionExecutionSpan  ddtrace.Span , err  error ) {
125+ 	// Handle Lambda response 
126+ 	lambdaResponse  :=  ctx .Value (DdLambdaResponse )
127+ 	content , responseErr  :=  json .Marshal (lambdaResponse )
128+ 	if  responseErr  !=  nil  {
129+ 		content  =  []byte ("{}" )
130+ 	}
131+ 	body  :=  bytes .NewBuffer (content )
132+ 	req , _  :=  http .NewRequest (http .MethodPost , em .endInvocationUrl , body )
133+ 
134+ 	// Mark the invocation as an error if any 
135+ 	if  err  !=  nil  {
136+ 		req .Header .Set (string (DdInvocationError ), "true" )
137+ 	}
138+ 
139+ 	// Extract the DD trace context and pass them to the extension via request headers 
140+ 	traceId , ok  :=  ctx .Value (DdTraceId ).(string )
141+ 	if  ok  {
142+ 		req .Header .Set (string (DdTraceId ), traceId )
143+ 		if  parentId , ok  :=  ctx .Value (DdParentId ).(string ); ok  {
144+ 			req .Header .Set (string (DdParentId ), parentId )
145+ 		}
146+ 		if  spanId , ok  :=  ctx .Value (DdSpanId ).(string ); ok  {
147+ 			req .Header .Set (string (DdSpanId ), spanId )
148+ 		}
149+ 		if  samplingPriority , ok  :=  ctx .Value (DdSamplingPriority ).(string ); ok  {
150+ 			req .Header .Set (string (DdSamplingPriority ), samplingPriority )
151+ 		}
152+ 	} else  {
153+ 		req .Header .Set (string (DdTraceId ), fmt .Sprint (functionExecutionSpan .Context ().TraceID ()))
154+ 		req .Header .Set (string (DdSpanId ), fmt .Sprint (functionExecutionSpan .Context ().SpanID ()))
155+ 	}
156+ 
157+ 	resp , err  :=  em .httpClient .Do (req )
158+ 	if  err  !=  nil  ||  resp .StatusCode  !=  200  {
159+ 		logger .Error (fmt .Errorf ("could not send end invocation payload to the extension: %v" , err ))
160+ 	}
69161}
70162
71163func  (em  * ExtensionManager ) IsExtensionRunning () bool  {
0 commit comments