Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/azcore/headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const (
HeaderIfUnmodifiedSince = "If-Unmodified-Since"
HeaderMetadata = "Metadata"
HeaderRange = "Range"
HeaderRetryAfter = "Retry-After"
HeaderURLEncoded = "application/x-www-form-urlencoded"
HeaderUserAgent = "User-Agent"
HeaderXmsDate = "x-ms-date"
Expand Down
6 changes: 4 additions & 2 deletions sdk/azcore/policy_logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ func TestPolicyLoggingSuccess(t *testing.T) {
srv.SetResponse()
pl := NewPipeline(srv, NewRequestLogPolicy(RequestLogOptions{}))
req := NewRequest(http.MethodGet, srv.URL())
req.SetQueryParam("one", "fish")
req.SetQueryParam("sig", "redact")
qp := req.URL.Query()
qp.Set("one", "fish")
qp.Set("sig", "redact")
req.URL.RawQuery = qp.Encode()
resp, err := pl.Do(context.Background(), req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
Expand Down
35 changes: 23 additions & 12 deletions sdk/azcore/policy_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,11 @@ const (
// RetryOptions configures the retry policy's behavior.
type RetryOptions struct {
// MaxTries specifies the maximum number of attempts an operation will be tried before producing an error (0=default).
// A value of zero means that you accept our default policy. A value of 1 means 1 try and no retries.
// A value of zero means that you accept the default value. A value of 1 means 1 try and no retries.
MaxTries int32

// TryTimeout indicates the maximum time allowed for any single try of an HTTP request.
// A value of zero means that you accept our default timeout. NOTE: When transferring large amounts
// of data, the default TryTimeout will probably not be sufficient. You should override this value
// based on the bandwidth available to the host machine and proximity to the service. A good
// starting point may be something like (60 seconds per MB of anticipated-payload-size).
// A value of zero means that you accept the default timeout.
TryTimeout time.Duration

// RetryDelay specifies the amount of delay to use before retrying an operation (0=default).
Expand All @@ -49,9 +46,9 @@ type RetryOptions struct {

var (
// StatusCodesForRetry is the default set of HTTP status code for which the policy will retry.
StatusCodesForRetry = [6]int{
// Changing the value of StatusCodesForRetry will affect all clients that use the default values.
StatusCodesForRetry = []int{
http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
Expand All @@ -62,14 +59,23 @@ var (
// DefaultRetryOptions returns an instance of RetryOptions initialized with default values.
func DefaultRetryOptions() RetryOptions {
return RetryOptions{
StatusCodes: StatusCodesForRetry[:],
StatusCodes: StatusCodesForRetry,
MaxTries: defaultMaxTries,
TryTimeout: 1 * time.Minute,
RetryDelay: 4 * time.Second,
MaxRetryDelay: 120 * time.Second,
}
}

// used as a context key for adding/retrieving RetryOptions
type ctxWithRetryOptionsKey struct{}

// WithRetryOptions adds the specified RetryOptions to the parent context.
// Use this to specify custom RetryOptions at the API-call level.
func WithRetryOptions(parent context.Context, options RetryOptions) context.Context {
return context.WithValue(parent, ctxWithRetryOptionsKey{}, options)
}

func (o RetryOptions) calcDelay(try int32) time.Duration { // try is >=1; never 0
pow := func(number int64, exponent int32) int64 { // pow is nested helper function
var result int64 = 1
Expand Down Expand Up @@ -105,6 +111,11 @@ type retryPolicy struct {
}

func (p *retryPolicy) Do(ctx context.Context, req *Request) (resp *Response, err error) {
options := p.options
// check if the retry options have been overridden for this call
if override := ctx.Value(ctxWithRetryOptionsKey{}); override != nil {
options = override.(RetryOptions)
}
// Exponential retry algorithm: ((2 ^ attempt) - 1) * delay * random(0.8, 1.2)
// When to retry: connection failure or temporary/timeout.
if req.Body != nil {
Expand Down Expand Up @@ -134,14 +145,14 @@ func (p *retryPolicy) Do(ctx context.Context, req *Request) (resp *Response, err
}

// Set the time for this particular retry operation and then Do the operation.
tryCtx, tryCancel := context.WithTimeout(ctx, p.options.TryTimeout)
tryCtx, tryCancel := context.WithTimeout(ctx, options.TryTimeout)
resp, err = req.Next(tryCtx) // Make the request
tryCancel()
if shouldLog {
Log().Write(LogRetryPolicy, fmt.Sprintf("Err=%v, response=%v\n", err, resp))
}

if err == nil && !resp.HasStatusCode(p.options.StatusCodes...) {
if err == nil && !resp.HasStatusCode(options.StatusCodes...) {
// if there is no error and the response code isn't in the list of retry codes then we're done.
return
} else if ctx.Err() != nil {
Expand All @@ -155,15 +166,15 @@ func (p *retryPolicy) Do(ctx context.Context, req *Request) (resp *Response, err
// drain before retrying so nothing is leaked
resp.Drain()

if try == p.options.MaxTries {
if try == options.MaxTries {
// max number of tries has been reached, don't sleep again
return
}

// use the delay from retry-after if available
delay, ok := resp.RetryAfter()
if !ok {
delay = p.options.calcDelay(try)
delay = options.calcDelay(try)
}
if shouldLog {
Log().Write(LogRetryPolicy, fmt.Sprintf("Try=%d, Delay=%v\n", try, delay))
Expand Down
29 changes: 29 additions & 0 deletions sdk/azcore/policy_retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,35 @@ func TestRetryPolicyIsNotRetriable(t *testing.T) {
}
}

func TestWithRetryOptions(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.RepeatResponse(9, mock.WithStatusCode(http.StatusRequestTimeout))
srv.AppendResponse(mock.WithStatusCode(http.StatusOK))
defaultOptions := testRetryOptions()
pl := NewPipeline(srv, NewRetryPolicy(defaultOptions))
customOptions := *defaultOptions
customOptions.MaxTries = 10
customOptions.MaxRetryDelay = 200 * time.Millisecond
retryCtx := WithRetryOptions(context.Background(), customOptions)
req := NewRequest(http.MethodGet, srv.URL())
body := newRewindTrackingBody("stuff")
req.SetBody(body)
resp, err := pl.Do(retryCtx, req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if body.rcount != int(customOptions.MaxTries-1) {
t.Fatalf("unexpected rewind count: %d", body.rcount)
}
if !body.closed {
t.Fatal("request body wasn't closed")
}
}

// TODO: add test for retry failing to read response body

// TODO: add test for per-retry timeout failed but e2e succeeded
Expand Down
14 changes: 0 additions & 14 deletions sdk/azcore/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ const (
type Request struct {
*http.Request
policies []Policy
qp url.Values
values opValues
}

Expand Down Expand Up @@ -79,11 +78,6 @@ func (req *Request) Next(ctx context.Context) (*Response, error) {
nextPolicy := req.policies[0]
nextReq := *req
nextReq.policies = nextReq.policies[1:]
// encode any pending query params
if nextReq.qp != nil {
nextReq.Request.URL.RawQuery = nextReq.qp.Encode()
nextReq.qp = nil
}
return nextPolicy.Do(ctx, &nextReq)
}

Expand Down Expand Up @@ -125,14 +119,6 @@ func (req *Request) OperationValue(value interface{}) bool {
return req.values.get(value)
}

// SetQueryParam sets the key to value.
func (req *Request) SetQueryParam(key, value string) {
if req.qp == nil {
req.qp = req.Request.URL.Query()
}
req.qp.Set(key, value)
}

// SetBody sets the specified ReadSeekCloser as the HTTP request body.
func (req *Request) SetBody(body ReadSeekCloser) error {
// Set the body and content length.
Expand Down
12 changes: 10 additions & 2 deletions sdk/azcore/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,21 @@ func (r *Response) removeBOM() {
}
}

// RetryAfter returns (non-zero, true) if the response contains a Retry-After header value
// RetryAfter returns (non-zero, true) if the response contains a Retry-After header value.
func (r *Response) RetryAfter() (time.Duration, bool) {
if r == nil {
return 0, false
}
if retryAfter, _ := strconv.Atoi(r.Header.Get("Retry-After")); retryAfter > 0 {
ra := r.Header.Get(HeaderRetryAfter)
if ra == "" {
return 0, false
}
// retry-after values are expressed in either number of
// seconds or an HTTP-date indicating when to try again
if retryAfter, _ := strconv.Atoi(ra); retryAfter > 0 {
return time.Duration(retryAfter) * time.Second, true
} else if t, err := time.Parse(time.RFC1123, ra); err == nil {
return t.Sub(time.Now()), true
}
return 0, false
}
Expand Down
29 changes: 29 additions & 0 deletions sdk/azcore/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"context"
"net/http"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)
Expand Down Expand Up @@ -103,3 +104,31 @@ func TestResponseUnmarshalXMLNoBody(t *testing.T) {
t.Fatalf("unexpected error unmarshalling: %v", err)
}
}

func TestRetryAfter(t *testing.T) {
raw := &http.Response{
Header: http.Header{},
}
resp := Response{raw}
if d, ok := resp.RetryAfter(); ok {
t.Fatalf("unexpected retry-after value %d", d)
}
raw.Header.Set(HeaderRetryAfter, "300")
d, ok := resp.RetryAfter()
if !ok {
t.Fatal("expected retry-after value from seconds")
}
if d != 300*time.Second {
t.Fatalf("expected 300 seconds, got %d", d/time.Second)
}
atDate := time.Now().Add(600 * time.Second)
raw.Header.Set(HeaderRetryAfter, atDate.Format(time.RFC1123))
d, ok = resp.RetryAfter()
if !ok {
t.Fatal("expected retry-after value from date")
}
// d will not be exactly 600 seconds but it will be close
if d/time.Second != 599 {
t.Fatalf("expected ~600 seconds, got %d", d/time.Second)
}
}