Skip to content

Commit 44acf0c

Browse files
authored
Merge pull request #2105 from aws/recursion-detection
Add Recursion Detection middleware to all SDK requests
2 parents 7399331 + 75061a4 commit 44acf0c

File tree

13,519 files changed

+40773
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

13,519 files changed

+40773
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "d74f8a81-3ddb-431f-b600-6abefbdaba1b",
3+
"type": "feature",
4+
"description": "add recursion detection middleware to all SDK requests to avoid recursion invocation in Lambda",
5+
"modules": [
6+
"."
7+
]
8+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/aws/smithy-go/middleware"
7+
smithyhttp "github.com/aws/smithy-go/transport/http"
8+
"os"
9+
)
10+
11+
const envAwsLambdaFunctionName = "AWS_LAMBDA_FUNCTION_NAME"
12+
const envAmznTraceID = "_X_AMZN_TRACE_ID"
13+
const amznTraceIDHeader = "X-Amzn-Trace-Id"
14+
15+
// AddRecursionDetection adds recursionDetection to the middleware stack
16+
func AddRecursionDetection(stack *middleware.Stack) error {
17+
return stack.Build.Add(&RecursionDetection{}, middleware.After)
18+
}
19+
20+
// RecursionDetection detects Lambda environment and sets its X-Ray trace ID to request header if absent
21+
// to avoid recursion invocation in Lambda
22+
type RecursionDetection struct{}
23+
24+
// ID returns the middleware identifier
25+
func (m *RecursionDetection) ID() string {
26+
return "RecursionDetection"
27+
}
28+
29+
// HandleBuild detects Lambda environment and adds its trace ID to request header if absent
30+
func (m *RecursionDetection) HandleBuild(
31+
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
32+
) (
33+
out middleware.BuildOutput, metadata middleware.Metadata, err error,
34+
) {
35+
req, ok := in.Request.(*smithyhttp.Request)
36+
if !ok {
37+
return out, metadata, fmt.Errorf("unknown request type %T", req)
38+
}
39+
40+
_, hasLambdaEnv := os.LookupEnv(envAwsLambdaFunctionName)
41+
xAmznTraceID, hasTraceID := os.LookupEnv(envAmznTraceID)
42+
value := req.Header.Get(amznTraceIDHeader)
43+
// only set the X-Amzn-Trace-Id header when it is not set initially, the
44+
// current environment is Lambda and the _X_AMZN_TRACE_ID env variable exists
45+
if value != "" || !hasLambdaEnv || !hasTraceID {
46+
return next.HandleBuild(ctx, in)
47+
}
48+
49+
req.Header.Set(amznTraceIDHeader, percentEncode(xAmznTraceID))
50+
return next.HandleBuild(ctx, in)
51+
}
52+
53+
func percentEncode(s string) string {
54+
upperhex := "0123456789ABCDEF"
55+
hexCount := 0
56+
for i := 0; i < len(s); i++ {
57+
c := s[i]
58+
if shouldEncode(c) {
59+
hexCount++
60+
}
61+
}
62+
63+
if hexCount == 0 {
64+
return s
65+
}
66+
67+
required := len(s) + 2*hexCount
68+
t := make([]byte, required)
69+
j := 0
70+
for i := 0; i < len(s); i++ {
71+
if c := s[i]; shouldEncode(c) {
72+
t[j] = '%'
73+
t[j+1] = upperhex[c>>4]
74+
t[j+2] = upperhex[c&15]
75+
j += 3
76+
} else {
77+
t[j] = c
78+
j++
79+
}
80+
}
81+
return string(t)
82+
}
83+
84+
func shouldEncode(c byte) bool {
85+
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
86+
return false
87+
}
88+
switch c {
89+
case '-', '=', ';', ':', '+', '&', '[', ']', '{', '}', '"', '\'', ',':
90+
return false
91+
default:
92+
return true
93+
}
94+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
smithymiddleware "github.com/aws/smithy-go/middleware"
6+
smithyhttp "github.com/aws/smithy-go/transport/http"
7+
"os"
8+
"testing"
9+
)
10+
11+
func TestRecursionDetection(t *testing.T) {
12+
cases := map[string]struct {
13+
LambdaFuncName string
14+
TraceID string
15+
HeaderBefore string
16+
HeaderAfter string
17+
}{
18+
"non lambda env and no trace ID header before": {},
19+
"with lambda env but no trace ID env variable, no trace ID header before": {
20+
LambdaFuncName: "some-function1",
21+
},
22+
"with lambda env and trace ID env variable, no trace ID header before": {
23+
LambdaFuncName: "some-function2",
24+
TraceID: "traceID1",
25+
HeaderAfter: "traceID1",
26+
},
27+
"with lambda env and trace ID env variable, has trace ID header before": {
28+
LambdaFuncName: "some-function3",
29+
TraceID: "traceID2",
30+
HeaderBefore: "traceID1",
31+
HeaderAfter: "traceID1",
32+
},
33+
"with lambda env and trace ID (needs encoding) env variable, no trace ID header before": {
34+
LambdaFuncName: "some-function4",
35+
TraceID: "traceID3\n",
36+
HeaderAfter: "traceID3%0A",
37+
},
38+
"with lambda env and trace ID (contains chars must not be encoded) env variable, no trace ID header before": {
39+
LambdaFuncName: "some-function5",
40+
TraceID: "traceID4-=;:+&[]{}\"'",
41+
HeaderAfter: "traceID4-=;:+&[]{}\"'",
42+
},
43+
}
44+
45+
for name, c := range cases {
46+
t.Run(name, func(t *testing.T) {
47+
// clear current case's environment variables and restore them at the end of the test func goroutine
48+
restoreEnv := clearEnv()
49+
defer restoreEnv()
50+
51+
setEnvVar(t, envAwsLambdaFunctionName, c.LambdaFuncName)
52+
setEnvVar(t, envAmznTraceID, c.TraceID)
53+
54+
req := smithyhttp.NewStackRequest().(*smithyhttp.Request)
55+
if c.HeaderBefore != "" {
56+
req.Header.Set(amznTraceIDHeader, c.HeaderBefore)
57+
}
58+
var updatedRequest *smithyhttp.Request
59+
m := RecursionDetection{}
60+
_, _, err := m.HandleBuild(context.Background(),
61+
smithymiddleware.BuildInput{Request: req},
62+
smithymiddleware.BuildHandlerFunc(func(ctx context.Context, input smithymiddleware.BuildInput) (
63+
out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error) {
64+
updatedRequest = input.Request.(*smithyhttp.Request)
65+
return out, metadata, nil
66+
}),
67+
)
68+
if err != nil {
69+
t.Fatalf("expect no error, got %v", err)
70+
}
71+
72+
if e, a := c.HeaderAfter, updatedRequest.Header.Get(amznTraceIDHeader); e != a {
73+
t.Errorf("expect header value %v found, got %v", e, a)
74+
}
75+
})
76+
}
77+
}
78+
79+
// check if test case has environment variable and set to os if it has
80+
func setEnvVar(t *testing.T, key, value string) {
81+
if value != "" {
82+
err := os.Setenv(key, value)
83+
if err != nil {
84+
t.Fatalf("expect no error, got %v", err)
85+
}
86+
}
87+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package software.amazon.smithy.aws.go.codegen.customization;
2+
3+
import software.amazon.smithy.aws.go.codegen.AwsGoDependency;
4+
import software.amazon.smithy.go.codegen.SymbolUtils;
5+
import software.amazon.smithy.go.codegen.integration.GoIntegration;
6+
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
7+
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
8+
import software.amazon.smithy.utils.ListUtils;
9+
10+
import java.util.List;
11+
12+
/**
13+
* Add middleware during operation builder step, which detects Lambda environment and sets its X-Ray trace ID to
14+
* request header if absent to avoid recursion invocation in Lambda
15+
*/
16+
public class LambdaRecursionDetection implements GoIntegration {
17+
/**
18+
* Gets the sort order of the customization from -128 to 127, with lowest
19+
* executed first.
20+
*
21+
* @return Returns the sort order, defaults to -40.
22+
*/
23+
@Override
24+
public byte getOrder() {
25+
return 126;
26+
}
27+
28+
@Override
29+
public List<RuntimeClientPlugin> getClientPlugins() {
30+
return ListUtils.of(
31+
RuntimeClientPlugin.builder()
32+
.registerMiddleware(MiddlewareRegistrar.builder()
33+
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(
34+
"AddRecursionDetection", AwsGoDependency.AWS_MIDDLEWARE)
35+
.build())
36+
.build()
37+
)
38+
.build()
39+
);
40+
}
41+
}

codegen/smithy-aws-go-codegen/src/main/resources/META-INF/services/software.amazon.smithy.go.codegen.integration.GoIntegration

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ software.amazon.smithy.aws.go.codegen.customization.SQSValidateMessageChecksum
4949
software.amazon.smithy.aws.go.codegen.EndpointDiscoveryGenerator
5050
software.amazon.smithy.aws.go.codegen.customization.S3100Continue
5151
software.amazon.smithy.aws.go.codegen.customization.ApiGatewayExportsNullabilityExceptionIntegration
52+
software.amazon.smithy.aws.go.codegen.customization.LambdaRecursionDetection

internal/protocoltest/awsrestjson/api_op_AllQueryStringTypes.go

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/protocoltest/awsrestjson/api_op_ConstantAndVariableQueryString.go

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/protocoltest/awsrestjson/api_op_ConstantQueryString.go

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/protocoltest/awsrestjson/api_op_DatetimeOffsets.go

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/protocoltest/awsrestjson/api_op_DocumentType.go

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)