diff --git a/context/context.go b/context/context.go index db1c95fab1..24cea68820 100644 --- a/context/context.go +++ b/context/context.go @@ -2,44 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package context defines the Context type, which carries deadlines, -// cancellation signals, and other request-scoped values across API boundaries -// and between processes. -// As of Go 1.7 this package is available in the standard library under the -// name [context], and migrating to it can be done automatically with [go fix]. -// -// Incoming requests to a server should create a [Context], and outgoing -// calls to servers should accept a Context. The chain of function -// calls between them must propagate the Context, optionally replacing -// it with a derived Context created using [WithCancel], [WithDeadline], -// [WithTimeout], or [WithValue]. -// -// Programs that use Contexts should follow these rules to keep interfaces -// consistent across packages and enable static analysis tools to check context -// propagation: -// -// Do not store Contexts inside a struct type; instead, pass a Context -// explicitly to each function that needs it. This is discussed further in -// https://go.dev/blog/context-and-structs. The Context should be the first -// parameter, typically named ctx: -// -// func DoSomething(ctx context.Context, arg Arg) error { -// // ... use ctx ... -// } -// -// Do not pass a nil [Context], even if a function permits it. Pass [context.TODO] -// if you are unsure about which Context to use. -// -// Use context Values only for request-scoped data that transits processes and -// APIs, not for passing optional parameters to functions. +// Package context has been superseded by the standard library [context] package. // -// The same Context may be passed to functions running in different goroutines; -// Contexts are safe for simultaneous use by multiple goroutines. -// -// See https://go.dev/blog/context for example code for a server that uses -// Contexts. -// -// [go fix]: https://go.dev/cmd/go#hdr-Update_packages_to_use_new_APIs +// Deprecated: Use the standard library context package instead. package context import ( @@ -51,36 +16,37 @@ import ( // API boundaries. // // Context's methods may be called by multiple goroutines simultaneously. +// +//go:fix inline type Context = context.Context // Canceled is the error returned by [Context.Err] when the context is canceled // for some reason other than its deadline passing. +// +//go:fix inline var Canceled = context.Canceled // DeadlineExceeded is the error returned by [Context.Err] when the context is canceled // due to its deadline passing. +// +//go:fix inline var DeadlineExceeded = context.DeadlineExceeded // Background returns a non-nil, empty Context. It is never canceled, has no // values, and has no deadline. It is typically used by the main function, // initialization, and tests, and as the top-level Context for incoming // requests. -func Background() Context { - return background -} +// +//go:fix inline +func Background() Context { return context.Background() } // TODO returns a non-nil, empty Context. Code should use context.TODO when // it's unclear which Context to use or it is not yet available (because the // surrounding function has not yet been extended to accept a Context // parameter). -func TODO() Context { - return todo -} - -var ( - background = context.Background() - todo = context.TODO() -) +// +//go:fix inline +func TODO() Context { return context.TODO() } // A CancelFunc tells an operation to abandon its work. // A CancelFunc does not wait for the work to stop. @@ -95,6 +61,8 @@ type CancelFunc = context.CancelFunc // // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this [Context] complete. +// +//go:fix inline func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { return context.WithCancel(parent) } @@ -108,6 +76,8 @@ func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { // // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this [Context] complete. +// +//go:fix inline func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { return context.WithDeadline(parent, d) } @@ -122,6 +92,8 @@ func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { // defer cancel() // releases resources if slowOperation completes before timeout elapses // return slowOperation(ctx) // } +// +//go:fix inline func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { return context.WithTimeout(parent, timeout) } @@ -139,6 +111,8 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { // interface{}, context keys often have concrete type // struct{}. Alternatively, exported context key variables' static // type should be a pointer or interface. +// +//go:fix inline func WithValue(parent Context, key, val interface{}) Context { return context.WithValue(parent, key, val) } diff --git a/dns/dnsmessage/message.go b/dns/dnsmessage/message.go index a656efc128..7a978b47f6 100644 --- a/dns/dnsmessage/message.go +++ b/dns/dnsmessage/message.go @@ -17,8 +17,21 @@ import ( ) // Message formats - -// A Type is a type of DNS request and response. +// +// To add a new Resource Record type: +// 1. Create Resource Record types +// 1.1. Add a Type constant named "Type" +// 1.2. Add the corresponding entry to the typeNames map +// 1.3. Add a [ResourceBody] implementation named "Resource" +// 2. Implement packing +// 2.1. Implement Builder.Resource() +// 3. Implement unpacking +// 3.1. Add the unpacking code to unpackResourceBody() +// 3.2. Implement Parser.Resource() + +// A Type is the type of a DNS Resource Record, as defined in the [IANA registry]. +// +// [IANA registry]: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4 type Type uint16 const ( @@ -33,6 +46,8 @@ const ( TypeAAAA Type = 28 TypeSRV Type = 33 TypeOPT Type = 41 + TypeSVCB Type = 64 + TypeHTTPS Type = 65 // Question.Type TypeWKS Type = 11 @@ -53,6 +68,8 @@ var typeNames = map[Type]string{ TypeAAAA: "TypeAAAA", TypeSRV: "TypeSRV", TypeOPT: "TypeOPT", + TypeSVCB: "TypeSVCB", + TypeHTTPS: "TypeHTTPS", TypeWKS: "TypeWKS", TypeHINFO: "TypeHINFO", TypeMINFO: "TypeMINFO", @@ -273,6 +290,8 @@ var ( errTooManyAdditionals = errors.New("too many Additionals to pack (>65535)") errNonCanonicalName = errors.New("name is not in canonical format (it must end with a .)") errStringTooLong = errors.New("character string exceeds maximum length (255)") + errParamOutOfOrder = errors.New("parameter out of order") + errTooLongSVCBValue = errors.New("value too long (>65535 bytes)") ) // Internal constants. @@ -2220,6 +2239,16 @@ func unpackResourceBody(msg []byte, off int, hdr ResourceHeader) (ResourceBody, rb, err = unpackSRVResource(msg, off) r = &rb name = "SRV" + case TypeSVCB: + var rb SVCBResource + rb, err = unpackSVCBResource(msg, off, hdr.Length) + r = &rb + name = "SVCB" + case TypeHTTPS: + var rb HTTPSResource + rb.SVCBResource, err = unpackSVCBResource(msg, off, hdr.Length) + r = &rb + name = "HTTPS" case TypeOPT: var rb OPTResource rb, err = unpackOPTResource(msg, off, hdr.Length) diff --git a/dns/dnsmessage/message_test.go b/dns/dnsmessage/message_test.go index 1fa93e63ad..e004db7840 100644 --- a/dns/dnsmessage/message_test.go +++ b/dns/dnsmessage/message_test.go @@ -363,6 +363,49 @@ func TestResourceNotStarted(t *testing.T) { } } +func buildTestSVCBMsg() Message { + svcb := &SVCBResource{ + Priority: 1, + Target: MustNewName("svc.example.com."), + Params: []SVCParam{{Key: SVCParamALPN, Value: []byte("h2")}}, + } + + https := &HTTPSResource{ + SVCBResource{ + Priority: 2, + Target: MustNewName("https.example.com."), + Params: []SVCParam{ + {Key: SVCParamPort, Value: []byte{0x01, 0xbb}}, + {Key: SVCParamIPv4Hint, Value: []byte{192, 0, 2, 1}}, + }, + }, + } + + return Message{ + Questions: []Question{}, + Answers: []Resource{ + { + ResourceHeader{ + Name: MustNewName("foo.bar.example.com."), + Type: TypeSVCB, + Class: ClassINET, + }, + svcb, + }, + { + ResourceHeader{ + Name: MustNewName("foo.bar.example.com."), + Type: TypeHTTPS, + Class: ClassINET, + }, + https, + }, + }, + Authorities: []Resource{}, + Additionals: []Resource{}, + } +} + func TestDNSPackUnpack(t *testing.T) { wants := []Message{ { @@ -378,6 +421,7 @@ func TestDNSPackUnpack(t *testing.T) { Additionals: []Resource{}, }, largeTestMsg(), + buildTestSVCBMsg(), } for i, want := range wants { b, err := want.Pack() @@ -390,7 +434,14 @@ func TestDNSPackUnpack(t *testing.T) { t.Fatalf("%d: Message.Unapck() = %v", i, err) } if !reflect.DeepEqual(got, want) { - t.Errorf("%d: Message.Pack/Unpack() roundtrip: got = %+v, want = %+v", i, &got, &want) + t.Errorf("%d: Message.Pack/Unpack() roundtrip: got = %#v, want = %#v", i, &got, &want) + if len(got.Answers) > 0 && len(want.Answers) > 0 { + if !reflect.DeepEqual(got.Answers[0].Body, want.Answers[0].Body) { + t.Errorf("Answer 0 Body mismatch") + t.Errorf("got: %#v", got.Answers[0].Body) + t.Errorf("want: %#v", want.Answers[0].Body) + } + } } } } @@ -684,16 +735,19 @@ func TestBuilderResourceError(t *testing.T) { name string fn func(*Builder) error }{ + // Keep it sorted by resource type name. + {"AResource", func(b *Builder) error { return b.AResource(ResourceHeader{}, AResource{}) }}, + {"AAAAResource", func(b *Builder) error { return b.AAAAResource(ResourceHeader{}, AAAAResource{}) }}, {"CNAMEResource", func(b *Builder) error { return b.CNAMEResource(ResourceHeader{}, CNAMEResource{}) }}, + {"HTTPSResource", func(b *Builder) error { return b.HTTPSResource(ResourceHeader{}, HTTPSResource{}) }}, {"MXResource", func(b *Builder) error { return b.MXResource(ResourceHeader{}, MXResource{}) }}, {"NSResource", func(b *Builder) error { return b.NSResource(ResourceHeader{}, NSResource{}) }}, + {"OPTResource", func(b *Builder) error { return b.OPTResource(ResourceHeader{}, OPTResource{}) }}, {"PTRResource", func(b *Builder) error { return b.PTRResource(ResourceHeader{}, PTRResource{}) }}, {"SOAResource", func(b *Builder) error { return b.SOAResource(ResourceHeader{}, SOAResource{}) }}, - {"TXTResource", func(b *Builder) error { return b.TXTResource(ResourceHeader{}, TXTResource{}) }}, {"SRVResource", func(b *Builder) error { return b.SRVResource(ResourceHeader{}, SRVResource{}) }}, - {"AResource", func(b *Builder) error { return b.AResource(ResourceHeader{}, AResource{}) }}, - {"AAAAResource", func(b *Builder) error { return b.AAAAResource(ResourceHeader{}, AAAAResource{}) }}, - {"OPTResource", func(b *Builder) error { return b.OPTResource(ResourceHeader{}, OPTResource{}) }}, + {"SVCBResource", func(b *Builder) error { return b.SVCBResource(ResourceHeader{}, SVCBResource{}) }}, + {"TXTResource", func(b *Builder) error { return b.TXTResource(ResourceHeader{}, TXTResource{}) }}, {"UnknownResource", func(b *Builder) error { return b.UnknownResource(ResourceHeader{}, UnknownResource{}) }}, } @@ -785,6 +839,14 @@ func TestBuilder(t *testing.T) { if err := b.SRVResource(a.Header, *a.Body.(*SRVResource)); err != nil { t.Fatalf("Builder.SRVResource(%#v) = %v", a, err) } + case TypeSVCB: + if err := b.SVCBResource(a.Header, *a.Body.(*SVCBResource)); err != nil { + t.Fatalf("Builder.SVCBResource(%#v) = %v", a, err) + } + case TypeHTTPS: + if err := b.HTTPSResource(a.Header, *a.Body.(*HTTPSResource)); err != nil { + t.Fatalf("Builder.HTTPSResource(%#v) = %v", a, err) + } case privateUseType: if err := b.UnknownResource(a.Header, *a.Body.(*UnknownResource)); err != nil { t.Fatalf("Builder.UnknownResource(%#v) = %v", a, err) @@ -1262,6 +1324,14 @@ func benchmarkParsing(tb testing.TB, buf []byte) { if _, err := p.NSResource(); err != nil { tb.Fatal("Parser.NSResource() =", err) } + case TypeSVCB: + if _, err := p.SVCBResource(); err != nil { + tb.Fatal("Parser.SVCBResource() =", err) + } + case TypeHTTPS: + if _, err := p.HTTPSResource(); err != nil { + tb.Fatal("Parser.HTTPSResource() =", err) + } case TypeOPT: if _, err := p.OPTResource(); err != nil { tb.Fatal("Parser.OPTResource() =", err) diff --git a/dns/dnsmessage/svcb.go b/dns/dnsmessage/svcb.go new file mode 100644 index 0000000000..4840516a7f --- /dev/null +++ b/dns/dnsmessage/svcb.go @@ -0,0 +1,326 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dnsmessage + +import ( + "slices" +) + +// An SVCBResource is an SVCB Resource record. +type SVCBResource struct { + Priority uint16 + Target Name + Params []SVCParam // Must be in strict increasing order by Key. +} + +func (r *SVCBResource) realType() Type { + return TypeSVCB +} + +// GoString implements fmt.GoStringer.GoString. +func (r *SVCBResource) GoString() string { + b := []byte("dnsmessage.SVCBResource{" + + "Priority: " + printUint16(r.Priority) + ", " + + "Target: " + r.Target.GoString() + ", " + + "Params: []dnsmessage.SVCParam{") + if len(r.Params) > 0 { + b = append(b, r.Params[0].GoString()...) + for _, p := range r.Params[1:] { + b = append(b, ", "+p.GoString()...) + } + } + b = append(b, "}}"...) + return string(b) +} + +// An HTTPSResource is an HTTPS Resource record. +// It has the same format as the SVCB record. +type HTTPSResource struct { + // Alias for SVCB resource record. + SVCBResource +} + +func (r *HTTPSResource) realType() Type { + return TypeHTTPS +} + +// GoString implements fmt.GoStringer.GoString. +func (r *HTTPSResource) GoString() string { + return "dnsmessage.HTTPSResource{SVCBResource: " + r.SVCBResource.GoString() + "}" +} + +// GetParam returns a parameter value by key. +func (r *SVCBResource) GetParam(key SVCParamKey) (value []byte, ok bool) { + for i := range r.Params { + if r.Params[i].Key == key { + return r.Params[i].Value, true + } + if r.Params[i].Key > key { + break + } + } + return nil, false +} + +// SetParam sets a parameter value by key. +// The Params list is kept sorted by key. +func (r *SVCBResource) SetParam(key SVCParamKey, value []byte) { + i := 0 + for i < len(r.Params) { + if r.Params[i].Key >= key { + break + } + i++ + } + + if i < len(r.Params) && r.Params[i].Key == key { + r.Params[i].Value = value + return + } + + r.Params = slices.Insert(r.Params, i, SVCParam{Key: key, Value: value}) +} + +// DeleteParam deletes a parameter by key. +// It returns true if the parameter was present. +func (r *SVCBResource) DeleteParam(key SVCParamKey) bool { + for i := range r.Params { + if r.Params[i].Key == key { + r.Params = slices.Delete(r.Params, i, i+1) + return true + } + if r.Params[i].Key > key { + break + } + } + return false +} + +// A SVCParam is a service parameter. +type SVCParam struct { + Key SVCParamKey + Value []byte +} + +// GoString implements fmt.GoStringer.GoString. +func (p SVCParam) GoString() string { + return "dnsmessage.SVCParam{" + + "Key: " + p.Key.GoString() + ", " + + "Value: []byte{" + printByteSlice(p.Value) + "}}" +} + +// A SVCParamKey is a key for a service parameter. +type SVCParamKey uint16 + +// Values defined at https://www.iana.org/assignments/dns-svcb/dns-svcb.xhtml#dns-svcparamkeys. +const ( + SVCParamMandatory SVCParamKey = 0 + SVCParamALPN SVCParamKey = 1 + SVCParamNoDefaultALPN SVCParamKey = 2 + SVCParamPort SVCParamKey = 3 + SVCParamIPv4Hint SVCParamKey = 4 + SVCParamECH SVCParamKey = 5 + SVCParamIPv6Hint SVCParamKey = 6 + SVCParamDOHPath SVCParamKey = 7 + SVCParamOHTTP SVCParamKey = 8 + SVCParamTLSSupportedGroups SVCParamKey = 9 +) + +var svcParamKeyNames = map[SVCParamKey]string{ + SVCParamMandatory: "Mandatory", + SVCParamALPN: "ALPN", + SVCParamNoDefaultALPN: "NoDefaultALPN", + SVCParamPort: "Port", + SVCParamIPv4Hint: "IPv4Hint", + SVCParamECH: "ECH", + SVCParamIPv6Hint: "IPv6Hint", + SVCParamDOHPath: "DOHPath", + SVCParamOHTTP: "OHTTP", + SVCParamTLSSupportedGroups: "TLSSupportedGroups", +} + +// String implements fmt.Stringer.String. +func (k SVCParamKey) String() string { + if n, ok := svcParamKeyNames[k]; ok { + return n + } + return printUint16(uint16(k)) +} + +// GoString implements fmt.GoStringer.GoString. +func (k SVCParamKey) GoString() string { + if n, ok := svcParamKeyNames[k]; ok { + return "dnsmessage.SVCParam" + n + } + return printUint16(uint16(k)) +} + +func (r *SVCBResource) pack(msg []byte, _ map[string]uint16, _ int) ([]byte, error) { + oldMsg := msg + msg = packUint16(msg, r.Priority) + // https://datatracker.ietf.org/doc/html/rfc3597#section-4 prohibits name + // compression for RR types that are not "well-known". + // https://datatracker.ietf.org/doc/html/rfc9460#section-2.2 explicitly states that + // compression of the Target is prohibited, following RFC 3597. + msg, err := r.Target.pack(msg, nil, 0) + if err != nil { + return oldMsg, &nestedError{"SVCBResource.Target", err} + } + var previousKey SVCParamKey + for i, param := range r.Params { + if i > 0 && param.Key <= previousKey { + return oldMsg, &nestedError{"SVCBResource.Params", errParamOutOfOrder} + } + if len(param.Value) > (1<<16)-1 { + return oldMsg, &nestedError{"SVCBResource.Params", errTooLongSVCBValue} + } + msg = packUint16(msg, uint16(param.Key)) + msg = packUint16(msg, uint16(len(param.Value))) + msg = append(msg, param.Value...) + } + return msg, nil +} + +func unpackSVCBResource(msg []byte, off int, length uint16) (SVCBResource, error) { + // Wire format reference: https://www.rfc-editor.org/rfc/rfc9460.html#section-2.2. + r := SVCBResource{} + paramsOff := off + bodyEnd := off + int(length) + + var err error + if r.Priority, paramsOff, err = unpackUint16(msg, paramsOff); err != nil { + return SVCBResource{}, &nestedError{"Priority", err} + } + + if paramsOff, err = r.Target.unpack(msg, paramsOff); err != nil { + return SVCBResource{}, &nestedError{"Target", err} + } + + // Two-pass parsing to avoid allocations. + // First, count the number of params. + n := 0 + var totalValueLen uint16 + off = paramsOff + var previousKey uint16 + for off < bodyEnd { + var key, len uint16 + if key, off, err = unpackUint16(msg, off); err != nil { + return SVCBResource{}, &nestedError{"Params key", err} + } + if n > 0 && key <= previousKey { + // As per https://www.rfc-editor.org/rfc/rfc9460.html#section-2.2, clients MUST + // consider the RR malformed if the SvcParamKeys are not in strictly increasing numeric order + return SVCBResource{}, &nestedError{"Params", errParamOutOfOrder} + } + if len, off, err = unpackUint16(msg, off); err != nil { + return SVCBResource{}, &nestedError{"Params value length", err} + } + if off+int(len) > bodyEnd { + return SVCBResource{}, errResourceLen + } + totalValueLen += len + off += int(len) + n++ + } + if off != bodyEnd { + return SVCBResource{}, errResourceLen + } + + // Second, fill in the params. + r.Params = make([]SVCParam, n) + // valuesBuf is used to hold all param values to reduce allocations. + // Each param's Value slice will point into this buffer. + valuesBuf := make([]byte, totalValueLen) + off = paramsOff + for i := 0; i < n; i++ { + p := &r.Params[i] + var key, len uint16 + if key, off, err = unpackUint16(msg, off); err != nil { + return SVCBResource{}, &nestedError{"param key", err} + } + p.Key = SVCParamKey(key) + if len, off, err = unpackUint16(msg, off); err != nil { + return SVCBResource{}, &nestedError{"param length", err} + } + if copy(valuesBuf, msg[off:off+int(len)]) != int(len) { + return SVCBResource{}, &nestedError{"param value", errCalcLen} + } + p.Value = valuesBuf[:len:len] + valuesBuf = valuesBuf[len:] + off += int(len) + } + + return r, nil +} + +// genericSVCBResource parses a single Resource Record compatible with SVCB. +func (p *Parser) genericSVCBResource(svcbType Type) (SVCBResource, error) { + if !p.resHeaderValid || p.resHeaderType != svcbType { + return SVCBResource{}, ErrNotStarted + } + r, err := unpackSVCBResource(p.msg, p.off, p.resHeaderLength) + if err != nil { + return SVCBResource{}, err + } + p.off += int(p.resHeaderLength) + p.resHeaderValid = false + p.index++ + return r, nil +} + +// SVCBResource parses a single SVCBResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) SVCBResource() (SVCBResource, error) { + return p.genericSVCBResource(TypeSVCB) +} + +// HTTPSResource parses a single HTTPSResource. +// +// One of the XXXHeader methods must have been called before calling this +// method. +func (p *Parser) HTTPSResource() (HTTPSResource, error) { + svcb, err := p.genericSVCBResource(TypeHTTPS) + if err != nil { + return HTTPSResource{}, err + } + return HTTPSResource{svcb}, nil +} + +// genericSVCBResource is the generic implementation for adding SVCB-like resources. +func (b *Builder) genericSVCBResource(h ResourceHeader, r SVCBResource) error { + if err := b.checkResourceSection(); err != nil { + return err + } + msg, lenOff, err := h.pack(b.msg, b.compression, b.start) + if err != nil { + return &nestedError{"ResourceHeader", err} + } + preLen := len(msg) + if msg, err = r.pack(msg, b.compression, b.start); err != nil { + return &nestedError{"ResourceBody", err} + } + if err := h.fixLen(msg, lenOff, preLen); err != nil { + return err + } + if err := b.incrementSectionCount(); err != nil { + return err + } + b.msg = msg + return nil +} + +// SVCBResource adds a single SVCBResource. +func (b *Builder) SVCBResource(h ResourceHeader, r SVCBResource) error { + h.Type = r.realType() + return b.genericSVCBResource(h, r) +} + +// HTTPSResource adds a single HTTPSResource. +func (b *Builder) HTTPSResource(h ResourceHeader, r HTTPSResource) error { + h.Type = r.realType() + return b.genericSVCBResource(h, r.SVCBResource) +} diff --git a/dns/dnsmessage/svcb_test.go b/dns/dnsmessage/svcb_test.go new file mode 100644 index 0000000000..74fcccdac5 --- /dev/null +++ b/dns/dnsmessage/svcb_test.go @@ -0,0 +1,393 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dnsmessage + +import ( + "bytes" + "math" + "reflect" + "testing" +) + +func TestSVCBParamsRoundTrip(t *testing.T) { + testSVCBParam := func(t *testing.T, p *SVCParam) { + t.Helper() + rr := &SVCBResource{ + Priority: 1, + Target: MustNewName("svc.example.com."), + Params: []SVCParam{*p}, + } + buf, err := rr.pack([]byte{}, nil, 0) + if err != nil { + t.Fatalf("pack() = %v", err) + } + got, n, err := unpackResourceBody(buf, 0, ResourceHeader{Type: TypeSVCB, Length: uint16(len(buf))}) + if err != nil { + t.Fatalf("unpackResourceBody() = %v", err) + } + if n != len(buf) { + t.Fatalf("unpacked different amount than packed: got = %d, want = %d", n, len(buf)) + } + if !reflect.DeepEqual(got, rr) { + t.Fatalf("roundtrip mismatch: got = %#v, want = %#v", got, rr) + } + } + + testSVCBParam(t, &SVCParam{Key: SVCParamMandatory, Value: []byte{0x00, 0x01, 0x00, 0x03, 0x00, 0x05}}) + testSVCBParam(t, &SVCParam{Key: SVCParamALPN, Value: []byte{0x02, 'h', '2', 0x02, 'h', '3'}}) + testSVCBParam(t, &SVCParam{Key: SVCParamNoDefaultALPN, Value: []byte{}}) + testSVCBParam(t, &SVCParam{Key: SVCParamPort, Value: []byte{0x1f, 0x90}}) // 8080 + testSVCBParam(t, &SVCParam{Key: SVCParamIPv4Hint, Value: []byte{192, 0, 2, 1, 198, 51, 100, 2}}) + testSVCBParam(t, &SVCParam{Key: SVCParamECH, Value: []byte{0x01, 0x02, 0x03, 0x04}}) + testSVCBParam(t, &SVCParam{Key: SVCParamIPv6Hint, Value: []byte{0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}) + testSVCBParam(t, &SVCParam{Key: SVCParamDOHPath, Value: []byte("/dns-query{?dns}")}) + testSVCBParam(t, &SVCParam{Key: SVCParamOHTTP, Value: []byte{0x00, 0x01, 0x02, 0x03}}) + testSVCBParam(t, &SVCParam{Key: SVCParamTLSSupportedGroups, Value: []byte{0x00, 0x1d, 0x00, 0x17}}) +} + +func TestSVCBParsingAllocs(t *testing.T) { + name := MustNewName("foo.bar.example.com.") + msg := Message{ + Header: Header{Response: true, Authoritative: true}, + Questions: []Question{{Name: name, Type: TypeA, Class: ClassINET}}, + Answers: []Resource{{ + Header: ResourceHeader{Name: name, Type: TypeSVCB, Class: ClassINET, TTL: 300}, + Body: &SVCBResource{ + Priority: 1, + Target: MustNewName("svc.example.com."), + Params: []SVCParam{ + {Key: SVCParamMandatory, Value: []byte{0x00, 0x01, 0x00, 0x03, 0x00, 0x05}}, + {Key: SVCParamALPN, Value: []byte{0x02, 'h', '2', 0x02, 'h', '3'}}, + {Key: SVCParamPort, Value: []byte{0x1f, 0x90}}, // 8080 + {Key: SVCParamIPv4Hint, Value: []byte{192, 0, 2, 1, 198, 51, 100, 2}}, + {Key: SVCParamIPv6Hint, Value: []byte{0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}, + }, + }, + }}, + } + buf, err := msg.Pack() + if err != nil { + t.Fatal(err) + } + + allocs := int(testing.AllocsPerRun(1, func() { + var p Parser + if _, err := p.Start(buf); err != nil { + t.Fatal("Parser.Start(non-nil) =", err) + } + if err := p.SkipAllQuestions(); err != nil { + t.Fatal("Parser.SkipAllQuestions(non-nil) =", err) + } + if _, err = p.AnswerHeader(); err != nil { + t.Fatal("Parser.AnswerHeader(non-nil) =", err) + } + if _, err = p.SVCBResource(); err != nil { + t.Fatal("Parser.SVCBResource(non-nil) =", err) + } + })) + + // Make sure we have only two allocations: one for the SVCBResource.Params slice, and one + // for the SVCParam Values. + if allocs != 2 { + t.Errorf("allocations during parsing: got = %d, want 2", allocs) + } +} + +func TestHTTPSBuildAllocs(t *testing.T) { + b := NewBuilder([]byte{}, Header{Response: true, Authoritative: true}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + t.Fatalf("StartQuestions() = %v", err) + } + if err := b.Question(Question{Name: MustNewName("foo.bar.example.com."), Type: TypeHTTPS, Class: ClassINET}); err != nil { + t.Fatalf("Question() = %v", err) + } + if err := b.StartAnswers(); err != nil { + t.Fatalf("StartAnswers() = %v", err) + } + + header := ResourceHeader{Name: MustNewName("foo.bar.example.com."), Type: TypeHTTPS, Class: ClassINET, TTL: 300} + resource := HTTPSResource{SVCBResource{Priority: 1, Target: MustNewName("svc.example.com.")}} + + // AllocsPerRun runs the function once to "warm up" before running the measurement. + // So technically this function is running twice, on different data, which can potentially + // make the measurement inaccurate (e.g. by using the name cache the second time). + // So we make sure we don't run in the warm-up phase. + warmUp := true + allocs := int(testing.AllocsPerRun(1, func() { + if warmUp { + warmUp = false + return + } + if err := b.HTTPSResource(header, resource); err != nil { + t.Fatalf("HTTPSResource() = %v", err) + } + })) + if allocs != 1 { + t.Fatalf("unexpected allocations: got = %d, want = 1", allocs) + } +} + +func TestSVCBParams(t *testing.T) { + rr := SVCBResource{Priority: 1, Target: MustNewName("svc.example.com.")} + if _, ok := rr.GetParam(SVCParamALPN); ok { + t.Fatal("GetParam found non-existent param") + } + rr.SetParam(SVCParamIPv4Hint, []byte{192, 0, 2, 1}) + inALPN := []byte{0x02, 'h', '2', 0x02, 'h', '3'} + rr.SetParam(SVCParamALPN, inALPN) + + // Check sorting of params + packed, err := rr.pack([]byte{}, nil, 0) + if err != nil { + t.Fatal("pack() =", err) + } + expectedBytes := []byte{ + 0x00, 0x01, // priority + 0x03, 0x73, 0x76, 0x63, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // target + 0x00, 0x01, // key 1 + 0x00, 0x06, // length 6 + 0x02, 'h', '2', 0x02, 'h', '3', // value + 0x00, 0x04, // key 4 + 0x00, 0x04, // length 4 + 192, 0, 2, 1, // value + } + if !reflect.DeepEqual(packed, expectedBytes) { + t.Fatalf("pack() produced unexpected output: want = %v, got = %v", expectedBytes, packed) + } + + // Check GetParam and DeleteParam. + if outALPN, ok := rr.GetParam(SVCParamALPN); !ok || !bytes.Equal(outALPN, inALPN) { + t.Fatal("GetParam failed to retrieve set param") + } + if !rr.DeleteParam(SVCParamALPN) { + t.Fatal("DeleteParam failed to remove existing param") + } + if _, ok := rr.GetParam(SVCParamALPN); ok { + t.Fatal("GetParam found deleted param") + } + if len(rr.Params) != 1 || rr.Params[0].Key != SVCParamIPv4Hint { + t.Fatalf("DeleteParam removed wrong param: got = %#v, want = [%#v]", rr.Params, SVCParam{Key: SVCParamIPv4Hint, Value: []byte{192, 0, 2, 1}}) + } +} + +func TestSVCBWireFormat(t *testing.T) { + testRecord := func(bytesInput []byte, parsedInput *SVCBResource) { + parsedOutput, n, err := unpackResourceBody(bytesInput, 0, ResourceHeader{Type: TypeSVCB, Length: uint16(len(bytesInput))}) + if err != nil { + t.Fatalf("unpackResourceBody() = %v", err) + } + if n != len(bytesInput) { + t.Fatalf("unpacked different amount than packed: got = %d, want = %d", n, len(bytesInput)) + } + if !reflect.DeepEqual(parsedOutput, parsedInput) { + t.Fatalf("unpack mismatch: got = %#v, want = %#v", parsedOutput, parsedInput) + } + + bytesOutput, err := parsedInput.pack([]byte{}, nil, 0) + if err != nil { + t.Fatalf("pack() = %v", err) + } + if !reflect.DeepEqual(bytesOutput, bytesInput) { + t.Fatalf("pack mismatch: got = %#v, want = %#v", bytesOutput, bytesInput) + } + } + // Test examples from https://datatracker.ietf.org/doc/html/rfc9460#name-test-vectors + + // Example D.1. Alias Mode + + // Figure 2: AliasMode + // example.com. HTTPS 0 foo.example.com. + bytes := []byte{ + 0x00, 0x00, // priority + 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // target: foo.example.com. + } + parsed := &SVCBResource{ + Priority: 0, + Target: MustNewName("foo.example.com."), + Params: []SVCParam{}, + } + testRecord(bytes, parsed) + + // Example D.2. Service Mode + + // Figure 3: TargetName Is "." + // example.com. SVCB 1 . + bytes = []byte{ + 0x00, 0x01, // priority + 0x00, // target (root label) + } + parsed = &SVCBResource{ + Priority: 1, + Target: MustNewName("."), + Params: []SVCParam{}, + } + testRecord(bytes, parsed) + + // Figure 4: Specifies a Port + // example.com. SVCB 16 foo.example.com. port=53 + bytes = []byte{ + 0x00, 0x10, // priority + 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // target + 0x00, 0x03, // key 3 + 0x00, 0x02, // length 2 + 0x00, 0x35, // value + } + parsed = &SVCBResource{ + Priority: 16, + Target: MustNewName("foo.example.com."), + Params: []SVCParam{{Key: SVCParamPort, Value: []byte{0x00, 0x35}}}, + } + testRecord(bytes, parsed) + + // Figure 5: A Generic Key and Unquoted Value + // example.com. SVCB 1 foo.example.com. key667=hello + bytes = []byte{ + 0x00, 0x01, // priority + 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // target + 0x02, 0x9b, // key 667 + 0x00, 0x05, // length 5 + 0x68, 0x65, 0x6c, 0x6c, 0x6f, // value + } + parsed = &SVCBResource{ + Priority: 1, + Target: MustNewName("foo.example.com."), + Params: []SVCParam{{Key: 667, Value: []byte("hello")}}, + } + testRecord(bytes, parsed) + + // Figure 6: A Generic Key and Quoted Value with a Decimal Escape + // example.com. SVCB 1 foo.example.com. key667="hello\210qoo" + bytes = []byte{ + 0x00, 0x01, // priority + 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // target + 0x02, 0x9b, // key 667 + 0x00, 0x09, // length 9 + 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0xd2, 0x71, 0x6f, 0x6f, // value + } + parsed = &SVCBResource{ + Priority: 1, + Target: MustNewName("foo.example.com."), + Params: []SVCParam{{Key: 667, Value: []byte("hello\xd2qoo")}}, + } + testRecord(bytes, parsed) + + // Figure 7: Two Quoted IPv6 Hints + // example.com. SVCB 1 foo.example.com. ( + // ipv6hint="2001:db8::1,2001:db8::53:1" + // ) + bytes = []byte{ + 0x00, 0x01, // priority + 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // target + 0x00, 0x06, // key 6 + 0x00, 0x20, // length 32 + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // first address + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x53, 0x00, 0x01, // second address + } + parsed = &SVCBResource{ + Priority: 1, + Target: MustNewName("foo.example.com."), + Params: []SVCParam{{Key: SVCParamIPv6Hint, Value: []byte{0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x53, 0x00, 0x01}}}, + } + testRecord(bytes, parsed) + + // Figure 8: An IPv6 Hint Using the Embedded IPv4 Syntax + // example.com. SVCB 1 example.com. ( + // ipv6hint="2001:db8:122:344::192.0.2.33" + // ) + bytes = []byte{ + 0x00, 0x01, // priority + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // target + 0x00, 0x06, // key 6 + 0x00, 0x10, // length 16 + 0x20, 0x01, 0x0d, 0xb8, 0x01, 0x22, 0x03, 0x44, 0x00, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x02, 0x21, // address + } + parsed = &SVCBResource{ + Priority: 1, + Target: MustNewName("example.com."), + Params: []SVCParam{{Key: SVCParamIPv6Hint, Value: []byte{0x20, 0x01, 0x0d, 0xb8, 0x01, 0x22, 0x03, 0x44, 0x00, 0x00, 0x00, 0x00, 0xc0, 0x00, 0x02, 0x21}}}, + } + testRecord(bytes, parsed) + + // Figure 9: SvcParamKey Ordering Is Arbitrary in Presentation Format but Sorted in Wire Format + // example.com. SVCB 16 foo.example.org. ( + // alpn=h2,h3-19 mandatory=ipv4hint,alpn + // ipv4hint=192.0.2.1 + // ) + bytes = []byte{ + 0x00, 0x10, // priority + 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x6f, 0x72, 0x67, 0x00, // target + 0x00, 0x00, // key 0 + 0x00, 0x04, // param length 4 + 0x00, 0x01, // value: key 1 + 0x00, 0x04, // value: key 4 + 0x00, 0x01, // key 1 + 0x00, 0x09, // param length 9 + 0x02, // alpn length 2 + 0x68, 0x32, // alpn value + 0x05, // alpn length 5 + 0x68, 0x33, 0x2d, 0x31, 0x39, // alpn value + 0x00, 0x04, // key 4 + 0x00, 0x04, // param length 4 + 0xc0, 0x00, 0x02, 0x01, // param value + } + parsed = &SVCBResource{ + Priority: 16, + Target: MustNewName("foo.example.org."), + Params: []SVCParam{ + {Key: SVCParamMandatory, Value: []byte{0x00, 0x01, 0x00, 0x04}}, + {Key: SVCParamALPN, Value: []byte{0x02, 0x68, 0x32, 0x05, 0x68, 0x33, 0x2d, 0x31, 0x39}}, + {Key: SVCParamIPv4Hint, Value: []byte{0xc0, 0x00, 0x02, 0x01}}, + }, + } + testRecord(bytes, parsed) + + // Figure 10: An "alpn" Value with an Escaped Comma and an Escaped Backslash in Two Presentation Formats + // example.com. SVCB 16 foo.example.org. alpn=f\\\092oo\092,bar,h2 + bytes = []byte{ + 0x00, 0x10, // priority + 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x6f, 0x72, 0x67, 0x00, // target + 0x00, 0x01, // key 1 + 0x00, 0x0c, // param length 12 + 0x08, // alpn length 8 + 0x66, 0x5c, 0x6f, 0x6f, 0x2c, 0x62, 0x61, 0x72, // alpn value + 0x02, // alpn length 2 + 0x68, 0x32, // alpn value + } + parsed = &SVCBResource{ + Priority: 16, + Target: MustNewName("foo.example.org."), + Params: []SVCParam{ + {Key: SVCParamALPN, Value: []byte{0x08, 0x66, 0x5c, 0x6f, 0x6f, 0x2c, 0x62, 0x61, 0x72, 0x02, 0x68, 0x32}}, + }, + } + testRecord(bytes, parsed) +} + +func TestSVCBPackLongValue(t *testing.T) { + b := NewBuilder(nil, Header{}) + b.StartQuestions() + b.StartAnswers() + + res := SVCBResource{ + Target: MustNewName("example.com."), + Params: []SVCParam{ + { + Key: SVCParamMandatory, + Value: make([]byte, math.MaxUint16+1), + }, + }, + } + + err := b.SVCBResource(ResourceHeader{Name: MustNewName("example.com.")}, res) + if err == nil || err.Error() != "ResourceBody: SVCBResource.Params: value too long (>65535 bytes)" { + t.Fatalf(`b.SVCBResource() = %v; want = "ResourceBody: SVCBResource.Params: value too long (>65535 bytes)"`, err) + } + + err = b.HTTPSResource(ResourceHeader{Name: MustNewName("example.com.")}, HTTPSResource{res}) + if err == nil || err.Error() != "ResourceBody: SVCBResource.Params: value too long (>65535 bytes)" { + t.Fatalf(`b.HTTPSResource() = %v; want = "ResourceBody: SVCBResource.Params: value too long (>65535 bytes)"`, err) + } +} diff --git a/go.mod b/go.mod index 39cac244a1..f58c787ab5 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,10 @@ module golang.org/x/net -go 1.23.0 +go 1.24.0 require ( - golang.org/x/crypto v0.41.0 - golang.org/x/sys v0.35.0 - golang.org/x/term v0.34.0 - golang.org/x/text v0.28.0 + golang.org/x/crypto v0.46.0 + golang.org/x/sys v0.39.0 + golang.org/x/term v0.38.0 + golang.org/x/text v0.32.0 ) diff --git a/go.sum b/go.sum index 1ce4678e25..ca5a57bbb9 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= -golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= -golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= diff --git a/html/escape.go b/html/escape.go index 04c6bec210..12f2273706 100644 --- a/html/escape.go +++ b/html/escape.go @@ -299,7 +299,7 @@ func escape(w writer, s string) error { case '\r': esc = " " default: - panic("unrecognized escape character") + panic("html: unrecognized escape character") } s = s[i+1:] if _, err := w.WriteString(esc); err != nil { diff --git a/html/parse.go b/html/parse.go index 518ee4c94e..88fc0056a3 100644 --- a/html/parse.go +++ b/html/parse.go @@ -136,7 +136,7 @@ func (p *parser) indexOfElementInScope(s scope, matchTags ...a.Atom) int { return -1 } default: - panic("unreachable") + panic(fmt.Sprintf("html: internal error: indexOfElementInScope unknown scope: %d", s)) } } switch s { @@ -179,7 +179,7 @@ func (p *parser) clearStackToContext(s scope) { return } default: - panic("unreachable") + panic(fmt.Sprintf("html: internal error: clearStackToContext unknown scope: %d", s)) } } } @@ -231,7 +231,14 @@ func (p *parser) addChild(n *Node) { } if n.Type == ElementNode { - p.oe = append(p.oe, n) + p.insertOpenElement(n) + } +} + +func (p *parser) insertOpenElement(n *Node) { + p.oe = append(p.oe, n) + if len(p.oe) > 512 { + panic("html: open stack of elements exceeds 512 nodes") } } @@ -810,7 +817,7 @@ func afterHeadIM(p *parser) bool { p.im = inFramesetIM return true case a.Base, a.Basefont, a.Bgsound, a.Link, a.Meta, a.Noframes, a.Script, a.Style, a.Template, a.Title: - p.oe = append(p.oe, p.head) + p.insertOpenElement(p.head) defer p.oe.remove(p.head) return inHeadIM(p) case a.Head: @@ -1678,7 +1685,7 @@ func inTableBodyIM(p *parser) bool { return inTableIM(p) } -// Section 12.2.6.4.14. +// Section 13.2.6.4.14. func inRowIM(p *parser) bool { switch p.tok.Type { case StartTagToken: @@ -1690,7 +1697,9 @@ func inRowIM(p *parser) bool { p.im = inCellIM return true case a.Caption, a.Col, a.Colgroup, a.Tbody, a.Tfoot, a.Thead, a.Tr: - if p.popUntil(tableScope, a.Tr) { + if p.elementInScope(tableScope, a.Tr) { + p.clearStackToContext(tableRowScope) + p.oe.pop() p.im = inTableBodyIM return false } @@ -1700,22 +1709,28 @@ func inRowIM(p *parser) bool { case EndTagToken: switch p.tok.DataAtom { case a.Tr: - if p.popUntil(tableScope, a.Tr) { + if p.elementInScope(tableScope, a.Tr) { + p.clearStackToContext(tableRowScope) + p.oe.pop() p.im = inTableBodyIM return true } // Ignore the token. return true case a.Table: - if p.popUntil(tableScope, a.Tr) { + if p.elementInScope(tableScope, a.Tr) { + p.clearStackToContext(tableRowScope) + p.oe.pop() p.im = inTableBodyIM return false } // Ignore the token. return true case a.Tbody, a.Tfoot, a.Thead: - if p.elementInScope(tableScope, p.tok.DataAtom) { - p.parseImpliedToken(EndTagToken, a.Tr, a.Tr.String()) + if p.elementInScope(tableScope, p.tok.DataAtom) && p.elementInScope(tableScope, a.Tr) { + p.clearStackToContext(tableRowScope) + p.oe.pop() + p.im = inTableBodyIM return false } // Ignore the token. @@ -2222,16 +2237,20 @@ func parseForeignContent(p *parser) bool { p.acknowledgeSelfClosingTag() } case EndTagToken: + if strings.EqualFold(p.oe[len(p.oe)-1].Data, p.tok.Data) { + p.oe = p.oe[:len(p.oe)-1] + return true + } for i := len(p.oe) - 1; i >= 0; i-- { - if p.oe[i].Namespace == "" { - return p.im(p) - } if strings.EqualFold(p.oe[i].Data, p.tok.Data) { p.oe = p.oe[:i] + return true + } + if i > 0 && p.oe[i-1].Namespace == "" { break } } - return true + return p.im(p) default: // Ignore the token. } @@ -2312,9 +2331,13 @@ func (p *parser) parseCurrentToken() { } } -func (p *parser) parse() error { +func (p *parser) parse() (err error) { + defer func() { + if panicErr := recover(); panicErr != nil { + err = fmt.Errorf("%s", panicErr) + } + }() // Iterate until EOF. Any other error will cause an early return. - var err error for err != io.EOF { // CDATA sections are allowed only in foreign content. n := p.oe.top() @@ -2343,6 +2366,8 @@ func (p *parser) parse() error { // s. Conversely, explicit s in r's data can be silently dropped, // with no corresponding node in the resulting tree. // +// Parse will reject HTML that is nested deeper than 512 elements. +// // The input is assumed to be UTF-8 encoded. func Parse(r io.Reader) (*Node, error) { return ParseWithOptions(r) diff --git a/html/parse_test.go b/html/parse_test.go index fea110a4b3..fe66eb44e8 100644 --- a/html/parse_test.go +++ b/html/parse_test.go @@ -251,31 +251,35 @@ func TestParser(t *testing.T) { t.Fatal(err) } for _, tf := range testFiles { - f, err := os.Open(tf) - if err != nil { - t.Fatal(err) - } - defer f.Close() - r := bufio.NewReader(f) - - for i := 0; ; i++ { - ta, err := readParseTest(r) - if err == io.EOF { - break - } + t.Run(tf, func(t *testing.T) { + f, err := os.Open(tf) if err != nil { t.Fatal(err) } - if parseTestBlacklist[ta.text] { - continue + defer f.Close() + r := bufio.NewReader(f) + + for i := 0; ; i++ { + ta, err := readParseTest(r) + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if parseTestBlacklist[ta.text] { + continue + } + + t.Run(fmt.Sprint(i), func(t *testing.T) { + err = testParseCase(ta.text, ta.want, ta.context, ParseOptionEnableScripting(ta.scripting)) + + if err != nil { + t.Errorf("%s test #%d %q, %s", tf, i, ta.text, err) + } + }) } - - err = testParseCase(ta.text, ta.want, ta.context, ParseOptionEnableScripting(ta.scripting)) - - if err != nil { - t.Errorf("%s test #%d %q, %s", tf, i, ta.text, err) - } - } + }) } } } @@ -506,3 +510,35 @@ func BenchmarkParser(b *testing.B) { Parse(bytes.NewBuffer(buf)) } } + +func TestIssue70179(t *testing.T) { + _, err := Parse(strings.NewReader("")) + if err != nil { + t.Fatalf("unexpected failure: %v", err) + } +} + +func TestDepthLimit(t *testing.T) { + for _, tc := range []struct { + name string + input string + succeed bool + }{ + // Not we don't use 512 as the limit here, because the parser will + // insert implied and tags, increasing the size of the + // stack by two before we start parsing the
. + {"above depth limit", strings.Repeat("
", 511), false}, + {"below depth limit", strings.Repeat("
", 510), true}, + {"above depth limit, interspersed elements", strings.Repeat("
", 511), false}, + {"closing tags", strings.Repeat("
", 512), true}, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := Parse(strings.NewReader(tc.input)) + if tc.succeed && err != nil { + t.Errorf("unexpected error: %v", err) + } else if !tc.succeed && err == nil { + t.Errorf("unexpected success") + } + }) + } +} diff --git a/html/render.go b/html/render.go index e8c1233455..0157d89e1f 100644 --- a/html/render.go +++ b/html/render.go @@ -184,7 +184,7 @@ func render1(w writer, n *Node) error { return err } - // Add initial newline where there is danger of a newline beging ignored. + // Add initial newline where there is danger of a newline being ignored. if c := n.FirstChild; c != nil && c.Type == TextNode && strings.HasPrefix(c.Data, "\n") { switch n.Data { case "pre", "listing", "textarea": diff --git a/html/token_test.go b/html/token_test.go index 44773f1712..e5ac62308b 100644 --- a/html/token_test.go +++ b/html/token_test.go @@ -908,7 +908,7 @@ func benchmarkTokenizer(b *testing.B, level int) { // not unescape < to <, or lower-case tag names and attribute keys. z.Raw() case lowLevel: - // Caling z.Text, z.TagName and z.TagAttr returns []byte values + // Calling z.Text, z.TagName and z.TagAttr returns []byte values // whose contents may change on the next call to z.Next. switch tt { case TextToken, CommentToken, DoctypeToken: diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index f9e9a2fdaa..0f57d3e37e 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -5,6 +5,8 @@ // Infrastructure for testing ClientConn.RoundTrip. // Put actual tests in transport_test.go. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -17,6 +19,7 @@ import ( "reflect" "sync/atomic" "testing" + "testing/synctest" "time" "golang.org/x/net/http2/hpack" @@ -24,7 +27,8 @@ import ( ) // TestTestClientConn demonstrates usage of testClientConn. -func TestTestClientConn(t *testing.T) { +func TestTestClientConn(t *testing.T) { synctestTest(t, testTestClientConn) } +func testTestClientConn(t testing.TB) { // newTestClientConn creates a *ClientConn and surrounding test infrastructure. tc := newTestClientConn(t) @@ -91,12 +95,11 @@ func TestTestClientConn(t *testing.T) { // testClientConn manages synchronization, so tests can generally be written as // a linear sequence of actions and validations without additional synchronization. type testClientConn struct { - t *testing.T + t testing.TB - tr *Transport - fr *Framer - cc *ClientConn - group *synctestGroup + tr *Transport + fr *Framer + cc *ClientConn testConnFramer encbuf bytes.Buffer @@ -107,12 +110,11 @@ type testClientConn struct { netconn *synctestNetConn } -func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn { +func newTestClientConnFromClientConn(t testing.TB, cc *ClientConn) *testClientConn { tc := &testClientConn{ - t: t, - tr: cc.t, - cc: cc, - group: cc.t.transportTestHooks.group.(*synctestGroup), + t: t, + tr: cc.t, + cc: cc, } // srv is the side controlled by the test. @@ -121,7 +123,7 @@ func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientCo // If cc.tconn is nil, we're being called with a new conn created by the // Transport's client pool. This path skips dialing the server, and we // create a test connection pair here. - cc.tconn, srv = synctestNetPipe(tc.group) + cc.tconn, srv = synctestNetPipe() } else { // If cc.tconn is non-nil, we're in a test which provides a conn to the // Transport via a TLSNextProto hook. Extract the test connection pair. @@ -133,7 +135,7 @@ func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientCo srv = cc.tconn.(*synctestNetConn).peer } - srv.SetReadDeadline(tc.group.Now()) + srv.SetReadDeadline(time.Now()) srv.autoWait = true tc.netconn = srv tc.enc = hpack.NewEncoder(&tc.encbuf) @@ -163,12 +165,12 @@ func (tc *testClientConn) readClientPreface() { } } -func newTestClientConn(t *testing.T, opts ...any) *testClientConn { +func newTestClientConn(t testing.TB, opts ...any) *testClientConn { t.Helper() tt := newTestTransport(t, opts...) const singleUse = false - _, err := tt.tr.newClientConn(nil, singleUse) + _, err := tt.tr.newClientConn(nil, singleUse, nil) if err != nil { t.Fatalf("newClientConn: %v", err) } @@ -176,18 +178,6 @@ func newTestClientConn(t *testing.T, opts ...any) *testClientConn { return tt.getConn() } -// sync waits for the ClientConn under test to reach a stable state, -// with all goroutines blocked on some input. -func (tc *testClientConn) sync() { - tc.group.Wait() -} - -// advance advances synthetic time by a duration. -func (tc *testClientConn) advance(d time.Duration) { - tc.group.AdvanceTime(d) - tc.sync() -} - // hasFrame reports whether a frame is available to be read. func (tc *testClientConn) hasFrame() bool { return len(tc.netconn.Peek()) > 0 @@ -204,6 +194,13 @@ func (tc *testClientConn) closeWrite() { tc.netconn.Close() } +// closeWrite causes the net.Conn used by the ClientConn to return a error +// from Write calls. +func (tc *testClientConn) closeWriteWithError(err error) { + tc.netconn.loc.setReadError(io.EOF) + tc.netconn.loc.setWriteError(err) +} + // testRequestBody is a Request.Body for use in tests. type testRequestBody struct { tc *testClientConn @@ -258,17 +255,17 @@ func (b *testRequestBody) Close() error { // writeBytes adds n arbitrary bytes to the body. func (b *testRequestBody) writeBytes(n int) { - defer b.tc.sync() + defer synctest.Wait() b.gate.Lock() defer b.unlock() b.bytes += n b.checkWrite() - b.tc.sync() + synctest.Wait() } // Write adds bytes to the body. func (b *testRequestBody) Write(p []byte) (int, error) { - defer b.tc.sync() + defer synctest.Wait() b.gate.Lock() defer b.unlock() n, err := b.buf.Write(p) @@ -287,7 +284,7 @@ func (b *testRequestBody) checkWrite() { // closeWithError sets an error which will be returned by Read. func (b *testRequestBody) closeWithError(err error) { - defer b.tc.sync() + defer synctest.Wait() b.gate.Lock() defer b.unlock() b.err = err @@ -298,19 +295,21 @@ func (b *testRequestBody) closeWithError(err error) { // (Note that the RoundTrip won't complete until response headers are received, // the request times out, or some other terminal condition is reached.) func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) rt := &testRoundTrip{ - t: tc.t, - donec: make(chan struct{}), + t: tc.t, + donec: make(chan struct{}), + cancel: cancel, } tc.roundtrips = append(tc.roundtrips, rt) go func() { - tc.group.Join() defer close(rt.donec) rt.resp, rt.respErr = tc.cc.roundTrip(req, func(cs *clientStream) { rt.id.Store(cs.ID) }) }() - tc.sync() + synctest.Wait() tc.t.Cleanup(func() { if !rt.done() { @@ -336,7 +335,7 @@ func (tc *testClientConn) greet(settings ...Setting) { // makeHeaderBlockFragment encodes headers in a form suitable for inclusion // in a HEADERS or CONTINUATION frame. // -// It takes a list of alernating names and values. +// It takes a list of alternating names and values. func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte { if len(s)%2 != 0 { tc.t.Fatalf("uneven list of header name/value pairs") @@ -366,11 +365,12 @@ func (tc *testClientConn) inflowWindow(streamID uint32) int32 { // testRoundTrip manages a RoundTrip in progress. type testRoundTrip struct { - t *testing.T + t testing.TB resp *http.Response respErr error donec chan struct{} id atomic.Uint32 + cancel context.CancelFunc } // streamID returns the HTTP/2 stream ID of the request. @@ -396,6 +396,7 @@ func (rt *testRoundTrip) done() bool { func (rt *testRoundTrip) result() (*http.Response, error) { t := rt.t t.Helper() + synctest.Wait() select { case <-rt.donec: default: @@ -494,19 +495,16 @@ func diffHeaders(got, want http.Header) string { // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling // should use testClientConn instead. type testTransport struct { - t *testing.T - tr *Transport - group *synctestGroup + t testing.TB + tr *Transport ccs []*testClientConn } -func newTestTransport(t *testing.T, opts ...any) *testTransport { +func newTestTransport(t testing.TB, opts ...any) *testTransport { tt := &testTransport{ - t: t, - group: newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + t: t, } - tt.group.Join() tr := &Transport{} for _, o := range opts { @@ -525,7 +523,6 @@ func newTestTransport(t *testing.T, opts ...any) *testTransport { tt.tr = tr tr.transportTestHooks = &transportTestHooks{ - group: tt.group, newclientconn: func(cc *ClientConn) { tc := newTestClientConnFromClientConn(t, cc) tt.ccs = append(tt.ccs, tc) @@ -533,25 +530,15 @@ func newTestTransport(t *testing.T, opts ...any) *testTransport { } t.Cleanup(func() { - tt.sync() + synctest.Wait() if len(tt.ccs) > 0 { t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs)) } - tt.group.Close(t) }) return tt } -func (tt *testTransport) sync() { - tt.group.Wait() -} - -func (tt *testTransport) advance(d time.Duration) { - tt.group.AdvanceTime(d) - tt.sync() -} - func (tt *testTransport) hasConn() bool { return len(tt.ccs) > 0 } @@ -563,9 +550,9 @@ func (tt *testTransport) getConn() *testClientConn { } tc := tt.ccs[0] tt.ccs = tt.ccs[1:] - tc.sync() + synctest.Wait() tc.readClientPreface() - tc.sync() + synctest.Wait() return tc } @@ -575,11 +562,10 @@ func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { donec: make(chan struct{}), } go func() { - tt.group.Join() defer close(rt.donec) rt.resp, rt.respErr = tt.tr.RoundTrip(req) }() - tt.sync() + synctest.Wait() tt.t.Cleanup(func() { if !rt.done() { diff --git a/http2/config.go b/http2/config.go index ca645d9a1a..8a7a89d016 100644 --- a/http2/config.go +++ b/http2/config.go @@ -27,6 +27,7 @@ import ( // - If the resulting value is zero or out of range, use a default. type http2Config struct { MaxConcurrentStreams uint32 + StrictMaxConcurrentRequests bool MaxDecoderHeaderTableSize uint32 MaxEncoderHeaderTableSize uint32 MaxReadFrameSize uint32 @@ -55,7 +56,7 @@ func configFromServer(h1 *http.Server, h2 *Server) http2Config { PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites, CountError: h2.CountError, } - fillNetHTTPServerConfig(&conf, h1) + fillNetHTTPConfig(&conf, h1.HTTP2) setConfigDefaults(&conf, true) return conf } @@ -64,12 +65,13 @@ func configFromServer(h1 *http.Server, h2 *Server) http2Config { // (the net/http Transport). func configFromTransport(h2 *Transport) http2Config { conf := http2Config{ - MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize, - MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize, - MaxReadFrameSize: h2.MaxReadFrameSize, - SendPingTimeout: h2.ReadIdleTimeout, - PingTimeout: h2.PingTimeout, - WriteByteTimeout: h2.WriteByteTimeout, + StrictMaxConcurrentRequests: h2.StrictMaxConcurrentStreams, + MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize, + MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize, + MaxReadFrameSize: h2.MaxReadFrameSize, + SendPingTimeout: h2.ReadIdleTimeout, + PingTimeout: h2.PingTimeout, + WriteByteTimeout: h2.WriteByteTimeout, } // Unlike most config fields, where out-of-range values revert to the default, @@ -81,7 +83,7 @@ func configFromTransport(h2 *Transport) http2Config { } if h2.t1 != nil { - fillNetHTTPTransportConfig(&conf, h2.t1) + fillNetHTTPConfig(&conf, h2.t1.HTTP2) } setConfigDefaults(&conf, false) return conf @@ -120,3 +122,48 @@ func adjustHTTP1MaxHeaderSize(n int64) int64 { const typicalHeaders = 10 // conservative return n + typicalHeaders*perFieldOverhead } + +func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) { + if h2 == nil { + return + } + if h2.MaxConcurrentStreams != 0 { + conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) + } + if http2ConfigStrictMaxConcurrentRequests(h2) { + conf.StrictMaxConcurrentRequests = true + } + if h2.MaxEncoderHeaderTableSize != 0 { + conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize) + } + if h2.MaxDecoderHeaderTableSize != 0 { + conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize) + } + if h2.MaxConcurrentStreams != 0 { + conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) + } + if h2.MaxReadFrameSize != 0 { + conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize) + } + if h2.MaxReceiveBufferPerConnection != 0 { + conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection) + } + if h2.MaxReceiveBufferPerStream != 0 { + conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream) + } + if h2.SendPingTimeout != 0 { + conf.SendPingTimeout = h2.SendPingTimeout + } + if h2.PingTimeout != 0 { + conf.PingTimeout = h2.PingTimeout + } + if h2.WriteByteTimeout != 0 { + conf.WriteByteTimeout = h2.WriteByteTimeout + } + if h2.PermitProhibitedCipherSuites { + conf.PermitProhibitedCipherSuites = true + } + if h2.CountError != nil { + conf.CountError = h2.CountError + } +} diff --git a/http2/config_go124.go b/http2/config_go124.go deleted file mode 100644 index 5b516c55ff..0000000000 --- a/http2/config_go124.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.24 - -package http2 - -import "net/http" - -// fillNetHTTPServerConfig sets fields in conf from srv.HTTP2. -func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) { - fillNetHTTPConfig(conf, srv.HTTP2) -} - -// fillNetHTTPTransportConfig sets fields in conf from tr.HTTP2. -func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) { - fillNetHTTPConfig(conf, tr.HTTP2) -} - -func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) { - if h2 == nil { - return - } - if h2.MaxConcurrentStreams != 0 { - conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) - } - if h2.MaxEncoderHeaderTableSize != 0 { - conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize) - } - if h2.MaxDecoderHeaderTableSize != 0 { - conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize) - } - if h2.MaxConcurrentStreams != 0 { - conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) - } - if h2.MaxReadFrameSize != 0 { - conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize) - } - if h2.MaxReceiveBufferPerConnection != 0 { - conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection) - } - if h2.MaxReceiveBufferPerStream != 0 { - conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream) - } - if h2.SendPingTimeout != 0 { - conf.SendPingTimeout = h2.SendPingTimeout - } - if h2.PingTimeout != 0 { - conf.PingTimeout = h2.PingTimeout - } - if h2.WriteByteTimeout != 0 { - conf.WriteByteTimeout = h2.WriteByteTimeout - } - if h2.PermitProhibitedCipherSuites { - conf.PermitProhibitedCipherSuites = true - } - if h2.CountError != nil { - conf.CountError = h2.CountError - } -} diff --git a/http2/config_go125.go b/http2/config_go125.go new file mode 100644 index 0000000000..b4373fe33c --- /dev/null +++ b/http2/config_go125.go @@ -0,0 +1,15 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.26 + +package http2 + +import ( + "net/http" +) + +func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool { + return false +} diff --git a/http2/config_go126.go b/http2/config_go126.go new file mode 100644 index 0000000000..6b071c149d --- /dev/null +++ b/http2/config_go126.go @@ -0,0 +1,15 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.26 + +package http2 + +import ( + "net/http" +) + +func http2ConfigStrictMaxConcurrentRequests(h2 *http.HTTP2Config) bool { + return h2.StrictMaxConcurrentRequests +} diff --git a/http2/config_pre_go124.go b/http2/config_pre_go124.go deleted file mode 100644 index 060fd6c64c..0000000000 --- a/http2/config_pre_go124.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.24 - -package http2 - -import "net/http" - -// Pre-Go 1.24 fallback. -// The Server.HTTP2 and Transport.HTTP2 config fields were added in Go 1.24. - -func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {} - -func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {} diff --git a/http2/config_test.go b/http2/config_test.go index b8e7a7b043..88e05e0aa4 100644 --- a/http2/config_test.go +++ b/http2/config_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.24 +//go:build go1.25 || goexperiment.synctest package http2 @@ -12,7 +12,8 @@ import ( "time" ) -func TestConfigServerSettings(t *testing.T) { +func TestConfigServerSettings(t *testing.T) { synctestTest(t, testConfigServerSettings) } +func testConfigServerSettings(t testing.TB) { config := &http.HTTP2Config{ MaxConcurrentStreams: 1, MaxDecoderHeaderTableSize: 1<<20 + 2, @@ -37,7 +38,8 @@ func TestConfigServerSettings(t *testing.T) { }) } -func TestConfigTransportSettings(t *testing.T) { +func TestConfigTransportSettings(t *testing.T) { synctestTest(t, testConfigTransportSettings) } +func testConfigTransportSettings(t testing.TB) { config := &http.HTTP2Config{ MaxConcurrentStreams: 1, // ignored by Transport MaxDecoderHeaderTableSize: 1<<20 + 2, @@ -60,7 +62,8 @@ func TestConfigTransportSettings(t *testing.T) { tc.wantWindowUpdate(0, uint32(config.MaxReceiveBufferPerConnection)) } -func TestConfigPingTimeoutServer(t *testing.T) { +func TestConfigPingTimeoutServer(t *testing.T) { synctestTest(t, testConfigPingTimeoutServer) } +func testConfigPingTimeoutServer(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { s.ReadIdleTimeout = 2 * time.Second @@ -68,13 +71,14 @@ func TestConfigPingTimeoutServer(t *testing.T) { }) st.greet() - st.advance(2 * time.Second) + time.Sleep(2 * time.Second) _ = readFrame[*PingFrame](t, st) - st.advance(3 * time.Second) + time.Sleep(3 * time.Second) st.wantClosed() } -func TestConfigPingTimeoutTransport(t *testing.T) { +func TestConfigPingTimeoutTransport(t *testing.T) { synctestTest(t, testConfigPingTimeoutTransport) } +func testConfigPingTimeoutTransport(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.ReadIdleTimeout = 2 * time.Second tr.PingTimeout = 3 * time.Second @@ -85,9 +89,9 @@ func TestConfigPingTimeoutTransport(t *testing.T) { rt := tc.roundTrip(req) tc.wantFrameType(FrameHeaders) - tc.advance(2 * time.Second) + time.Sleep(2 * time.Second) tc.wantFrameType(FramePing) - tc.advance(3 * time.Second) + time.Sleep(3 * time.Second) err := rt.err() if err == nil { t.Fatalf("expected connection to close") diff --git a/http2/connframes_test.go b/http2/connframes_test.go index 2c4532571a..f2e6eb520a 100644 --- a/http2/connframes_test.go +++ b/http2/connframes_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -108,7 +110,7 @@ frame: if typ.Kind() != reflect.Func || typ.NumIn() != 1 || typ.NumOut() != 1 || - typ.Out(0) != reflect.TypeOf(true) { + typ.Out(0) != reflect.TypeFor[bool]() { tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f) } if typ.In(0) == reflect.TypeOf(fr) { diff --git a/http2/frame.go b/http2/frame.go index db3264da8c..9a4bd123c9 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -280,6 +280,8 @@ type Framer struct { // lastHeaderStream is non-zero if the last frame was an // unfinished HEADERS/CONTINUATION. lastHeaderStream uint32 + // lastFrameType holds the type of the last frame for verifying frame order. + lastFrameType FrameType maxReadSize uint32 headerBuf [frameHeaderLen]byte @@ -347,7 +349,7 @@ func (fr *Framer) maxHeaderListSize() uint32 { func (f *Framer) startWrite(ftype FrameType, flags Flags, streamID uint32) { // Write the FrameHeader. f.wbuf = append(f.wbuf[:0], - 0, // 3 bytes of length, filled in in endWrite + 0, // 3 bytes of length, filled in endWrite 0, 0, byte(ftype), @@ -488,30 +490,41 @@ func terminalReadFrameError(err error) bool { return err != nil } -// ReadFrame reads a single frame. The returned Frame is only valid -// until the next call to ReadFrame. +// ReadFrameHeader reads the header of the next frame. +// It reads the 9-byte fixed frame header, and does not read any portion of the +// frame payload. The caller is responsible for consuming the payload, either +// with ReadFrameForHeader or directly from the Framer's io.Reader. // -// If the frame is larger than previously set with SetMaxReadFrameSize, the -// returned error is ErrFrameTooLarge. Other errors may be of type -// ConnectionError, StreamError, or anything else from the underlying -// reader. +// If the frame is larger than previously set with SetMaxReadFrameSize, it +// returns the frame header and ErrFrameTooLarge. // -// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID -// indicates the stream responsible for the error. -func (fr *Framer) ReadFrame() (Frame, error) { +// If the returned FrameHeader.StreamID is non-zero, it indicates the stream +// responsible for the error. +func (fr *Framer) ReadFrameHeader() (FrameHeader, error) { fr.errDetail = nil - if fr.lastFrame != nil { - fr.lastFrame.invalidate() - } fh, err := readFrameHeader(fr.headerBuf[:], fr.r) if err != nil { - return nil, err + return fh, err } if fh.Length > fr.maxReadSize { if fh == invalidHTTP1LookingFrameHeader() { - return nil, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", ErrFrameTooLarge) + return fh, fmt.Errorf("http2: failed reading the frame payload: %w, note that the frame header looked like an HTTP/1.1 header", ErrFrameTooLarge) } - return nil, ErrFrameTooLarge + return fh, ErrFrameTooLarge + } + if err := fr.checkFrameOrder(fh); err != nil { + return fh, err + } + return fh, nil +} + +// ReadFrameForHeader reads the payload for the frame with the given FrameHeader. +// +// It behaves identically to ReadFrame, other than not checking the maximum +// frame size. +func (fr *Framer) ReadFrameForHeader(fh FrameHeader) (Frame, error) { + if fr.lastFrame != nil { + fr.lastFrame.invalidate() } payload := fr.getReadBuf(fh.Length) if _, err := io.ReadFull(fr.r, payload); err != nil { @@ -527,9 +540,7 @@ func (fr *Framer) ReadFrame() (Frame, error) { } return nil, err } - if err := fr.checkFrameOrder(f); err != nil { - return nil, err - } + fr.lastFrame = f if fr.logReads { fr.debugReadLoggerf("http2: Framer %p: read %v", fr, summarizeFrame(f)) } @@ -539,6 +550,24 @@ func (fr *Framer) ReadFrame() (Frame, error) { return f, nil } +// ReadFrame reads a single frame. The returned Frame is only valid +// until the next call to ReadFrame or ReadFrameBodyForHeader. +// +// If the frame is larger than previously set with SetMaxReadFrameSize, the +// returned error is ErrFrameTooLarge. Other errors may be of type +// ConnectionError, StreamError, or anything else from the underlying +// reader. +// +// If ReadFrame returns an error and a non-nil Frame, the Frame's StreamID +// indicates the stream responsible for the error. +func (fr *Framer) ReadFrame() (Frame, error) { + fh, err := fr.ReadFrameHeader() + if err != nil { + return nil, err + } + return fr.ReadFrameForHeader(fh) +} + // connError returns ConnectionError(code) but first // stashes away a public reason to the caller can optionally relay it // to the peer before hanging up on them. This might help others debug @@ -551,20 +580,19 @@ func (fr *Framer) connError(code ErrCode, reason string) error { // checkFrameOrder reports an error if f is an invalid frame to return // next from ReadFrame. Mostly it checks whether HEADERS and // CONTINUATION frames are contiguous. -func (fr *Framer) checkFrameOrder(f Frame) error { - last := fr.lastFrame - fr.lastFrame = f +func (fr *Framer) checkFrameOrder(fh FrameHeader) error { + lastType := fr.lastFrameType + fr.lastFrameType = fh.Type if fr.AllowIllegalReads { return nil } - fh := f.Header() if fr.lastHeaderStream != 0 { if fh.Type != FrameContinuation { return fr.connError(ErrCodeProtocol, fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d", fh.Type, fh.StreamID, - last.Header().Type, fr.lastHeaderStream)) + lastType, fr.lastHeaderStream)) } if fh.StreamID != fr.lastHeaderStream { return fr.connError(ErrCodeProtocol, @@ -1152,7 +1180,16 @@ type PriorityFrame struct { PriorityParam } -// PriorityParam are the stream prioritzation parameters. +var defaultRFC9218Priority = PriorityParam{ + incremental: 0, + urgency: 3, +} + +// Note that HTTP/2 has had two different prioritization schemes, and +// PriorityParam struct below is a superset of both schemes. The exported +// symbols are from RFC 7540 and the non-exported ones are from RFC 9218. + +// PriorityParam are the stream prioritization parameters. type PriorityParam struct { // StreamDep is a 31-bit stream identifier for the // stream that this stream depends on. Zero means no @@ -1167,6 +1204,20 @@ type PriorityParam struct { // the spec, "Add one to the value to obtain a weight between // 1 and 256." Weight uint8 + + // "The urgency (u) parameter value is Integer (see Section 3.3.1 of + // [STRUCTURED-FIELDS]), between 0 and 7 inclusive, in descending order of + // priority. The default is 3." + urgency uint8 + + // "The incremental (i) parameter value is Boolean (see Section 3.3.6 of + // [STRUCTURED-FIELDS]). It indicates if an HTTP response can be processed + // incrementally, i.e., provide some meaningful output as chunks of the + // response arrive." + // + // We use uint8 (i.e. 0 is false, 1 is true) instead of bool so we can + // avoid unnecessary type conversions and because either type takes 1 byte. + incremental uint8 } func (p PriorityParam) IsZero() bool { diff --git a/http2/frame_test.go b/http2/frame_test.go index 68505317e1..a2b136d136 100644 --- a/http2/frame_test.go +++ b/http2/frame_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -823,7 +825,7 @@ func TestReadFrameOrder(t *testing.T) { }, }, 9: { - wantErr: "CONTINUATION frame with stream ID 0", + wantErr: "unexpected CONTINUATION for stream 0", w: func(f *Framer) { cont(f, 0, true) }, @@ -873,7 +875,7 @@ func TestReadFrameOrder(t *testing.T) { continue } if !((f.errDetail == nil && tt.wantErr == "") || (fmt.Sprint(f.errDetail) == tt.wantErr)) { - t.Errorf("%d. framer eror = %q; want %q\n%s", i, f.errDetail, tt.wantErr, log.Bytes()) + t.Errorf("%d. framer error = %q; want %q\n%s", i, f.errDetail, tt.wantErr, log.Bytes()) } if n < tt.atLeast { t.Errorf("%d. framer only read %d frames; want at least %d\n%s", i, n, tt.atLeast, log.Bytes()) @@ -1276,3 +1278,110 @@ func TestTypeFrameParser(t *testing.T) { t.Errorf("expected UnknownFrame, got %T", frame) } } + +func TestReadFrameHeaderAndBody(t *testing.T) { + fr, _ := testFramer() + var streamID uint32 = 1 + data := []byte("ABC") + if err := fr.WriteData(streamID, true, data); err != nil { + t.Fatalf("WriteData(%d, true, %q) failed: %v", streamID, data, err) + } + + fh, err := fr.ReadFrameHeader() + if err != nil { + t.Fatalf("ReadFrameHeader failed: %v", err) + } + wantHeader := FrameHeader{ + Type: FrameData, + Flags: FlagDataEndStream, + Length: 3, + StreamID: 1, + valid: true, + } + if !fh.Equal(wantHeader) { + t.Fatalf("ReadFrameHeader = %+v; want %+v", fh, wantHeader) + } + + f, err := fr.ReadFrameForHeader(fh) + if err != nil { + t.Fatalf("ReadFrameForHeader failed: %v", err) + } + + if !fh.Equal(f.Header()) { + t.Fatalf("Frame.Header() = %+v; want %+v", f.Header(), fh) + } + + df, ok := f.(*DataFrame) + if !ok { + t.Fatalf("got %T; want *DataFrame", f) + } + if got, want := df.Data(), data; !bytes.Equal(got, want) { + t.Errorf("DataFrame.Data() = %q; want %q", string(got), string(want)) + } + if got, want := df.StreamEnded(), true; got != want { + t.Errorf("DataFrame.StreamEnded() = %v; want %v", got, want) + } +} + +func TestReadFrameHeaderFrameTooLarge(t *testing.T) { + fr, _ := testFramer() + fr.SetMaxReadFrameSize(2) + if err := fr.WriteData(1, true, []byte("ABC")); err != nil { + t.Fatalf("WriteData failed: %v", err) + } + fh, err := fr.ReadFrameHeader() + if gotErr, wantErr := err, ErrFrameTooLarge; gotErr != wantErr { + t.Fatalf("ReadFrameHeader returned error %v; want %v", gotErr, wantErr) + } + if fh.StreamID != 1 { + t.Errorf("ReadFrameHeader = %v, %v; want StreamID 1", fh, err) + } +} + +func TestReadFrameHeaderBadFrameOrder(t *testing.T) { + fr, _ := testFramer() + if err := fr.WriteHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: []byte("foo"), // unused, but non-empty + EndHeaders: false, + }); err != nil { + t.Fatalf("WriteHeaders failed: %v", err) + } + + // Write a CONTINUATION frame for stream 2 without first finishing the headers for stream 1. + if err := fr.WriteContinuation(2, true, []byte("foo")); err != nil { + t.Fatalf("WriteContinuation failed: %v", err) + } + + fh, err := fr.ReadFrameHeader() + if err != nil { + t.Fatalf("ReadFrameHeader failed: %v", err) + } + if _, err = fr.ReadFrameForHeader(fh); err != nil { + t.Fatalf("ReadFrameForHeader failed: %v", err) + } + + if _, err := fr.ReadFrameHeader(); err != ConnectionError(ErrCodeProtocol) { + t.Fatalf("ReadFrameHeader returned error %v; want ConnectionError(ErrCodeProtocol)", err) + } +} + +func TestReadFrameForHeaderUnexpectedEOF(t *testing.T) { + fr, b := testFramer() + if err := fr.WriteData(1, true, []byte("ABC")); err != nil { + t.Fatalf("WriteData failed: %v", err) + } + + fh, err := fr.ReadFrameHeader() + if err != nil { + t.Fatalf("ReadFrameHeader failed: %v", err) + } + + // Remove one byte from the body, corrupting the frame body. + b.Truncate(b.Len() - 1) + + _, err = fr.ReadFrameForHeader(fh) + if err != io.ErrUnexpectedEOF { + t.Fatalf("ReadFrameForHeader with short body = %v; want io.ErrUnexpectedEOF", err) + } +} diff --git a/http2/gotrack.go b/http2/gotrack.go index 9933c9f8c7..9921ca096d 100644 --- a/http2/gotrack.go +++ b/http2/gotrack.go @@ -15,21 +15,32 @@ import ( "runtime" "strconv" "sync" + "sync/atomic" ) var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" +// Setting DebugGoroutines to false during a test to disable goroutine debugging +// results in race detector complaints when a test leaves goroutines running before +// returning. Tests shouldn't do this, of course, but when they do it generally shows +// up as infrequent, hard-to-debug flakes. (See #66519.) +// +// Disable goroutine debugging during individual tests with an atomic bool. +// (Note that it's safe to enable/disable debugging mid-test, so the actual race condition +// here is harmless.) +var disableDebugGoroutines atomic.Bool + type goroutineLock uint64 func newGoroutineLock() goroutineLock { - if !DebugGoroutines { + if !DebugGoroutines || disableDebugGoroutines.Load() { return 0 } return goroutineLock(curGoroutineID()) } func (g goroutineLock) check() { - if !DebugGoroutines { + if !DebugGoroutines || disableDebugGoroutines.Load() { return } if curGoroutineID() != uint64(g) { @@ -38,7 +49,7 @@ func (g goroutineLock) check() { } func (g goroutineLock) checkNotOn() { - if !DebugGoroutines { + if !DebugGoroutines || disableDebugGoroutines.Load() { return } if curGoroutineID() == uint64(g) { diff --git a/http2/gotrack_test.go b/http2/gotrack_test.go index 06db61231d..18b8961c05 100644 --- a/http2/gotrack_test.go +++ b/http2/gotrack_test.go @@ -11,10 +11,6 @@ import ( ) func TestGoroutineLock(t *testing.T) { - oldDebug := DebugGoroutines - DebugGoroutines = true - defer func() { DebugGoroutines = oldDebug }() - g := newGoroutineLock() g.check() diff --git a/http2/http2.go b/http2/http2.go index ea5ae629fd..105fe12fef 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -15,7 +15,6 @@ package http2 // import "golang.org/x/net/http2" import ( "bufio" - "context" "crypto/tls" "errors" "fmt" @@ -35,7 +34,6 @@ var ( VerboseLogs bool logFrameWrites bool logFrameReads bool - inTests bool // Enabling extended CONNECT by causes browsers to attempt to use // WebSockets-over-HTTP/2. This results in problems when the server's websocket @@ -255,15 +253,13 @@ func (cw closeWaiter) Wait() { // idle memory usage with many connections. type bufferedWriter struct { _ incomparable - group synctestGroupInterface // immutable - conn net.Conn // immutable - bw *bufio.Writer // non-nil when data is buffered - byteTimeout time.Duration // immutable, WriteByteTimeout + conn net.Conn // immutable + bw *bufio.Writer // non-nil when data is buffered + byteTimeout time.Duration // immutable, WriteByteTimeout } -func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter { +func newBufferedWriter(conn net.Conn, timeout time.Duration) *bufferedWriter { return &bufferedWriter{ - group: group, conn: conn, byteTimeout: timeout, } @@ -314,24 +310,18 @@ func (w *bufferedWriter) Flush() error { type bufferedWriterTimeoutWriter bufferedWriter func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) { - return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p) + return writeWithByteTimeout(w.conn, w.byteTimeout, p) } // writeWithByteTimeout writes to conn. // If more than timeout passes without any bytes being written to the connection, // the write fails. -func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) { +func writeWithByteTimeout(conn net.Conn, timeout time.Duration, p []byte) (n int, err error) { if timeout <= 0 { return conn.Write(p) } for { - var now time.Time - if group == nil { - now = time.Now() - } else { - now = group.Now() - } - conn.SetWriteDeadline(now.Add(timeout)) + conn.SetWriteDeadline(time.Now().Add(timeout)) nn, err := conn.Write(p[n:]) n += nn if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) { @@ -417,14 +407,3 @@ func (s *sorter) SortStrings(ss []string) { // makes that struct also non-comparable, and generally doesn't add // any size (as long as it's first). type incomparable [0]func() - -// synctestGroupInterface is the methods of synctestGroup used by Server and Transport. -// It's defined as an interface here to let us keep synctestGroup entirely test-only -// and not a part of non-test builds. -type synctestGroupInterface interface { - Join() - Now() time.Time - NewTimer(d time.Duration) timer - AfterFunc(d time.Duration, f func()) timer - ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) -} diff --git a/http2/http2_test.go b/http2/http2_test.go index c7774133a7..cd38b96d15 100644 --- a/http2/http2_test.go +++ b/http2/http2_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -28,7 +30,6 @@ func condSkipFailingTest(t *testing.T) { } func init() { - inTests = true DebugGoroutines = true flag.BoolVar(&VerboseLogs, "verboseh2", VerboseLogs, "Verbose HTTP/2 debug logging") } @@ -68,7 +69,7 @@ func (w twriter) Write(p []byte) (n int, err error) { } // like encodeHeader, but don't add implicit pseudo headers. -func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte { +func encodeHeaderNoImplicit(t testing.TB, headers ...string) []byte { var buf bytes.Buffer enc := hpack.NewEncoder(&buf) for len(headers) > 0 { @@ -81,35 +82,6 @@ func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte { return buf.Bytes() } -type puppetCommand struct { - fn func(w http.ResponseWriter, r *http.Request) - done chan<- bool -} - -type handlerPuppet struct { - ch chan puppetCommand -} - -func newHandlerPuppet() *handlerPuppet { - return &handlerPuppet{ - ch: make(chan puppetCommand), - } -} - -func (p *handlerPuppet) act(w http.ResponseWriter, r *http.Request) { - for cmd := range p.ch { - cmd.fn(w, r) - cmd.done <- true - } -} - -func (p *handlerPuppet) done() { close(p.ch) } -func (p *handlerPuppet) do(fn func(http.ResponseWriter, *http.Request)) { - done := make(chan bool) - p.ch <- puppetCommand{fn, done} - <-done -} - func cleanDate(res *http.Response) { if d := res.Header["Date"]; len(d) == 1 { d[0] = "XXX" @@ -285,7 +257,7 @@ func TestNoUnicodeStrings(t *testing.T) { } // setForTest sets *p = v, and restores its original value in t.Cleanup. -func setForTest[T any](t *testing.T, p *T, v T) { +func setForTest[T any](t testing.TB, p *T, v T) { orig := *p t.Cleanup(func() { *p = orig @@ -300,3 +272,11 @@ func must[T any](v T, err error) T { } return v } + +// synctestSubtest starts a subtest and runs f in a synctest bubble within it. +func synctestSubtest(t *testing.T, name string, f func(testing.TB)) { + t.Helper() + t.Run(name, func(t *testing.T) { + synctestTest(t, f) + }) +} diff --git a/http2/netconn_test.go b/http2/netconn_test.go index 5a1759579e..4d4124dc69 100644 --- a/http2/netconn_test.go +++ b/http2/netconn_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -14,6 +16,7 @@ import ( "net/netip" "os" "sync" + "testing/synctest" "time" ) @@ -23,13 +26,13 @@ import ( // Unlike net.Pipe, the connection is not synchronous. // Writes are made to a buffer, and return immediately. // By default, the buffer size is unlimited. -func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) { +func synctestNetPipe() (r, w *synctestNetConn) { s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")) s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001")) s1 := newSynctestNetConnHalf(s1addr) s2 := newSynctestNetConnHalf(s2addr) - r = &synctestNetConn{group: group, loc: s1, rem: s2} - w = &synctestNetConn{group: group, loc: s2, rem: s1} + r = &synctestNetConn{loc: s1, rem: s2} + w = &synctestNetConn{loc: s2, rem: s1} r.peer = w w.peer = r return r, w @@ -37,8 +40,6 @@ func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) { // A synctestNetConn is one endpoint of the connection created by synctestNetPipe. type synctestNetConn struct { - group *synctestGroup - // local and remote connection halves. // Each half contains a buffer. // Reads pull from the local buffer, and writes push to the remote buffer. @@ -54,7 +55,7 @@ type synctestNetConn struct { // Read reads data from the connection. func (c *synctestNetConn) Read(b []byte) (n int, err error) { if c.autoWait { - c.group.Wait() + synctest.Wait() } return c.loc.read(b) } @@ -63,7 +64,7 @@ func (c *synctestNetConn) Read(b []byte) (n int, err error) { // without consuming its contents. func (c *synctestNetConn) Peek() []byte { if c.autoWait { - c.group.Wait() + synctest.Wait() } return c.loc.peek() } @@ -71,7 +72,7 @@ func (c *synctestNetConn) Peek() []byte { // Write writes data to the connection. func (c *synctestNetConn) Write(b []byte) (n int, err error) { if c.autoWait { - defer c.group.Wait() + defer synctest.Wait() } return c.rem.write(b) } @@ -79,7 +80,7 @@ func (c *synctestNetConn) Write(b []byte) (n int, err error) { // IsClosedByPeer reports whether the peer has closed its end of the connection. func (c *synctestNetConn) IsClosedByPeer() bool { if c.autoWait { - c.group.Wait() + synctest.Wait() } return c.loc.isClosedByPeer() } @@ -89,7 +90,7 @@ func (c *synctestNetConn) Close() error { c.loc.setWriteError(errors.New("connection closed by peer")) c.rem.setReadError(io.EOF) if c.autoWait { - c.group.Wait() + synctest.Wait() } return nil } @@ -99,7 +100,7 @@ func (c *synctestNetConn) LocalAddr() net.Addr { return c.loc.addr } -// LocalAddr returns the (fake) remote network address. +// RemoteAddr returns the (fake) remote network address. func (c *synctestNetConn) RemoteAddr() net.Addr { return c.rem.addr } @@ -113,13 +114,13 @@ func (c *synctestNetConn) SetDeadline(t time.Time) error { // SetReadDeadline sets the read deadline for the connection. func (c *synctestNetConn) SetReadDeadline(t time.Time) error { - c.loc.rctx.setDeadline(c.group, t) + c.loc.rctx.setDeadline(t) return nil } // SetWriteDeadline sets the write deadline for the connection. func (c *synctestNetConn) SetWriteDeadline(t time.Time) error { - c.rem.wctx.setDeadline(c.group, t) + c.rem.wctx.setDeadline(t) return nil } @@ -300,12 +301,12 @@ func (h *synctestNetConnHalf) setWriteError(err error) { } } -// deadlineContext converts a changable deadline (as in net.Conn.SetDeadline) into a Context. +// deadlineContext converts a changeable deadline (as in net.Conn.SetDeadline) into a Context. type deadlineContext struct { mu sync.Mutex ctx context.Context cancel context.CancelCauseFunc - timer timer + timer *time.Timer } // context returns a Context which expires when the deadline does. @@ -319,7 +320,7 @@ func (t *deadlineContext) context() context.Context { } // setDeadline sets the current deadline. -func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) { +func (t *deadlineContext) setDeadline(deadline time.Time) { t.mu.Lock() defer t.mu.Unlock() // If t.ctx is non-nil and t.cancel is nil, then t.ctx was canceled @@ -335,7 +336,7 @@ func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) // No deadline. return } - if !deadline.After(group.Now()) { + if !deadline.After(time.Now()) { // Deadline has already expired. t.cancel(os.ErrDeadlineExceeded) t.cancel = nil @@ -343,11 +344,11 @@ func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) } if t.timer != nil { // Reuse existing deadline timer. - t.timer.Reset(deadline.Sub(group.Now())) + t.timer.Reset(deadline.Sub(time.Now())) return } // Create a new timer to cancel the context at the deadline. - t.timer = group.AfterFunc(deadline.Sub(group.Now()), func() { + t.timer = time.AfterFunc(deadline.Sub(time.Now()), func() { t.mu.Lock() defer t.mu.Unlock() t.cancel(os.ErrDeadlineExceeded) diff --git a/http2/server.go b/http2/server.go index 51fca38f61..bdc5520ebd 100644 --- a/http2/server.go +++ b/http2/server.go @@ -176,44 +176,15 @@ type Server struct { // so that we don't embed a Mutex in this struct, which will make the // struct non-copyable, which might break some callers. state *serverInternalState - - // Synchronization group used for testing. - // Outside of tests, this is nil. - group synctestGroupInterface -} - -func (s *Server) markNewGoroutine() { - if s.group != nil { - s.group.Join() - } -} - -func (s *Server) now() time.Time { - if s.group != nil { - return s.group.Now() - } - return time.Now() -} - -// newTimer creates a new time.Timer, or a synthetic timer in tests. -func (s *Server) newTimer(d time.Duration) timer { - if s.group != nil { - return s.group.NewTimer(d) - } - return timeTimer{time.NewTimer(d)} -} - -// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. -func (s *Server) afterFunc(d time.Duration, f func()) timer { - if s.group != nil { - return s.group.AfterFunc(d, f) - } - return timeTimer{time.AfterFunc(d, f)} } type serverInternalState struct { mu sync.Mutex activeConns map[*serverConn]struct{} + + // Pool of error channels. This is per-Server rather than global + // because channels can't be reused across synctest bubbles. + errChanPool sync.Pool } func (s *serverInternalState) registerConn(sc *serverConn) { @@ -245,6 +216,27 @@ func (s *serverInternalState) startGracefulShutdown() { s.mu.Unlock() } +// Global error channel pool used for uninitialized Servers. +// We use a per-Server pool when possible to avoid using channels across synctest bubbles. +var errChanPool = sync.Pool{ + New: func() any { return make(chan error, 1) }, +} + +func (s *serverInternalState) getErrChan() chan error { + if s == nil { + return errChanPool.Get().(chan error) // Server used without calling ConfigureServer + } + return s.errChanPool.Get().(chan error) +} + +func (s *serverInternalState) putErrChan(ch chan error) { + if s == nil { + errChanPool.Put(ch) // Server used without calling ConfigureServer + return + } + s.errChanPool.Put(ch) +} + // ConfigureServer adds HTTP/2 support to a net/http Server. // // The configuration conf may be nil. @@ -257,7 +249,10 @@ func ConfigureServer(s *http.Server, conf *Server) error { if conf == nil { conf = new(Server) } - conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})} + conf.state = &serverInternalState{ + activeConns: make(map[*serverConn]struct{}), + errChanPool: sync.Pool{New: func() any { return make(chan error, 1) }}, + } if h1, h2 := s, conf; h2.IdleTimeout == 0 { if h1.IdleTimeout != 0 { h2.IdleTimeout = h1.IdleTimeout @@ -423,6 +418,9 @@ func (o *ServeConnOpts) handler() http.Handler { // // The opts parameter is optional. If nil, default values are used. func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { + if opts == nil { + opts = &ServeConnOpts{} + } s.serveConn(c, opts, nil) } @@ -438,7 +436,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon conn: c, baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), - bw: newBufferedWriter(s.group, c, conf.WriteByteTimeout), + bw: newBufferedWriter(c, conf.WriteByteTimeout), handler: opts.handler(), streams: make(map[uint32]*stream), readFrameCh: make(chan readFrameResult), @@ -638,11 +636,11 @@ type serverConn struct { pingSent bool sentPingData [8]byte goAwayCode ErrCode - shutdownTimer timer // nil until used - idleTimer timer // nil if unused + shutdownTimer *time.Timer // nil until used + idleTimer *time.Timer // nil if unused readIdleTimeout time.Duration pingTimeout time.Duration - readIdleTimer timer // nil if unused + readIdleTimer *time.Timer // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -687,12 +685,12 @@ type stream struct { flow outflow // limits writing from Handler to client inflow inflow // what the client is allowed to POST/etc to us state streamState - resetQueued bool // RST_STREAM queued for write; set by sc.resetStream - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - readDeadline timer // nil if unused - writeDeadline timer // nil if unused - closeErr error // set before cw is closed + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + readDeadline *time.Timer // nil if unused + writeDeadline *time.Timer // nil if unused + closeErr error // set before cw is closed trailer http.Header // accumulated trailers reqTrailer http.Header // handler's Request.Trailer @@ -848,7 +846,6 @@ type readFrameResult struct { // consumer is done with the frame. // It's run on its own goroutine. func (sc *serverConn) readFrames() { - sc.srv.markNewGoroutine() gate := make(chan struct{}) gateDone := func() { gate <- struct{}{} } for { @@ -881,7 +878,6 @@ type frameWriteResult struct { // At most one goroutine can be running writeFrameAsync at a time per // serverConn. func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) { - sc.srv.markNewGoroutine() var err error if wd == nil { err = wr.write.writeFrame(sc) @@ -965,22 +961,22 @@ func (sc *serverConn) serve(conf http2Config) { sc.setConnState(http.StateIdle) if sc.srv.IdleTimeout > 0 { - sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) + sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } if conf.SendPingTimeout > 0 { sc.readIdleTimeout = conf.SendPingTimeout - sc.readIdleTimer = sc.srv.afterFunc(conf.SendPingTimeout, sc.onReadIdleTimer) + sc.readIdleTimer = time.AfterFunc(conf.SendPingTimeout, sc.onReadIdleTimer) defer sc.readIdleTimer.Stop() } go sc.readFrames() // closed by defer sc.conn.Close above - settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer) + settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer) defer settingsTimer.Stop() - lastFrameTime := sc.srv.now() + lastFrameTime := time.Now() loopNum := 0 for { loopNum++ @@ -994,7 +990,7 @@ func (sc *serverConn) serve(conf http2Config) { case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: - lastFrameTime = sc.srv.now() + lastFrameTime = time.Now() // Process any written frames before reading new frames from the client since a // written frame could have triggered a new stream to be started. if sc.writingFrameAsync { @@ -1077,7 +1073,7 @@ func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) { } pingAt := lastFrameReadTime.Add(sc.readIdleTimeout) - now := sc.srv.now() + now := time.Now() if pingAt.After(now) { // We received frames since arming the ping timer. // Reset it for the next possible timeout. @@ -1141,10 +1137,10 @@ func (sc *serverConn) readPreface() error { errc <- nil } }() - timer := sc.srv.newTimer(prefaceTimeout) // TODO: configurable on *Server? + timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server? defer timer.Stop() select { - case <-timer.C(): + case <-timer.C: return errPrefaceTimeout case err := <-errc: if err == nil { @@ -1156,10 +1152,6 @@ func (sc *serverConn) readPreface() error { } } -var errChanPool = sync.Pool{ - New: func() interface{} { return make(chan error, 1) }, -} - var writeDataPool = sync.Pool{ New: func() interface{} { return new(writeData) }, } @@ -1167,7 +1159,7 @@ var writeDataPool = sync.Pool{ // writeDataFromHandler writes DATA response frames from a handler on // the given stream. func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error { - ch := errChanPool.Get().(chan error) + ch := sc.srv.state.getErrChan() writeArg := writeDataPool.Get().(*writeData) *writeArg = writeData{stream.id, data, endStream} err := sc.writeFrameFromHandler(FrameWriteRequest{ @@ -1199,7 +1191,7 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStrea return errStreamClosed } } - errChanPool.Put(ch) + sc.srv.state.putErrChan(ch) if frameWriteDone { writeDataPool.Put(writeArg) } @@ -1513,7 +1505,7 @@ func (sc *serverConn) goAway(code ErrCode) { func (sc *serverConn) shutDownIn(d time.Duration) { sc.serveG.check() - sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer) + sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) } func (sc *serverConn) resetStream(se StreamError) { @@ -2118,7 +2110,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // (in Go 1.8), though. That's a more sane option anyway. if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) - st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout) + st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } return sc.scheduleHandler(id, rw, req, handler) @@ -2216,7 +2208,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.initialStreamRecvWindowSize) if sc.hs.WriteTimeout > 0 { - st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } sc.streams[id] = st @@ -2405,7 +2397,6 @@ func (sc *serverConn) handlerDone() { // Run on its own goroutine. func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { - sc.srv.markNewGoroutine() defer sc.sendServeMsg(handlerDoneMsg) didPanic := true defer func() { @@ -2454,7 +2445,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro // waiting for this frame to be written, so an http.Flush mid-handler // writes out the correct value of keys, before a handler later potentially // mutates it. - errc = errChanPool.Get().(chan error) + errc = sc.srv.state.getErrChan() } if err := sc.writeFrameFromHandler(FrameWriteRequest{ write: headerData, @@ -2466,7 +2457,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro if errc != nil { select { case err := <-errc: - errChanPool.Put(errc) + sc.srv.state.putErrChan(errc) return err case <-sc.doneServing: return errClientDisconnected @@ -2573,7 +2564,7 @@ func (b *requestBody) Read(p []byte) (n int, err error) { if err == io.EOF { b.sawEOF = true } - if b.conn == nil && inTests { + if b.conn == nil { return } b.conn.noteBodyReadFromHandler(b.stream, n, err) @@ -2702,7 +2693,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { var date string if _, ok := rws.snapHeader["Date"]; !ok { // TODO(bradfitz): be faster here, like net/http? measure. - date = rws.conn.srv.now().UTC().Format(http.TimeFormat) + date = time.Now().UTC().Format(http.TimeFormat) } for _, v := range rws.snapHeader["Trailer"] { @@ -2824,7 +2815,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() { func (w *responseWriter) SetReadDeadline(deadline time.Time) error { st := w.rws.stream - if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) { + if !deadline.IsZero() && deadline.Before(time.Now()) { // If we're setting a deadline in the past, reset the stream immediately // so writes after SetWriteDeadline returns will fail. st.onReadTimeout() @@ -2840,9 +2831,9 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error { if deadline.IsZero() { st.readDeadline = nil } else if st.readDeadline == nil { - st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout) + st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout) } else { - st.readDeadline.Reset(deadline.Sub(sc.srv.now())) + st.readDeadline.Reset(deadline.Sub(time.Now())) } }) return nil @@ -2850,7 +2841,7 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error { func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { st := w.rws.stream - if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) { + if !deadline.IsZero() && deadline.Before(time.Now()) { // If we're setting a deadline in the past, reset the stream immediately // so writes after SetWriteDeadline returns will fail. st.onWriteTimeout() @@ -2866,9 +2857,9 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { if deadline.IsZero() { st.writeDeadline = nil } else if st.writeDeadline == nil { - st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout) + st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout) } else { - st.writeDeadline.Reset(deadline.Sub(sc.srv.now())) + st.writeDeadline.Reset(deadline.Sub(time.Now())) } }) return nil @@ -3147,7 +3138,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error { method: opts.Method, url: u, header: cloneHeader(opts.Header), - done: errChanPool.Get().(chan error), + done: sc.srv.state.getErrChan(), } select { @@ -3164,7 +3155,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error { case <-st.cw: return errStreamClosed case err := <-msg.done: - errChanPool.Put(msg.done) + sc.srv.state.putErrChan(msg.done) return err } } diff --git a/http2/server_push_test.go b/http2/server_push_test.go index 69e4c3b12d..ea0a1b260c 100644 --- a/http2/server_push_test.go +++ b/http2/server_push_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -10,14 +12,14 @@ import ( "io" "net/http" "reflect" - "runtime" "strconv" - "sync" "testing" + "testing/synctest" "time" ) -func TestServer_Push_Success(t *testing.T) { +func TestServer_Push_Success(t *testing.T) { synctestTest(t, testServer_Push_Success) } +func testServer_Push_Success(t testing.TB) { const ( mainBody = "index page" pushedBody = "pushed page" @@ -242,7 +244,8 @@ func TestServer_Push_Success(t *testing.T) { } } -func TestServer_Push_SuccessNoRace(t *testing.T) { +func TestServer_Push_SuccessNoRace(t *testing.T) { synctestTest(t, testServer_Push_SuccessNoRace) } +func testServer_Push_SuccessNoRace(t testing.TB) { // Regression test for issue #18326. Ensure the request handler can mutate // pushed request headers without racing with the PUSH_PROMISE write. errc := make(chan error, 2) @@ -287,6 +290,9 @@ func TestServer_Push_SuccessNoRace(t *testing.T) { } func TestServer_Push_RejectRecursivePush(t *testing.T) { + synctestTest(t, testServer_Push_RejectRecursivePush) +} +func testServer_Push_RejectRecursivePush(t testing.TB) { // Expect two requests, but might get three if there's a bug and the second push succeeds. errc := make(chan error, 3) handler := func(w http.ResponseWriter, r *http.Request) error { @@ -323,6 +329,11 @@ func TestServer_Push_RejectRecursivePush(t *testing.T) { } func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) { + synctestTest(t, func(t testing.TB) { + testServer_Push_RejectSingleRequest_Bubble(t, doPush, settings...) + }) +} +func testServer_Push_RejectSingleRequest_Bubble(t testing.TB, doPush func(http.Pusher, *http.Request) error, settings ...Setting) { // Expect one request, but might get two if there's a bug and the push succeeds. errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -426,6 +437,9 @@ func TestServer_Push_RejectForbiddenHeader(t *testing.T) { } func TestServer_Push_StateTransitions(t *testing.T) { + synctestTest(t, testServer_Push_StateTransitions) +} +func testServer_Push_StateTransitions(t testing.TB) { const body = "foo" gotPromise := make(chan bool) @@ -479,7 +493,9 @@ func TestServer_Push_StateTransitions(t *testing.T) { } func TestServer_Push_RejectAfterGoAway(t *testing.T) { - var readyOnce sync.Once + synctestTest(t, testServer_Push_RejectAfterGoAway) +} +func testServer_Push_RejectAfterGoAway(t testing.TB) { ready := make(chan struct{}) errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -495,30 +511,15 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) { // Send GOAWAY and wait for it to be processed. st.fr.WriteGoAway(1, ErrCodeNo, nil) - go func() { - for { - select { - case <-ready: - return - default: - if runtime.GOARCH == "wasm" { - // Work around https://go.dev/issue/65178 to avoid goroutine starvation. - runtime.Gosched() - } - } - st.sc.serveMsgCh <- func(loopNum int) { - if !st.sc.pushEnabled { - readyOnce.Do(func() { close(ready) }) - } - } - } - }() + synctest.Wait() + close(ready) if err := <-errc; err != nil { t.Error(err) } } -func TestServer_Push_Underflow(t *testing.T) { +func TestServer_Push_Underflow(t *testing.T) { synctestTest(t, testServer_Push_Underflow) } +func testServer_Push_Underflow(t testing.TB) { // Test for #63511: Send several requests which generate PUSH_PROMISE responses, // verify they all complete successfully. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { diff --git a/http2/server_test.go b/http2/server_test.go index b27a127a5e..c61c53db17 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -26,6 +28,7 @@ import ( "strings" "sync" "testing" + "testing/synctest" "time" "golang.org/x/net/http2/hpack" @@ -67,7 +70,6 @@ func (sb *safeBuffer) Len() int { type serverTester struct { cc net.Conn // client conn t testing.TB - group *synctestGroup h1server *http.Server h2server *Server serverLogBuf safeBuffer // logger for httptest.Server @@ -76,6 +78,9 @@ type serverTester struct { sc *serverConn testConnFramer + callsMu sync.Mutex + calls []*serverHandlerCall + // If http2debug!=2, then we capture Frame debug logs that will be written // to t.Log after a test fails. The read and write logs use separate locks // and buffers so we don't accidentally introduce synchronization between @@ -149,15 +154,9 @@ var optQuiet = func(server *http.Server) { func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { t.Helper() - g := newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)) - t.Cleanup(func() { - g.Close(t) - }) h1server := &http.Server{} - h2server := &Server{ - group: g, - } + h2server := &Server{} tlsState := tls.ConnectionState{ Version: tls.VersionTLS13, ServerName: "go.dev", @@ -177,14 +176,13 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } ConfigureServer(h1server, h2server) - cli, srv := synctestNetPipe(g) - cli.SetReadDeadline(g.Now()) + cli, srv := synctestNetPipe() + cli.SetReadDeadline(time.Now()) cli.autoWait = true st := &serverTester{ t: t, cc: cli, - group: g, h1server: h1server, h2server: h2server, } @@ -193,14 +191,17 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} h1server.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags) } + if handler == nil { + handler = serverTesterHandler{st}.ServeHTTP + } + t.Cleanup(func() { st.Close() - g.AdvanceTime(goAwayTimeout) // give server time to shut down + time.Sleep(goAwayTimeout) // give server time to shut down }) connc := make(chan *serverConn) go func() { - g.Join() h2server.serveConn(&netConnWithConnectionState{ Conn: srv, state: tlsState, @@ -219,7 +220,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} fr: NewFramer(st.cc, st.cc), dec: hpack.NewDecoder(initialHeaderTableSize, nil), } - g.Wait() + synctest.Wait() return st } @@ -232,6 +233,50 @@ func (c *netConnWithConnectionState) ConnectionState() tls.ConnectionState { return c.state } +type serverTesterHandler struct { + st *serverTester +} + +func (h serverTesterHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + call := &serverHandlerCall{ + w: w, + req: req, + ch: make(chan func()), + } + h.st.t.Cleanup(call.exit) + h.st.callsMu.Lock() + h.st.calls = append(h.st.calls, call) + h.st.callsMu.Unlock() + for f := range call.ch { + f() + } +} + +// serverHandlerCall is a call to the server handler's ServeHTTP method. +type serverHandlerCall struct { + w http.ResponseWriter + req *http.Request + closeOnce sync.Once + ch chan func() +} + +// do executes f in the handler's goroutine. +func (call *serverHandlerCall) do(f func(http.ResponseWriter, *http.Request)) { + donec := make(chan struct{}) + call.ch <- func() { + defer close(donec) + f(call.w, call.req) + } + <-donec +} + +// exit causes the handler to return. +func (call *serverHandlerCall) exit() { + call.closeOnce.Do(func() { + close(call.ch) + }) +} + // newServerTesterWithRealConn creates a test server listening on a localhost port. // Mostly superseded by newServerTester, which creates a test server using a fake // net.Conn and synthetic time. This function is still around because some benchmarks @@ -333,14 +378,13 @@ func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts .. // sync waits for all goroutines to idle. func (st *serverTester) sync() { - if st.group != nil { - st.group.Wait() - } + synctest.Wait() } // advance advances synthetic time by a duration. func (st *serverTester) advance(d time.Duration) { - st.group.AdvanceTime(d) + time.Sleep(d) + synctest.Wait() } func (st *serverTester) authority() string { @@ -357,6 +401,19 @@ func (st *serverTester) addLogFilter(phrase string) { st.logFilter = append(st.logFilter, phrase) } +func (st *serverTester) nextHandlerCall() *serverHandlerCall { + st.t.Helper() + synctest.Wait() + st.callsMu.Lock() + defer st.callsMu.Unlock() + if len(st.calls) == 0 { + st.t.Fatal("expected server handler call, got none") + } + call := st.calls[0] + st.calls = st.calls[1:] + return call +} + func (st *serverTester) stream(id uint32) *stream { ch := make(chan *stream, 1) st.sc.serveMsgCh <- func(int) { @@ -383,23 +440,6 @@ func (st *serverTester) loopNum() int { return <-lastc } -// awaitIdle heuristically awaits for the server conn's select loop to be idle. -// The heuristic is that the server connection's serve loop must schedule -// 50 times in a row without any channel sends or receives occurring. -func (st *serverTester) awaitIdle() { - remain := 50 - last := st.loopNum() - for remain > 0 { - n := st.loopNum() - if n == last+1 { - remain-- - } else { - remain = 50 - } - last = n - } -} - func (st *serverTester) Close() { if st.t.Failed() { st.frameReadLogMu.Lock() @@ -591,30 +631,23 @@ func (st *serverTester) bodylessReq1(headers ...string) { }) } -func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) { +func (st *serverTester) wantConnFlowControlConsumed(consumed int32) { conf := configFromServer(st.sc.hs, st.sc.srv) - var initial int32 - if streamID == 0 { - initial = conf.MaxUploadBufferPerConnection - } else { - initial = conf.MaxUploadBufferPerStream - } donec := make(chan struct{}) st.sc.sendServeMsg(func(sc *serverConn) { defer close(donec) var avail int32 - if streamID == 0 { - avail = sc.inflow.avail + sc.inflow.unsent - } else { - } + initial := conf.MaxUploadBufferPerConnection + avail = sc.inflow.avail + sc.inflow.unsent if got, want := initial-avail, consumed; got != want { - st.t.Errorf("stream %v flow control consumed: %v, want %v", streamID, got, want) + st.t.Errorf("connection flow control consumed: %v, want %v", got, want) } }) <-donec } -func TestServer(t *testing.T) { +func TestServer(t *testing.T) { synctestTest(t, testServer) } +func testServer(t testing.TB) { gotReq := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Foo", "Bar") @@ -633,7 +666,8 @@ func TestServer(t *testing.T) { <-gotReq } -func TestServer_Request_Get(t *testing.T) { +func TestServer_Request_Get(t *testing.T) { synctestTest(t, testServer_Request_Get) } +func testServer_Request_Get(t testing.TB) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -673,6 +707,9 @@ func TestServer_Request_Get(t *testing.T) { } func TestServer_Request_Get_PathSlashes(t *testing.T) { + synctestTest(t, testServer_Request_Get_PathSlashes) +} +func testServer_Request_Get_PathSlashes(t testing.TB) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -695,6 +732,9 @@ func TestServer_Request_Get_PathSlashes(t *testing.T) { // zero? func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) { + synctestTest(t, testServer_Request_Post_NoContentLength_EndStream) +} +func testServer_Request_Post_NoContentLength_EndStream(t testing.TB) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -716,6 +756,9 @@ func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) { } func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_ImmediateEOF) +} +func testServer_Request_Post_Body_ImmediateEOF(t testing.TB) { testBodyContents(t, -1, "", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -728,6 +771,9 @@ func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) { } func TestServer_Request_Post_Body_OneData(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_OneData) +} +func testServer_Request_Post_Body_OneData(t testing.TB) { const content = "Some content" testBodyContents(t, -1, content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -741,6 +787,9 @@ func TestServer_Request_Post_Body_OneData(t *testing.T) { } func TestServer_Request_Post_Body_TwoData(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_TwoData) +} +func testServer_Request_Post_Body_TwoData(t testing.TB) { const content = "Some content" testBodyContents(t, -1, content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -755,6 +804,9 @@ func TestServer_Request_Post_Body_TwoData(t *testing.T) { } func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_ContentLength_Correct) +} +func testServer_Request_Post_Body_ContentLength_Correct(t testing.TB) { const content = "Some content" testBodyContents(t, int64(len(content)), content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -771,6 +823,9 @@ func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) { } func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_ContentLength_TooLarge) +} +func testServer_Request_Post_Body_ContentLength_TooLarge(t testing.TB) { testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -787,6 +842,9 @@ func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) { } func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_ContentLength_TooSmall) +} +func testServer_Request_Post_Body_ContentLength_TooSmall(t testing.TB) { testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -802,11 +860,11 @@ func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) { // Return flow control bytes back, since the data handler closed // the stream. st.wantRSTStream(1, ErrCodeProtocol) - st.wantFlowControlConsumed(0, 0) + st.wantConnFlowControlConsumed(0) }) } -func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) { +func testBodyContents(t testing.TB, wantContentLength int64, wantBody string, write func(st *serverTester)) { testServerRequest(t, write, func(r *http.Request) { if r.Method != "POST" { t.Errorf("Method = %q; want POST", r.Method) @@ -827,7 +885,7 @@ func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, wr }) } -func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) { +func testBodyContentsFail(t testing.TB, wantContentLength int64, wantReadError string, write func(st *serverTester)) { testServerRequest(t, write, func(r *http.Request) { if r.Method != "POST" { t.Errorf("Method = %q; want POST", r.Method) @@ -850,7 +908,8 @@ func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError s } // Using a Host header, instead of :authority -func TestServer_Request_Get_Host(t *testing.T) { +func TestServer_Request_Get_Host(t *testing.T) { synctestTest(t, testServer_Request_Get_Host) } +func testServer_Request_Get_Host(t testing.TB) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -868,6 +927,9 @@ func TestServer_Request_Get_Host(t *testing.T) { // Using an :authority pseudo-header, instead of Host func TestServer_Request_Get_Authority(t *testing.T) { + synctestTest(t, testServer_Request_Get_Authority) +} +func testServer_Request_Get_Authority(t testing.TB) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -884,6 +946,9 @@ func TestServer_Request_Get_Authority(t *testing.T) { } func TestServer_Request_WithContinuation(t *testing.T) { + synctestTest(t, testServer_Request_WithContinuation) +} +func testServer_Request_WithContinuation(t testing.TB) { wantHeader := http.Header{ "Foo-One": []string{"value-one"}, "Foo-Two": []string{"value-two"}, @@ -931,7 +996,8 @@ func TestServer_Request_WithContinuation(t *testing.T) { } // Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field") -func TestServer_Request_CookieConcat(t *testing.T) { +func TestServer_Request_CookieConcat(t *testing.T) { synctestTest(t, testServer_Request_CookieConcat) } +func testServer_Request_CookieConcat(t testing.TB) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.bodylessReq1( @@ -1053,17 +1119,19 @@ func TestServer_Request_Reject_Authority_Userinfo(t *testing.T) { } func testRejectRequest(t *testing.T, send func(*serverTester)) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - t.Error("server request made it to handler; should've been rejected") - }) - defer st.Close() + synctestTest(t, func(t testing.TB) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + t.Error("server request made it to handler; should've been rejected") + }) + defer st.Close() - st.greet() - send(st) - st.wantRSTStream(1, ErrCodeProtocol) + st.greet() + send(st) + st.wantRSTStream(1, ErrCodeProtocol) + }) } -func newServerTesterForError(t *testing.T) *serverTester { +func newServerTesterForError(t testing.TB) *serverTester { t.Helper() st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("server request made it to handler; should've been rejected") @@ -1076,22 +1144,28 @@ func newServerTesterForError(t *testing.T) *serverTester { // HEADERS or PRIORITY on a stream in this state MUST be treated as a // connection error (Section 5.4.1) of type PROTOCOL_ERROR." func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) { + synctestTest(t, testRejectFrameOnIdle_WindowUpdate) +} +func testRejectFrameOnIdle_WindowUpdate(t testing.TB) { st := newServerTesterForError(t) st.fr.WriteWindowUpdate(123, 456) st.wantGoAway(123, ErrCodeProtocol) } -func TestRejectFrameOnIdle_Data(t *testing.T) { +func TestRejectFrameOnIdle_Data(t *testing.T) { synctestTest(t, testRejectFrameOnIdle_Data) } +func testRejectFrameOnIdle_Data(t testing.TB) { st := newServerTesterForError(t) st.fr.WriteData(123, true, nil) st.wantGoAway(123, ErrCodeProtocol) } -func TestRejectFrameOnIdle_RSTStream(t *testing.T) { +func TestRejectFrameOnIdle_RSTStream(t *testing.T) { synctestTest(t, testRejectFrameOnIdle_RSTStream) } +func testRejectFrameOnIdle_RSTStream(t testing.TB) { st := newServerTesterForError(t) st.fr.WriteRSTStream(123, ErrCodeCancel) st.wantGoAway(123, ErrCodeProtocol) } -func TestServer_Request_Connect(t *testing.T) { +func TestServer_Request_Connect(t *testing.T) { synctestTest(t, testServer_Request_Connect) } +func testServer_Request_Connect(t testing.TB) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1116,6 +1190,9 @@ func TestServer_Request_Connect(t *testing.T) { } func TestServer_Request_Connect_InvalidPath(t *testing.T) { + synctestTest(t, testServer_Request_Connect_InvalidPath) +} +func testServer_Request_Connect_InvalidPath(t testing.TB) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1131,6 +1208,9 @@ func TestServer_Request_Connect_InvalidPath(t *testing.T) { } func TestServer_Request_Connect_InvalidScheme(t *testing.T) { + synctestTest(t, testServer_Request_Connect_InvalidScheme) +} +func testServer_Request_Connect_InvalidScheme(t testing.TB) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1145,7 +1225,8 @@ func TestServer_Request_Connect_InvalidScheme(t *testing.T) { }) } -func TestServer_Ping(t *testing.T) { +func TestServer_Ping(t *testing.T) { synctestTest(t, testServer_Ping) } +func testServer_Ping(t testing.TB) { st := newServerTester(t, nil) defer st.Close() st.greet() @@ -1185,6 +1266,9 @@ func (l *filterListener) Accept() (net.Conn, error) { } func TestServer_MaxQueuedControlFrames(t *testing.T) { + synctestTest(t, testServer_MaxQueuedControlFrames) +} +func testServer_MaxQueuedControlFrames(t testing.TB) { // Goroutine debugging makes this test very slow. disableGoroutineTracking(t) @@ -1201,7 +1285,7 @@ func TestServer_MaxQueuedControlFrames(t *testing.T) { pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} st.fr.WritePing(false, pingData) } - st.group.Wait() + synctest.Wait() // Unblock the server. // It should have closed the connection after exceeding the control frame limit. @@ -1217,7 +1301,8 @@ func TestServer_MaxQueuedControlFrames(t *testing.T) { st.wantClosed() } -func TestServer_RejectsLargeFrames(t *testing.T) { +func TestServer_RejectsLargeFrames(t *testing.T) { synctestTest(t, testServer_RejectsLargeFrames) } +func testServer_RejectsLargeFrames(t testing.TB) { if runtime.GOOS == "windows" || runtime.GOOS == "plan9" || runtime.GOOS == "zos" { t.Skip("see golang.org/issue/13434, golang.org/issue/37321") } @@ -1236,20 +1321,19 @@ func TestServer_RejectsLargeFrames(t *testing.T) { } func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { + synctestTest(t, testServer_Handler_Sends_WindowUpdate) +} +func testServer_Handler_Sends_WindowUpdate(t testing.TB) { // Need to set this to at least twice the initial window size, // or st.greet gets stuck waiting for a WINDOW_UPDATE. // // This also needs to be less than MAX_FRAME_SIZE. const windowSize = 65535 * 2 - puppet := newHandlerPuppet() - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - puppet.act(w, r) - }, func(s *Server) { + st := newServerTester(t, nil, func(s *Server) { s.MaxUploadBufferPerConnection = windowSize s.MaxUploadBufferPerStream = windowSize }) defer st.Close() - defer puppet.done() st.greet() st.writeHeaders(HeadersFrameParam{ @@ -1258,13 +1342,14 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { EndStream: false, // data coming EndHeaders: true, }) + call := st.nextHandlerCall() // Write less than half the max window of data and consume it. // The server doesn't return flow control yet, buffering the 1024 bytes to // combine with a future update. data := make([]byte, windowSize) st.writeData(1, false, data[:1024]) - puppet.do(readBodyHandler(t, string(data[:1024]))) + call.do(readBodyHandler(t, string(data[:1024]))) // Write up to the window limit. // The server returns the buffered credit. @@ -1273,7 +1358,7 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { st.wantWindowUpdate(1, 1024) // The handler consumes the data and the server returns credit. - puppet.do(readBodyHandler(t, string(data[1024:]))) + call.do(readBodyHandler(t, string(data[1024:]))) st.wantWindowUpdate(0, windowSize-1024) st.wantWindowUpdate(1, windowSize-1024) } @@ -1281,16 +1366,15 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { // the version of the TestServer_Handler_Sends_WindowUpdate with padding. // See golang.org/issue/16556 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { + synctestTest(t, testServer_Handler_Sends_WindowUpdate_Padding) +} +func testServer_Handler_Sends_WindowUpdate_Padding(t testing.TB) { const windowSize = 65535 * 2 - puppet := newHandlerPuppet() - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - puppet.act(w, r) - }, func(s *Server) { + st := newServerTester(t, nil, func(s *Server) { s.MaxUploadBufferPerConnection = windowSize s.MaxUploadBufferPerStream = windowSize }) defer st.Close() - defer puppet.done() st.greet() st.writeHeaders(HeadersFrameParam{ @@ -1299,6 +1383,7 @@ func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { EndStream: false, EndHeaders: true, }) + call := st.nextHandlerCall() // Write half a window of data, with some padding. // The server doesn't return the padding yet, buffering the 5 bytes to combine @@ -1310,12 +1395,15 @@ func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { // The handler consumes the body. // The server returns flow control for the body and padding // (4 bytes of padding + 1 byte of length). - puppet.do(readBodyHandler(t, string(data))) + call.do(readBodyHandler(t, string(data))) st.wantWindowUpdate(0, uint32(len(data)+1+len(pad))) st.wantWindowUpdate(1, uint32(len(data)+1+len(pad))) } func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) { + synctestTest(t, testServer_Send_GoAway_After_Bogus_WindowUpdate) +} +func testServer_Send_GoAway_After_Bogus_WindowUpdate(t testing.TB) { st := newServerTester(t, nil) defer st.Close() st.greet() @@ -1326,6 +1414,9 @@ func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) { } func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) { + synctestTest(t, testServer_Send_RstStream_After_Bogus_WindowUpdate) +} +func testServer_Send_RstStream_After_Bogus_WindowUpdate(t testing.TB) { inHandler := make(chan bool) blockHandler := make(chan bool) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -1352,7 +1443,7 @@ func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) { // testServerPostUnblock sends a hanging POST with unsent data to handler, // then runs fn once in the handler, and verifies that the error returned from // handler is acceptable. It fails if takes over 5 seconds for handler to exit. -func testServerPostUnblock(t *testing.T, +func testServerPostUnblock(t testing.TB, handler func(http.ResponseWriter, *http.Request) error, fn func(*serverTester), checkErr func(error), @@ -1380,6 +1471,9 @@ func testServerPostUnblock(t *testing.T, } func TestServer_RSTStream_Unblocks_Read(t *testing.T) { + synctestTest(t, testServer_RSTStream_Unblocks_Read) +} +func testServer_RSTStream_Unblocks_Read(t testing.TB) { testServerPostUnblock(t, func(w http.ResponseWriter, r *http.Request) (err error) { _, err = r.Body.Read(make([]byte, 1)) @@ -1407,11 +1501,11 @@ func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) { n = 5 } for i := 0; i < n; i++ { - testServer_RSTStream_Unblocks_Header_Write(t) + synctestTest(t, testServer_RSTStream_Unblocks_Header_Write) } } -func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) { +func testServer_RSTStream_Unblocks_Header_Write(t testing.TB) { inHandler := make(chan bool, 1) unblockHandler := make(chan bool, 1) headerWritten := make(chan bool, 1) @@ -1440,12 +1534,15 @@ func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) { t.Fatal(err) } wroteRST <- true - st.awaitIdle() + synctest.Wait() <-headerWritten unblockHandler <- true } func TestServer_DeadConn_Unblocks_Read(t *testing.T) { + synctestTest(t, testServer_DeadConn_Unblocks_Read) +} +func testServer_DeadConn_Unblocks_Read(t testing.TB) { testServerPostUnblock(t, func(w http.ResponseWriter, r *http.Request) (err error) { _, err = r.Body.Read(make([]byte, 1)) @@ -1466,6 +1563,9 @@ var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error { } func TestServer_CloseNotify_After_RSTStream(t *testing.T) { + synctestTest(t, testServer_CloseNotify_After_RSTStream) +} +func testServer_CloseNotify_After_RSTStream(t testing.TB) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { t.Fatal(err) @@ -1474,6 +1574,9 @@ func TestServer_CloseNotify_After_RSTStream(t *testing.T) { } func TestServer_CloseNotify_After_ConnClose(t *testing.T) { + synctestTest(t, testServer_CloseNotify_After_ConnClose) +} +func testServer_CloseNotify_After_ConnClose(t testing.TB) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil) } @@ -1481,13 +1584,17 @@ func TestServer_CloseNotify_After_ConnClose(t *testing.T) { // problem that's unrelated to them explicitly canceling it (which is // TestServer_CloseNotify_After_RSTStream above) func TestServer_CloseNotify_After_StreamError(t *testing.T) { + synctestTest(t, testServer_CloseNotify_After_StreamError) +} +func testServer_CloseNotify_After_StreamError(t testing.TB) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { // data longer than declared Content-Length => stream error st.writeData(1, true, []byte("1234")) }, nil, "content-length", "3") } -func TestServer_StateTransitions(t *testing.T) { +func TestServer_StateTransitions(t *testing.T) { synctestTest(t, testServer_StateTransitions) } +func testServer_StateTransitions(t testing.TB) { var st *serverTester inHandler := make(chan bool) writeData := make(chan bool) @@ -1544,6 +1651,9 @@ func TestServer_StateTransitions(t *testing.T) { // test HEADERS w/o EndHeaders + another HEADERS (should get rejected) func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_Headers) +} +func testServer_Rejects_HeadersNoEnd_Then_Headers(t testing.TB) { st := newServerTesterForError(t) st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1562,6 +1672,9 @@ func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) { // test HEADERS w/o EndHeaders + PING (should get rejected) func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_Ping) +} +func testServer_Rejects_HeadersNoEnd_Then_Ping(t testing.TB) { st := newServerTesterForError(t) st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1577,6 +1690,9 @@ func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) { // test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected) func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersEnd_Then_Continuation) +} +func testServer_Rejects_HeadersEnd_Then_Continuation(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optQuiet) st.greet() st.writeHeaders(HeadersFrameParam{ @@ -1597,6 +1713,9 @@ func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) { // test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream) +} +func testServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t testing.TB) { st := newServerTesterForError(t) st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1611,7 +1730,8 @@ func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) } // No HEADERS on stream 0. -func TestServer_Rejects_Headers0(t *testing.T) { +func TestServer_Rejects_Headers0(t *testing.T) { synctestTest(t, testServer_Rejects_Headers0) } +func testServer_Rejects_Headers0(t testing.TB) { st := newServerTesterForError(t) st.fr.AllowIllegalWrites = true st.writeHeaders(HeadersFrameParam{ @@ -1625,6 +1745,9 @@ func TestServer_Rejects_Headers0(t *testing.T) { // No CONTINUATION on stream 0. func TestServer_Rejects_Continuation0(t *testing.T) { + synctestTest(t, testServer_Rejects_Continuation0) +} +func testServer_Rejects_Continuation0(t testing.TB) { st := newServerTesterForError(t) st.fr.AllowIllegalWrites = true if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil { @@ -1634,7 +1757,8 @@ func TestServer_Rejects_Continuation0(t *testing.T) { } // No PRIORITY on stream 0. -func TestServer_Rejects_Priority0(t *testing.T) { +func TestServer_Rejects_Priority0(t *testing.T) { synctestTest(t, testServer_Rejects_Priority0) } +func testServer_Rejects_Priority0(t testing.TB) { st := newServerTesterForError(t) st.fr.AllowIllegalWrites = true st.writePriority(0, PriorityParam{StreamDep: 1}) @@ -1643,6 +1767,9 @@ func TestServer_Rejects_Priority0(t *testing.T) { // No HEADERS frame with a self-dependence. func TestServer_Rejects_HeadersSelfDependence(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersSelfDependence) +} +func testServer_Rejects_HeadersSelfDependence(t testing.TB) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.fr.AllowIllegalWrites = true st.writeHeaders(HeadersFrameParam{ @@ -1657,13 +1784,17 @@ func TestServer_Rejects_HeadersSelfDependence(t *testing.T) { // No PRIORITY frame with a self-dependence. func TestServer_Rejects_PrioritySelfDependence(t *testing.T) { + synctestTest(t, testServer_Rejects_PrioritySelfDependence) +} +func testServer_Rejects_PrioritySelfDependence(t testing.TB) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.fr.AllowIllegalWrites = true st.writePriority(1, PriorityParam{StreamDep: 1}) }) } -func TestServer_Rejects_PushPromise(t *testing.T) { +func TestServer_Rejects_PushPromise(t *testing.T) { synctestTest(t, testServer_Rejects_PushPromise) } +func testServer_Rejects_PushPromise(t testing.TB) { st := newServerTesterForError(t) pp := PushPromiseParam{ StreamID: 1, @@ -1677,7 +1808,7 @@ func TestServer_Rejects_PushPromise(t *testing.T) { // testServerRejectsStream tests that the server sends a RST_STREAM with the provided // error code after a client sends a bogus request. -func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) { +func testServerRejectsStream(t testing.TB, code ErrCode, writeReq func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) defer st.Close() st.greet() @@ -1688,7 +1819,7 @@ func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTe // testServerRequest sets up an idle HTTP/2 connection and lets you // write a single request with writeReq, and then verify that the // *http.Request is built correctly in checkReq. -func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) { +func testServerRequest(t testing.TB, writeReq func(*serverTester), checkReq func(*http.Request)) { gotReq := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { if r.Body == nil { @@ -1706,7 +1837,8 @@ func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func func getSlash(st *serverTester) { st.bodylessReq1() } -func TestServer_Response_NoData(t *testing.T) { +func TestServer_Response_NoData(t *testing.T) { synctestTest(t, testServer_Response_NoData) } +func testServer_Response_NoData(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { // Nothing. return nil @@ -1720,6 +1852,9 @@ func TestServer_Response_NoData(t *testing.T) { } func TestServer_Response_NoData_Header_FooBar(t *testing.T) { + synctestTest(t, testServer_Response_NoData_Header_FooBar) +} +func testServer_Response_NoData_Header_FooBar(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Foo-Bar", "some-value") return nil @@ -1740,6 +1875,9 @@ func TestServer_Response_NoData_Header_FooBar(t *testing.T) { // Reject content-length headers containing a sign. // See https://golang.org/issue/39017 func TestServerIgnoresContentLengthSignWhenWritingChunks(t *testing.T) { + synctestTest(t, testServerIgnoresContentLengthSignWhenWritingChunks) +} +func testServerIgnoresContentLengthSignWhenWritingChunks(t testing.TB) { tests := []struct { name string cl string @@ -1827,7 +1965,7 @@ func TestServerRejectsContentLengthWithSignNewRequests(t *testing.T) { for _, tt := range tests { tt := tt - t.Run(tt.name, func(t *testing.T) { + synctestSubtest(t, tt.name, func(t testing.TB) { writeReq := func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -1839,7 +1977,7 @@ func TestServerRejectsContentLengthWithSignNewRequests(t *testing.T) { } checkReq := func(r *http.Request) { if r.ContentLength != tt.wantCL { - t.Fatalf("Got: %q\nWant: %q", r.ContentLength, tt.wantCL) + t.Fatalf("Got: %d\nWant: %d", r.ContentLength, tt.wantCL) } } testServerRequest(t, writeReq, checkReq) @@ -1848,6 +1986,9 @@ func TestServerRejectsContentLengthWithSignNewRequests(t *testing.T) { } func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) { + synctestTest(t, testServer_Response_Data_Sniff_DoesntOverride) +} +func testServer_Response_Data_Sniff_DoesntOverride(t testing.TB) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Content-Type", "foo/bar") @@ -1873,6 +2014,9 @@ func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) { } func TestServer_Response_TransferEncoding_chunked(t *testing.T) { + synctestTest(t, testServer_Response_TransferEncoding_chunked) +} +func testServer_Response_TransferEncoding_chunked(t testing.TB) { const msg = "hi" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Transfer-Encoding", "chunked") // should be stripped @@ -1894,6 +2038,9 @@ func TestServer_Response_TransferEncoding_chunked(t *testing.T) { // Header accessed only after the initial write. func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) { + synctestTest(t, testServer_Response_Data_IgnoreHeaderAfterWrite_After) +} +func testServer_Response_Data_IgnoreHeaderAfterWrite_After(t testing.TB) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.WriteString(w, msg) @@ -1915,6 +2062,9 @@ func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) { // Header accessed before the initial write and later mutated. func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) { + synctestTest(t, testServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite) +} +func testServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t testing.TB) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("foo", "proper value") @@ -1937,6 +2087,9 @@ func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) { } func TestServer_Response_Data_SniffLenType(t *testing.T) { + synctestTest(t, testServer_Response_Data_SniffLenType) +} +func testServer_Response_Data_SniffLenType(t testing.TB) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.WriteString(w, msg) @@ -1961,6 +2114,9 @@ func TestServer_Response_Data_SniffLenType(t *testing.T) { } func TestServer_Response_Header_Flush_MidWrite(t *testing.T) { + synctestTest(t, testServer_Response_Header_Flush_MidWrite) +} +func testServer_Response_Header_Flush_MidWrite(t testing.TB) { const msg = "this is HTML" const msg2 = ", and this is the next chunk" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { @@ -1992,7 +2148,8 @@ func TestServer_Response_Header_Flush_MidWrite(t *testing.T) { }) } -func TestServer_Response_LargeWrite(t *testing.T) { +func TestServer_Response_LargeWrite(t *testing.T) { synctestTest(t, testServer_Response_LargeWrite) } +func testServer_Response_LargeWrite(t testing.TB) { const size = 1 << 20 const maxFrameSize = 16 << 10 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { @@ -2058,6 +2215,9 @@ func TestServer_Response_LargeWrite(t *testing.T) { // Test that the handler can't write more than the client allows func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) { + synctestTest(t, testServer_Response_LargeWrite_FlowControlled) +} +func testServer_Response_LargeWrite_FlowControlled(t testing.TB) { // Make these reads. Before each read, the client adds exactly enough // flow-control to satisfy the read. Numbers chosen arbitrarily. reads := []int{123, 1, 13, 127} @@ -2112,6 +2272,9 @@ func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) { // Test that the handler blocked in a Write is unblocked if the server sends a RST_STREAM. func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) { + synctestTest(t, testServer_Response_RST_Unblocks_LargeWrite) +} +func testServer_Response_RST_Unblocks_LargeWrite(t testing.TB) { const size = 1 << 20 const maxFrameSize = 16 << 10 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { @@ -2144,6 +2307,9 @@ func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) { } func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) { + synctestTest(t, testServer_Response_Empty_Data_Not_FlowControlled) +} +func testServer_Response_Empty_Data_Not_FlowControlled(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.(http.Flusher).Flush() // Nothing; send empty DATA @@ -2171,6 +2337,9 @@ func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) { } func TestServer_Response_Automatic100Continue(t *testing.T) { + synctestTest(t, testServer_Response_Automatic100Continue) +} +func testServer_Response_Automatic100Continue(t testing.TB) { const msg = "foo" const reply = "bar" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { @@ -2222,6 +2391,9 @@ func TestServer_Response_Automatic100Continue(t *testing.T) { } func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) { + synctestTest(t, testServer_HandlerWriteErrorOnDisconnect) +} +func testServer_HandlerWriteErrorOnDisconnect(t testing.TB) { errc := make(chan error, 1) testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { p := []byte("some data.\n") @@ -2250,16 +2422,17 @@ func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) { } func TestServer_Rejects_Too_Many_Streams(t *testing.T) { - const testPath = "/some/path" - + synctestTest(t, testServer_Rejects_Too_Many_Streams) +} +func testServer_Rejects_Too_Many_Streams(t testing.TB) { inHandler := make(chan uint32) leaveHandler := make(chan bool) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - id := w.(*responseWriter).rws.stream.id - inHandler <- id - if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath { - t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath) + var streamID uint32 + if _, err := fmt.Sscanf(r.URL.Path, "/%d", &streamID); err != nil { + t.Errorf("parsing %q: %v", r.URL.Path, err) } + inHandler <- streamID <-leaveHandler }) defer st.Close() @@ -2274,12 +2447,14 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) { defer func() { nextStreamID += 2 }() return nextStreamID } - sendReq := func(id uint32, headers ...string) { + sendReq := func(id uint32) { st.writeHeaders(HeadersFrameParam{ - StreamID: id, - BlockFragment: st.encodeHeader(headers...), - EndStream: true, - EndHeaders: true, + StreamID: id, + BlockFragment: st.encodeHeader( + ":path", fmt.Sprintf("/%v", id), + ), + EndStream: true, + EndHeaders: true, }) } for i := 0; i < defaultMaxStreams; i++ { @@ -2296,7 +2471,7 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) { // (It's also sent as a CONTINUATION, to verify we still track the decoder context, // even if we're rejecting it) rejectID := streamID() - headerBlock := st.encodeHeader(":path", testPath) + headerBlock := st.encodeHeader(":path", fmt.Sprintf("/%v", rejectID)) frag1, frag2 := headerBlock[:3], headerBlock[3:] st.writeHeaders(HeadersFrameParam{ StreamID: rejectID, @@ -2320,7 +2495,7 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) { // And now another stream should be able to start: goodID := streamID() - sendReq(goodID, ":path", testPath) + sendReq(goodID) if got := <-inHandler; got != goodID { t.Errorf("Got stream %d; want %d", got, goodID) } @@ -2328,6 +2503,9 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) { // So many response headers that the server needs to use CONTINUATION frames: func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) { + synctestTest(t, testServer_Response_ManyHeaders_With_Continuation) +} +func testServer_Response_ManyHeaders_With_Continuation(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { h := w.Header() for i := 0; i < 5000; i++ { @@ -2362,6 +2540,9 @@ func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) { // defer sc.closeAllStreamsOnConnClose) when the serverConn serve loop // ended. func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { + synctestTest(t, testServer_NoCrash_HandlerClose_Then_ClientClose) +} +func testServer_NoCrash_HandlerClose_Then_ClientClose(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { // nothing return nil @@ -2396,7 +2577,7 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { // We should have our flow control bytes back, // since the handler didn't get them. - st.wantFlowControlConsumed(0, 0) + st.wantConnFlowControlConsumed(0) // Set up a bunch of machinery to record the panic we saw // previously. @@ -2416,7 +2597,7 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { // Now force the serve loop to end, via closing the connection. st.cc.Close() - <-st.sc.doneServing + synctest.Wait() panMu.Lock() got := panicVal @@ -2431,17 +2612,20 @@ func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) } func testRejectTLS(t *testing.T, version uint16) { - st := newServerTester(t, nil, func(state *tls.ConnectionState) { - // As of 1.18 the default minimum Go TLS version is - // 1.2. In order to test rejection of lower versions, - // manually set the version to 1.0 - state.Version = version + synctestTest(t, func(t testing.TB) { + st := newServerTester(t, nil, func(state *tls.ConnectionState) { + // As of 1.18 the default minimum Go TLS version is + // 1.2. In order to test rejection of lower versions, + // manually set the version to 1.0 + state.Version = version + }) + defer st.Close() + st.wantGoAway(0, ErrCodeInadequateSecurity) }) - defer st.Close() - st.wantGoAway(0, ErrCodeInadequateSecurity) } -func TestServer_Rejects_TLSBadCipher(t *testing.T) { +func TestServer_Rejects_TLSBadCipher(t *testing.T) { synctestTest(t, testServer_Rejects_TLSBadCipher) } +func testServer_Rejects_TLSBadCipher(t testing.TB) { st := newServerTester(t, nil, func(state *tls.ConnectionState) { state.Version = tls.VersionTLS12 state.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA @@ -2451,6 +2635,9 @@ func TestServer_Rejects_TLSBadCipher(t *testing.T) { } func TestServer_Advertises_Common_Cipher(t *testing.T) { + synctestTest(t, testServer_Advertises_Common_Cipher) +} +func testServer_Advertises_Common_Cipher(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { }, func(srv *http.Server) { // Have the server configured with no specific cipher suites. @@ -2508,7 +2695,7 @@ func testServerResponse(t testing.TB, // readBodyHandler returns an http Handler func that reads len(want) // bytes from r.Body and fails t if the contents read were not // the value of want. -func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) { +func readBodyHandler(t testing.TB, want string) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { buf := make([]byte, len(want)) _, err := io.ReadFull(r.Body, buf) @@ -2523,6 +2710,9 @@ func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *h } func TestServer_MaxDecoderHeaderTableSize(t *testing.T) { + synctestTest(t, testServer_MaxDecoderHeaderTableSize) +} +func testServer_MaxDecoderHeaderTableSize(t testing.TB) { wantHeaderTableSize := uint32(initialHeaderTableSize * 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) { s.MaxDecoderHeaderTableSize = wantHeaderTableSize @@ -2546,6 +2736,9 @@ func TestServer_MaxDecoderHeaderTableSize(t *testing.T) { } func TestServer_MaxEncoderHeaderTableSize(t *testing.T) { + synctestTest(t, testServer_MaxEncoderHeaderTableSize) +} +func testServer_MaxEncoderHeaderTableSize(t testing.TB) { wantHeaderTableSize := uint32(initialHeaderTableSize / 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) { s.MaxEncoderHeaderTableSize = wantHeaderTableSize @@ -2560,7 +2753,8 @@ func TestServer_MaxEncoderHeaderTableSize(t *testing.T) { } // Issue 12843 -func TestServerDoS_MaxHeaderListSize(t *testing.T) { +func TestServerDoS_MaxHeaderListSize(t *testing.T) { synctestTest(t, testServerDoS_MaxHeaderListSize) } +func testServerDoS_MaxHeaderListSize(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) defer st.Close() @@ -2630,6 +2824,9 @@ func TestServerDoS_MaxHeaderListSize(t *testing.T) { } func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) { + synctestTest(t, testServer_Response_Stream_With_Missing_Trailer) +} +func testServer_Response_Stream_With_Missing_Trailer(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Trailer", "test-trailer") return nil @@ -2647,7 +2844,8 @@ func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) { }) } -func TestCompressionErrorOnWrite(t *testing.T) { +func TestCompressionErrorOnWrite(t *testing.T) { synctestTest(t, testCompressionErrorOnWrite) } +func testCompressionErrorOnWrite(t testing.TB) { const maxStrLen = 8 << 10 var serverConfig *http.Server st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -2709,7 +2907,8 @@ func TestCompressionErrorOnWrite(t *testing.T) { st.wantGoAway(3, ErrCodeCompression) } -func TestCompressionErrorOnClose(t *testing.T) { +func TestCompressionErrorOnClose(t *testing.T) { synctestTest(t, testCompressionErrorOnClose) } +func testCompressionErrorOnClose(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. }) @@ -2729,7 +2928,8 @@ func TestCompressionErrorOnClose(t *testing.T) { } // test that a server handler can read trailers from a client -func TestServerReadsTrailers(t *testing.T) { +func TestServerReadsTrailers(t *testing.T) { synctestTest(t, testServerReadsTrailers) } +func testServerReadsTrailers(t testing.TB) { const testBody = "some test body" writeReq := func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -2780,10 +2980,18 @@ func TestServerReadsTrailers(t *testing.T) { } // test that a server handler can send trailers -func TestServerWritesTrailers_WithFlush(t *testing.T) { testServerWritesTrailers(t, true) } -func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) } +func TestServerWritesTrailers_WithFlush(t *testing.T) { + synctestTest(t, func(t testing.TB) { + testServerWritesTrailers(t, true) + }) +} +func TestServerWritesTrailers_WithoutFlush(t *testing.T) { + synctestTest(t, func(t testing.TB) { + testServerWritesTrailers(t, false) + }) +} -func testServerWritesTrailers(t *testing.T, withFlush bool) { +func testServerWritesTrailers(t testing.TB, withFlush bool) { // See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") @@ -2851,6 +3059,9 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) { } func TestServerWritesUndeclaredTrailers(t *testing.T) { + synctestTest(t, testServerWritesUndeclaredTrailers) +} +func testServerWritesUndeclaredTrailers(t testing.TB) { const trailer = "Trailer-Header" const value = "hi1" ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { @@ -2876,6 +3087,9 @@ func TestServerWritesUndeclaredTrailers(t *testing.T) { // validate transmitted header field names & values // golang.org/issue/14048 func TestServerDoesntWriteInvalidHeaders(t *testing.T) { + synctestTest(t, testServerDoesntWriteInvalidHeaders) +} +func testServerDoesntWriteInvalidHeaders(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Add("OK1", "x") w.Header().Add("Bad:Colon", "x") // colon (non-token byte) in key @@ -3054,7 +3268,8 @@ func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) { // go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53 // Verify we don't hang. -func TestIssue53(t *testing.T) { +func TestIssue53(t *testing.T) { synctestTest(t, testIssue53) } +func testIssue53(t testing.TB) { const data = "PRI * HTTP/2.0\r\n\r\nSM" + "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad" s := &http.Server{ @@ -3109,28 +3324,53 @@ func (c *issue53Conn) SetDeadline(t time.Time) error { return nil } func (c *issue53Conn) SetReadDeadline(t time.Time) error { return nil } func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil } +// TestServeConnNilOpts ensures that Server.ServeConn(conn, nil) works. +// // golang.org/issue/33839 -func TestServeConnOptsNilReceiverBehavior(t *testing.T) { - defer func() { - if r := recover(); r != nil { - t.Errorf("got a panic that should not happen: %v", r) - } - }() +func TestServeConnNilOpts(t *testing.T) { synctestTest(t, testServeConnNilOpts) } +func testServeConnNilOpts(t testing.TB) { + // A nil ServeConnOpts uses http.DefaultServeMux as the handler. + var gotRequest string + var mux http.ServeMux + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + gotRequest = r.URL.Path + }) + setForTest(t, &http.DefaultServeMux, &mux) + + srvConn, cliConn := net.Pipe() + defer srvConn.Close() + defer cliConn.Close() + + s2 := &Server{} + go s2.ServeConn(srvConn, nil) + + fr := NewFramer(cliConn, cliConn) + io.WriteString(cliConn, ClientPreface) + fr.WriteSettings() + fr.WriteSettingsAck() + var henc hpackEncoder + const reqPath = "/request" + fr.WriteHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: henc.encodeHeaderRaw(t, + ":method", "GET", + ":path", reqPath, + ":scheme", "https", + ":authority", "foo.com", + ), + EndStream: true, + EndHeaders: true, + }) - var o *ServeConnOpts - if o.context() == nil { - t.Error("o.context should not return nil") - } - if o.baseConfig() == nil { - t.Error("o.baseConfig should not return nil") - } - if o.handler() == nil { - t.Error("o.handler should not return nil") + synctest.Wait() + if got, want := gotRequest, reqPath; got != want { + t.Errorf("got request: %q, want %q", got, want) } } // golang.org/issue/12895 -func TestConfigureServer(t *testing.T) { +func TestConfigureServer(t *testing.T) { synctestTest(t, testConfigureServer) } +func testConfigureServer(t testing.TB) { tests := []struct { name string tlsConfig *tls.Config @@ -3202,6 +3442,9 @@ func TestConfigureServer(t *testing.T) { } func TestServerNoAutoContentLengthOnHead(t *testing.T) { + synctestTest(t, testServerNoAutoContentLengthOnHead) +} +func testServerNoAutoContentLengthOnHead(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. (or smaller than one frame) }) @@ -3224,6 +3467,9 @@ func TestServerNoAutoContentLengthOnHead(t *testing.T) { // golang.org/issue/13495 func TestServerNoDuplicateContentType(t *testing.T) { + synctestTest(t, testServerNoDuplicateContentType) +} +func testServerNoDuplicateContentType(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header()["Content-Type"] = []string{""} fmt.Fprintf(w, "hi") @@ -3248,6 +3494,9 @@ func TestServerNoDuplicateContentType(t *testing.T) { } func TestServerContentLengthCanBeDisabled(t *testing.T) { + synctestTest(t, testServerContentLengthCanBeDisabled) +} +func testServerContentLengthCanBeDisabled(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header()["Content-Length"] = nil fmt.Fprintf(w, "OK") @@ -3271,9 +3520,10 @@ func TestServerContentLengthCanBeDisabled(t *testing.T) { } func disableGoroutineTracking(t testing.TB) { - old := DebugGoroutines - DebugGoroutines = false - t.Cleanup(func() { DebugGoroutines = old }) + disableDebugGoroutines.Store(true) + t.Cleanup(func() { + disableDebugGoroutines.Store(false) + }) } func BenchmarkServer_GetRequest(b *testing.B) { @@ -3349,7 +3599,8 @@ func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs } // golang.org/issue/12737 -- handle any net.Conn, not just // *tls.Conn. -func TestServerHandleCustomConn(t *testing.T) { +func TestServerHandleCustomConn(t *testing.T) { synctestTest(t, testServerHandleCustomConn) } +func testServerHandleCustomConn(t testing.TB) { var s Server c1, c2 := net.Pipe() clientDone := make(chan struct{}) @@ -3414,7 +3665,8 @@ func TestServerHandleCustomConn(t *testing.T) { } // golang.org/issue/14214 -func TestServer_Rejects_ConnHeaders(t *testing.T) { +func TestServer_Rejects_ConnHeaders(t *testing.T) { synctestTest(t, testServer_Rejects_ConnHeaders) } +func testServer_Rejects_ConnHeaders(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("should not get to Handler") }) @@ -3438,7 +3690,7 @@ type hpackEncoder struct { buf bytes.Buffer } -func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte { +func (he *hpackEncoder) encodeHeaderRaw(t testing.TB, headers ...string) []byte { if len(headers)%2 == 1 { panic("odd number of kv args") } @@ -3457,7 +3709,8 @@ func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte return he.buf.Bytes() } -func TestCheckValidHTTP2Request(t *testing.T) { +func TestCheckValidHTTP2Request(t *testing.T) { synctestTest(t, testCheckValidHTTP2Request) } +func testCheckValidHTTP2Request(t testing.TB) { tests := []struct { h http.Header want error @@ -3501,6 +3754,9 @@ func TestCheckValidHTTP2Request(t *testing.T) { // golang.org/issue/14030 func TestExpect100ContinueAfterHandlerWrites(t *testing.T) { + synctestTest(t, testExpect100ContinueAfterHandlerWrites) +} +func testExpect100ContinueAfterHandlerWrites(t testing.TB) { const msg = "Hello" const msg2 = "World" @@ -3578,7 +3834,7 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) { }, }, } { - t.Run(tt.name, func(t *testing.T) { + synctestSubtest(t, tt.name, func(t testing.TB) { unblock := make(chan bool, 1) defer close(unblock) @@ -3618,6 +3874,9 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) { } func TestServerReturnsStreamAndConnFlowControlOnBodyClose(t *testing.T) { + synctestTest(t, testServerReturnsStreamAndConnFlowControlOnBodyClose) +} +func testServerReturnsStreamAndConnFlowControlOnBodyClose(t testing.TB) { unblockHandler := make(chan struct{}) defer close(unblockHandler) @@ -3649,7 +3908,8 @@ func TestServerReturnsStreamAndConnFlowControlOnBodyClose(t *testing.T) { }) } -func TestServerIdleTimeout(t *testing.T) { +func TestServerIdleTimeout(t *testing.T) { synctestTest(t, testServerIdleTimeout) } +func testServerIdleTimeout(t testing.TB) { if testing.Short() { t.Skip("skipping in short mode") } @@ -3666,6 +3926,9 @@ func TestServerIdleTimeout(t *testing.T) { } func TestServerIdleTimeout_AfterRequest(t *testing.T) { + synctestTest(t, testServerIdleTimeout_AfterRequest) +} +func testServerIdleTimeout_AfterRequest(t testing.TB) { if testing.Short() { t.Skip("skipping in short mode") } @@ -3676,7 +3939,7 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) { var st *serverTester st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - st.group.Sleep(requestTimeout) + time.Sleep(requestTimeout) }, func(h2s *Server) { h2s.IdleTimeout = idleTimeout }) @@ -3702,28 +3965,46 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) { // grpc-go closes the Request.Body currently with a Read. // Verify that it doesn't race. // See https://github.com/grpc/grpc-go/pull/938 -func TestRequestBodyReadCloseRace(t *testing.T) { - for i := 0; i < 100; i++ { - body := &requestBody{ - pipe: &pipe{ - b: new(bytes.Buffer), - }, - } - body.pipe.CloseWithError(io.EOF) +func TestRequestBodyReadCloseRace(t *testing.T) { synctestTest(t, testRequestBodyReadCloseRace) } +func testRequestBodyReadCloseRace(t testing.TB) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + go r.Body.Close() + io.Copy(io.Discard, r.Body) + }) + st.greet() + + data := make([]byte, 1024) + for i := range 100 { + streamID := uint32(1 + (i * 2)) // clients send odd numbers + st.writeHeaders(HeadersFrameParam{ + StreamID: streamID, + BlockFragment: st.encodeHeader(), + EndHeaders: true, + }) + st.writeData(1, false, data) - done := make(chan bool, 1) - buf := make([]byte, 10) - go func() { - time.Sleep(1 * time.Millisecond) - body.Close() - done <- true - }() - body.Read(buf) - <-done + for { + // Look for a RST_STREAM frame. + // Skip over anything else (HEADERS and WINDOW_UPDATE). + fr := st.readFrame() + if fr == nil { + t.Fatalf("got no RSTStreamFrame, want one") + } + rst, ok := fr.(*RSTStreamFrame) + if !ok { + continue + } + // We can get NO or STREAM_CLOSED depending on scheduling. + if rst.ErrCode != ErrCodeNo && rst.ErrCode != ErrCodeStreamClosed { + t.Fatalf("got RSTStreamFrame with error code %v, want ErrCodeNo or ErrCodeStreamClosed", rst.ErrCode) + } + break + } } } -func TestIssue20704Race(t *testing.T) { +func TestIssue20704Race(t *testing.T) { synctestTest(t, testIssue20704Race) } +func testIssue20704Race(t testing.TB) { if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" { t.Skip("skipping in short mode") } @@ -3756,7 +4037,8 @@ func TestIssue20704Race(t *testing.T) { } } -func TestServer_Rejects_TooSmall(t *testing.T) { +func TestServer_Rejects_TooSmall(t *testing.T) { synctestTest(t, testServer_Rejects_TooSmall) } +func testServer_Rejects_TooSmall(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.ReadAll(r.Body) return nil @@ -3772,13 +4054,16 @@ func TestServer_Rejects_TooSmall(t *testing.T) { }) st.writeData(1, true, []byte("12345")) st.wantRSTStream(1, ErrCodeProtocol) - st.wantFlowControlConsumed(0, 0) + st.wantConnFlowControlConsumed(0) }) } // Tests that a handler setting "Connection: close" results in a GOAWAY being sent, // and the connection still completing. func TestServerHandlerConnectionClose(t *testing.T) { + synctestTest(t, testServerHandlerConnectionClose) +} +func testServerHandlerConnectionClose(t testing.TB) { unblockHandler := make(chan bool, 1) testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Connection", "close") @@ -3872,6 +4157,9 @@ func TestServerHandlerConnectionClose(t *testing.T) { } func TestServer_Headers_HalfCloseRemote(t *testing.T) { + synctestTest(t, testServer_Headers_HalfCloseRemote) +} +func testServer_Headers_HalfCloseRemote(t testing.TB) { var st *serverTester writeData := make(chan bool) writeHeaders := make(chan bool) @@ -3919,7 +4207,8 @@ func TestServer_Headers_HalfCloseRemote(t *testing.T) { st.wantRSTStream(1, ErrCodeStreamClosed) } -func TestServerGracefulShutdown(t *testing.T) { +func TestServerGracefulShutdown(t *testing.T) { synctestTest(t, testServerGracefulShutdown) } +func testServerGracefulShutdown(t testing.TB) { handlerDone := make(chan struct{}) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { <-handlerDone @@ -4016,7 +4305,7 @@ func TestContentEncodingNoSniffing(t *testing.T) { } for _, tt := range resps { - t.Run(tt.name, func(t *testing.T) { + synctestSubtest(t, tt.name, func(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { if tt.contentEncoding != nil { w.Header().Set("Content-Encoding", tt.contentEncoding.(string)) @@ -4055,10 +4344,12 @@ func TestContentEncodingNoSniffing(t *testing.T) { } func TestServerWindowUpdateOnBodyClose(t *testing.T) { + synctestTest(t, testServerWindowUpdateOnBodyClose) +} +func testServerWindowUpdateOnBodyClose(t testing.TB) { const windowSize = 65535 * 2 content := make([]byte, windowSize) - blockCh := make(chan bool) - errc := make(chan error, 1) + errc := make(chan error) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { buf := make([]byte, 4) n, err := io.ReadFull(r.Body, buf) @@ -4070,8 +4361,7 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { errc <- fmt.Errorf("too few bytes read: %d", n) return } - blockCh <- true - <-blockCh + r.Body.Close() errc <- nil }, func(s *Server) { s.MaxUploadBufferPerConnection = windowSize @@ -4090,9 +4380,9 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { EndHeaders: true, }) st.writeData(1, false, content[:windowSize/2]) - <-blockCh - st.stream(1).body.CloseWithError(io.EOF) - blockCh <- true + if err := <-errc; err != nil { + t.Fatal(err) + } // Wait for flow control credit for the portion of the request written so far. increments := windowSize / 2 @@ -4112,13 +4402,12 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { // Writing data after the stream is reset immediately returns flow control credit. st.writeData(1, false, content[windowSize/2:]) st.wantWindowUpdate(0, windowSize/2) - - if err := <-errc; err != nil { - t.Error(err) - } } func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) { + synctestTest(t, testNoErrorLoggedOnPostAfterGOAWAY) +} +func testNoErrorLoggedOnPostAfterGOAWAY(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) defer st.Close() @@ -4151,7 +4440,8 @@ func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) { } } -func TestServerSendsProcessing(t *testing.T) { +func TestServerSendsProcessing(t *testing.T) { synctestTest(t, testServerSendsProcessing) } +func testServerSendsProcessing(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.WriteHeader(http.StatusProcessing) w.Write([]byte("stuff")) @@ -4178,7 +4468,8 @@ func TestServerSendsProcessing(t *testing.T) { }) } -func TestServerSendsEarlyHints(t *testing.T) { +func TestServerSendsEarlyHints(t *testing.T) { synctestTest(t, testServerSendsEarlyHints) } +func testServerSendsEarlyHints(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { h := w.Header() h.Add("Content-Length", "123") @@ -4234,7 +4525,8 @@ func TestServerSendsEarlyHints(t *testing.T) { }) } -func TestProtocolErrorAfterGoAway(t *testing.T) { +func TestProtocolErrorAfterGoAway(t *testing.T) { synctestTest(t, testProtocolErrorAfterGoAway) } +func testProtocolErrorAfterGoAway(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.Copy(io.Discard, r.Body) }) @@ -4279,7 +4571,7 @@ func TestServerInitialFlowControlWindow(t *testing.T) { // test this case, but we currently do not. 65535 * 2, } { - t.Run(fmt.Sprint(want), func(t *testing.T) { + synctestSubtest(t, fmt.Sprint(want), func(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { @@ -4322,7 +4614,8 @@ func TestServerInitialFlowControlWindow(t *testing.T) { // TestCanonicalHeaderCacheGrowth verifies that the canonical header cache // size is capped to a reasonable level. -func TestCanonicalHeaderCacheGrowth(t *testing.T) { +func TestCanonicalHeaderCacheGrowth(t *testing.T) { synctestTest(t, testCanonicalHeaderCacheGrowth) } +func testCanonicalHeaderCacheGrowth(t testing.TB) { for _, size := range []int{1, (1 << 20) - 10} { base := strings.Repeat("X", size) sc := &serverConn{ @@ -4355,6 +4648,9 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) { // Terminating the request stream on the client causes Write to return. // We should not access the slice after this point. func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) { + synctestTest(t, testServerWriteDoesNotRetainBufferAfterReturn) +} +func testServerWriteDoesNotRetainBufferAfterReturn(t testing.TB) { donec := make(chan struct{}) ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { defer close(donec) @@ -4390,6 +4686,9 @@ func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) { // Shutting down the Server causes Write to return. // We should not access the slice after this point. func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) { + synctestTest(t, testServerWriteDoesNotRetainBufferAfterServerClose) +} +func testServerWriteDoesNotRetainBufferAfterServerClose(t testing.TB) { donec := make(chan struct{}, 1) ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { donec <- struct{}{} @@ -4422,7 +4721,8 @@ func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) { <-donec } -func TestServerMaxHandlerGoroutines(t *testing.T) { +func TestServerMaxHandlerGoroutines(t *testing.T) { synctestTest(t, testServerMaxHandlerGoroutines) } +func testServerMaxHandlerGoroutines(t testing.TB) { const maxHandlers = 10 handlerc := make(chan chan bool) donec := make(chan struct{}) @@ -4522,7 +4822,8 @@ func TestServerMaxHandlerGoroutines(t *testing.T) { } } -func TestServerContinuationFlood(t *testing.T) { +func TestServerContinuationFlood(t *testing.T) { synctestTest(t, testServerContinuationFlood) } +func testServerContinuationFlood(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { fmt.Println(r.Header) }, func(s *http.Server) { @@ -4575,6 +4876,9 @@ func TestServerContinuationFlood(t *testing.T) { } func TestServerContinuationAfterInvalidHeader(t *testing.T) { + synctestTest(t, testServerContinuationAfterInvalidHeader) +} +func testServerContinuationAfterInvalidHeader(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { fmt.Println(r.Header) }) @@ -4613,6 +4917,9 @@ func TestServerContinuationAfterInvalidHeader(t *testing.T) { } func TestServerUpgradeRequestPrefaceFailure(t *testing.T) { + synctestTest(t, testServerUpgradeRequestPrefaceFailure) +} +func testServerUpgradeRequestPrefaceFailure(t testing.TB) { // An h2c upgrade request fails when the client preface is not as expected. s2 := &Server{ // Setting IdleTimeout triggers #67168. @@ -4633,7 +4940,8 @@ func TestServerUpgradeRequestPrefaceFailure(t *testing.T) { } // Issue 67036: A stream error should result in the handler's request context being canceled. -func TestServerRequestCancelOnError(t *testing.T) { +func TestServerRequestCancelOnError(t *testing.T) { synctestTest(t, testServerRequestCancelOnError) } +func testServerRequestCancelOnError(t testing.TB) { recvc := make(chan struct{}) // handler has started donec := make(chan struct{}) // handler has finished st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -4667,6 +4975,9 @@ func TestServerRequestCancelOnError(t *testing.T) { } func TestServerSetReadWriteDeadlineRace(t *testing.T) { + synctestTest(t, testServerSetReadWriteDeadlineRace) +} +func testServerSetReadWriteDeadlineRace(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { ctl := http.NewResponseController(w) ctl.SetReadDeadline(time.Now().Add(3600 * time.Second)) @@ -4679,7 +4990,8 @@ func TestServerSetReadWriteDeadlineRace(t *testing.T) { resp.Body.Close() } -func TestServerWriteByteTimeout(t *testing.T) { +func TestServerWriteByteTimeout(t *testing.T) { synctestTest(t, testServerWriteByteTimeout) } +func testServerWriteByteTimeout(t testing.TB) { const timeout = 1 * time.Second st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Write(make([]byte, 100)) @@ -4711,7 +5023,8 @@ func TestServerWriteByteTimeout(t *testing.T) { st.wantClosed() } -func TestServerPingSent(t *testing.T) { +func TestServerPingSent(t *testing.T) { synctestTest(t, testServerPingSent) } +func testServerPingSent(t testing.TB) { const readIdleTimeout = 15 * time.Second st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { @@ -4731,7 +5044,8 @@ func TestServerPingSent(t *testing.T) { st.wantClosed() } -func TestServerPingResponded(t *testing.T) { +func TestServerPingResponded(t *testing.T) { synctestTest(t, testServerPingResponded) } +func testServerPingResponded(t testing.TB) { const readIdleTimeout = 15 * time.Second st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { @@ -4753,3 +5067,128 @@ func TestServerPingResponded(t *testing.T) { st.advance(2 * time.Second) st.wantIdle() } + +// golang.org/issue/15425: test that a handler closing the request +// body doesn't terminate the stream to the peer. (It just stops +// readability from the handler's side, and eventually the client +// runs out of flow control tokens) +func TestServerSendDataAfterRequestBodyClose(t *testing.T) { + synctestTest(t, testServerSendDataAfterRequestBodyClose) +} +func testServerSendDataAfterRequestBodyClose(t testing.TB) { + st := newServerTester(t, nil) + st.greet() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: false, + EndHeaders: true, + }) + + // Handler starts writing the response body. + call := st.nextHandlerCall() + call.do(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("one")) + http.NewResponseController(w).Flush() + }) + st.wantFrameType(FrameHeaders) + st.wantData(wantData{ + streamID: 1, + endStream: false, + data: []byte("one"), + }) + st.wantIdle() + + // Handler closes the request body. + // This is not observable by the client. + call.do(func(w http.ResponseWriter, req *http.Request) { + req.Body.Close() + }) + st.wantIdle() + + // The client can still send request data, which is discarded. + st.writeData(1, false, []byte("client-sent data")) + st.wantIdle() + + // Handler can still write more response body, + // which is sent to the client. + call.do(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("two")) + http.NewResponseController(w).Flush() + }) + st.wantData(wantData{ + streamID: 1, + endStream: false, + data: []byte("two"), + }) + st.wantIdle() +} + +// This test documents current behavior, rather than ideal behavior that we +// would necessarily like to see. Refer to go.dev/issues/75936 for details. +func TestServerRFC7540PrioritySmallPayload(t *testing.T) { + synctestTest(t, testServerRFC7540PrioritySmallPayload) +} +func testServerRFC7540PrioritySmallPayload(t testing.TB) { + endTest := false + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + for !endTest { + w.Write([]byte("a")) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + }, func(s *Server) { + s.NewWriteScheduler = func() WriteScheduler { + return NewPriorityWriteScheduler(nil) + } + }) + if syncConn, ok := st.cc.(*synctestNetConn); ok { + syncConn.SetReadBufferSize(1) + } else { + t.Fatal("Server connection is not synctestNetConn") + } + defer st.Close() + defer func() { endTest = true }() + st.greet() + + // Create 5 streams with weight of 0, and another 5 streams with weight of + // 255. + // Since each stream receives an infinite number of bytes, we should expect + // to see that almost all of the response we get are for the streams with + // weight of 255. + for i := 1; i <= 19; i += 2 { + weight := 1 + if i > 10 { + weight = 255 + } + st.writeHeaders(HeadersFrameParam{ + StreamID: uint32(i), + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + Priority: PriorityParam{StreamDep: 0, Weight: uint8(weight)}, + }) + synctest.Wait() + } + + // In the current implementation however, the response we get are + // distributed equally amongst all the streams, regardless of weight. + streamWriteCount := make(map[uint32]int) + totalWriteCount := 10000 + for range totalWriteCount { + f := st.readFrame() + if f == nil { + break + } + streamWriteCount[f.Header().StreamID] += 1 + } + for streamID, writeCount := range streamWriteCount { + expectedWriteCount := totalWriteCount / len(streamWriteCount) + errorMargin := expectedWriteCount / 100 + if writeCount >= expectedWriteCount+errorMargin || writeCount <= expectedWriteCount-errorMargin { + t.Errorf("Expected stream %v to receive %v±%v writes, got %v", streamID, expectedWriteCount, errorMargin, writeCount) + } + } +} diff --git a/http2/sync_test.go b/http2/sync_test.go deleted file mode 100644 index 6687202d2c..0000000000 --- a/http2/sync_test.go +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "context" - "fmt" - "runtime" - "strconv" - "strings" - "sync" - "testing" - "time" -) - -// A synctestGroup synchronizes between a set of cooperating goroutines. -type synctestGroup struct { - mu sync.Mutex - gids map[int]bool - now time.Time - timers map[*fakeTimer]struct{} -} - -type goroutine struct { - id int - parent int - state string - syscall bool -} - -// newSynctest creates a new group with the synthetic clock set the provided time. -func newSynctest(now time.Time) *synctestGroup { - return &synctestGroup{ - gids: map[int]bool{ - currentGoroutine(): true, - }, - now: now, - } -} - -// Join adds the current goroutine to the group. -func (g *synctestGroup) Join() { - g.mu.Lock() - defer g.mu.Unlock() - g.gids[currentGoroutine()] = true -} - -// Count returns the number of goroutines in the group. -func (g *synctestGroup) Count() int { - gs := stacks(true) - count := 0 - for _, gr := range gs { - if !g.gids[gr.id] && !g.gids[gr.parent] { - continue - } - count++ - } - return count -} - -// Close calls t.Fatal if the group contains any running goroutines. -func (g *synctestGroup) Close(t testing.TB) { - if count := g.Count(); count != 1 { - buf := make([]byte, 16*1024) - n := runtime.Stack(buf, true) - t.Logf("stacks:\n%s", buf[:n]) - t.Fatalf("%v goroutines still running after test completed, expect 1", count) - } -} - -// Wait blocks until every goroutine in the group and their direct children are idle. -func (g *synctestGroup) Wait() { - for i := 0; ; i++ { - if g.idle() { - return - } - runtime.Gosched() - if runtime.GOOS == "js" { - // When GOOS=js, we appear to need to time.Sleep to make progress - // on some syscalls. In particular, without this sleep - // writing to stdout (including via t.Log) can block forever. - for range 10 { - time.Sleep(1) - } - } - } -} - -func (g *synctestGroup) idle() bool { - gs := stacks(true) - g.mu.Lock() - defer g.mu.Unlock() - for _, gr := range gs[1:] { - if !g.gids[gr.id] && !g.gids[gr.parent] { - continue - } - if gr.syscall { - return false - } - // From runtime/runtime2.go. - switch gr.state { - case "IO wait": - case "chan receive (nil chan)": - case "chan send (nil chan)": - case "select": - case "select (no cases)": - case "chan receive": - case "chan send": - case "sync.Cond.Wait": - default: - return false - } - } - return true -} - -func currentGoroutine() int { - s := stacks(false) - return s[0].id -} - -func stacks(all bool) []goroutine { - buf := make([]byte, 16*1024) - for { - n := runtime.Stack(buf, all) - if n < len(buf) { - buf = buf[:n] - break - } - buf = make([]byte, len(buf)*2) - } - - var goroutines []goroutine - for _, gs := range strings.Split(string(buf), "\n\n") { - skip, rest, ok := strings.Cut(gs, "goroutine ") - if skip != "" || !ok { - panic(fmt.Errorf("1 unparsable goroutine stack:\n%s", gs)) - } - ids, rest, ok := strings.Cut(rest, " [") - if !ok { - panic(fmt.Errorf("2 unparsable goroutine stack:\n%s", gs)) - } - id, err := strconv.Atoi(ids) - if err != nil { - panic(fmt.Errorf("3 unparsable goroutine stack:\n%s", gs)) - } - state, rest, ok := strings.Cut(rest, "]") - isSyscall := false - if strings.Contains(rest, "\nsyscall.") { - isSyscall = true - } - var parent int - _, rest, ok = strings.Cut(rest, "\ncreated by ") - if ok && strings.Contains(rest, " in goroutine ") { - _, rest, ok := strings.Cut(rest, " in goroutine ") - if !ok { - panic(fmt.Errorf("4 unparsable goroutine stack:\n%s", gs)) - } - parents, rest, ok := strings.Cut(rest, "\n") - if !ok { - panic(fmt.Errorf("5 unparsable goroutine stack:\n%s", gs)) - } - parent, err = strconv.Atoi(parents) - if err != nil { - panic(fmt.Errorf("6 unparsable goroutine stack:\n%s", gs)) - } - } - goroutines = append(goroutines, goroutine{ - id: id, - parent: parent, - state: state, - syscall: isSyscall, - }) - } - return goroutines -} - -// AdvanceTime advances the synthetic clock by d. -func (g *synctestGroup) AdvanceTime(d time.Duration) { - defer g.Wait() - g.mu.Lock() - defer g.mu.Unlock() - g.now = g.now.Add(d) - for tm := range g.timers { - if tm.when.After(g.now) { - continue - } - tm.run() - delete(g.timers, tm) - } -} - -// Now returns the current synthetic time. -func (g *synctestGroup) Now() time.Time { - g.mu.Lock() - defer g.mu.Unlock() - return g.now -} - -// TimeUntilEvent returns the amount of time until the next scheduled timer. -func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) { - g.mu.Lock() - defer g.mu.Unlock() - for tm := range g.timers { - if dd := tm.when.Sub(g.now); !scheduled || dd < d { - d = dd - scheduled = true - } - } - return d, scheduled -} - -// Sleep is time.Sleep, but using synthetic time. -func (g *synctestGroup) Sleep(d time.Duration) { - tm := g.NewTimer(d) - <-tm.C() -} - -// NewTimer is time.NewTimer, but using synthetic time. -func (g *synctestGroup) NewTimer(d time.Duration) Timer { - return g.addTimer(d, &fakeTimer{ - ch: make(chan time.Time), - }) -} - -// AfterFunc is time.AfterFunc, but using synthetic time. -func (g *synctestGroup) AfterFunc(d time.Duration, f func()) Timer { - return g.addTimer(d, &fakeTimer{ - f: f, - }) -} - -// ContextWithTimeout is context.WithTimeout, but using synthetic time. -func (g *synctestGroup) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(ctx) - tm := g.AfterFunc(d, cancel) - return ctx, func() { - tm.Stop() - cancel() - } -} - -func (g *synctestGroup) addTimer(d time.Duration, tm *fakeTimer) *fakeTimer { - g.mu.Lock() - defer g.mu.Unlock() - tm.g = g - tm.when = g.now.Add(d) - if g.timers == nil { - g.timers = make(map[*fakeTimer]struct{}) - } - if tm.when.After(g.now) { - g.timers[tm] = struct{}{} - } else { - tm.run() - } - return tm -} - -type Timer = interface { - C() <-chan time.Time - Reset(d time.Duration) bool - Stop() bool -} - -type fakeTimer struct { - g *synctestGroup - when time.Time - ch chan time.Time - f func() -} - -func (tm *fakeTimer) run() { - if tm.ch != nil { - tm.ch <- tm.g.now - } else { - go func() { - tm.g.Join() - tm.f() - }() - } -} - -func (tm *fakeTimer) C() <-chan time.Time { return tm.ch } - -func (tm *fakeTimer) Reset(d time.Duration) bool { - tm.g.mu.Lock() - defer tm.g.mu.Unlock() - _, stopped := tm.g.timers[tm] - if d <= 0 { - delete(tm.g.timers, tm) - tm.run() - } else { - tm.when = tm.g.now.Add(d) - tm.g.timers[tm] = struct{}{} - } - return stopped -} - -func (tm *fakeTimer) Stop() bool { - tm.g.mu.Lock() - defer tm.g.mu.Unlock() - _, stopped := tm.g.timers[tm] - delete(tm.g.timers, tm) - return stopped -} - -// TestSynctestLogs verifies that t.Log works, -// in particular that the GOOS=js workaround in synctestGroup.Wait is working. -// (When GOOS=js, writing to stdout can hang indefinitely if some goroutine loops -// calling runtime.Gosched; see Wait for the workaround.) -func TestSynctestLogs(t *testing.T) { - g := newSynctest(time.Now()) - donec := make(chan struct{}) - go func() { - g.Join() - for range 100 { - t.Logf("logging a long line") - } - close(donec) - }() - g.Wait() - select { - case <-donec: - default: - panic("done") - } -} diff --git a/http2/synctest_go124_test.go b/http2/synctest_go124_test.go new file mode 100644 index 0000000000..59f66ac2da --- /dev/null +++ b/http2/synctest_go124_test.go @@ -0,0 +1,42 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.25 && goexperiment.synctest + +package http2 + +import ( + "slices" + "testing" + "testing/synctest" +) + +// synctestTest emulates the Go 1.25 synctest.Test function on Go 1.24. +func synctestTest(t *testing.T, f func(t testing.TB)) { + t.Helper() + synctest.Run(func() { + t.Helper() + ct := &cleanupT{T: t} + defer ct.done() + f(ct) + }) +} + +// cleanupT wraps a testing.T and adds its own Cleanup method. +// Used to execute cleanup functions within a synctest bubble. +type cleanupT struct { + *testing.T + cleanups []func() +} + +// Cleanup replaces T.Cleanup. +func (t *cleanupT) Cleanup(f func()) { + t.cleanups = append(t.cleanups, f) +} + +func (t *cleanupT) done() { + for _, f := range slices.Backward(t.cleanups) { + f() + } +} diff --git a/http2/synctest_go125_test.go b/http2/synctest_go125_test.go new file mode 100644 index 0000000000..a0c5696160 --- /dev/null +++ b/http2/synctest_go125_test.go @@ -0,0 +1,20 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.25 + +package http2 + +import ( + "testing" + "testing/synctest" +) + +func synctestTest(t *testing.T, f func(t testing.TB)) { + t.Helper() + synctest.Test(t, func(t *testing.T) { + t.Helper() + f(t) + }) +} diff --git a/http2/timer.go b/http2/timer.go deleted file mode 100644 index 0b1c17b812..0000000000 --- a/http2/timer.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -package http2 - -import "time" - -// A timer is a time.Timer, as an interface which can be replaced in tests. -type timer = interface { - C() <-chan time.Time - Reset(d time.Duration) bool - Stop() bool -} - -// timeTimer adapts a time.Timer to the timer interface. -type timeTimer struct { - *time.Timer -} - -func (t timeTimer) C() <-chan time.Time { return t.Timer.C } diff --git a/http2/transport.go b/http2/transport.go index f26356b9cd..ccb87e6da3 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -9,6 +9,7 @@ package http2 import ( "bufio" "bytes" + "compress/flate" "compress/gzip" "context" "crypto/rand" @@ -193,50 +194,6 @@ type Transport struct { type transportTestHooks struct { newclientconn func(*ClientConn) - group synctestGroupInterface -} - -func (t *Transport) markNewGoroutine() { - if t != nil && t.transportTestHooks != nil { - t.transportTestHooks.group.Join() - } -} - -func (t *Transport) now() time.Time { - if t != nil && t.transportTestHooks != nil { - return t.transportTestHooks.group.Now() - } - return time.Now() -} - -func (t *Transport) timeSince(when time.Time) time.Duration { - if t != nil && t.transportTestHooks != nil { - return t.now().Sub(when) - } - return time.Since(when) -} - -// newTimer creates a new time.Timer, or a synthetic timer in tests. -func (t *Transport) newTimer(d time.Duration) timer { - if t.transportTestHooks != nil { - return t.transportTestHooks.group.NewTimer(d) - } - return timeTimer{time.NewTimer(d)} -} - -// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. -func (t *Transport) afterFunc(d time.Duration, f func()) timer { - if t.transportTestHooks != nil { - return t.transportTestHooks.group.AfterFunc(d, f) - } - return timeTimer{time.AfterFunc(d, f)} -} - -func (t *Transport) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { - if t.transportTestHooks != nil { - return t.transportTestHooks.group.ContextWithTimeout(ctx, d) - } - return context.WithTimeout(ctx, d) } func (t *Transport) maxHeaderListSize() uint32 { @@ -366,7 +323,7 @@ type ClientConn struct { readerErr error // set before readerDone is closed idleTimeout time.Duration // or 0 for never - idleTimer timer + idleTimer *time.Timer mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes @@ -399,6 +356,7 @@ type ClientConn struct { readIdleTimeout time.Duration pingTimeout time.Duration extendedConnectAllowed bool + strictMaxConcurrentStreams bool // rstStreamPingsBlocked works around an unfortunate gRPC behavior. // gRPC strictly limits the number of PING frames that it will receive. @@ -418,11 +376,24 @@ type ClientConn struct { // completely unresponsive connection. pendingResets int + // readBeforeStreamID is the smallest stream ID that has not been followed by + // a frame read from the peer. We use this to determine when a request may + // have been sent to a completely unresponsive connection: + // If the request ID is less than readBeforeStreamID, then we have had some + // indication of life on the connection since sending the request. + readBeforeStreamID uint32 + // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. // Write to reqHeaderMu to lock it, read from it to unlock. // Lock reqmu BEFORE mu or wmu. reqHeaderMu chan struct{} + // internalStateHook reports state changes back to the net/http.ClientConn. + // Note that this is different from the user state hook registered by + // net/http.ClientConn.SetStateHook: The internal hook calls ClientConn, + // which calls the user hook. + internalStateHook func() + // wmu is held while writing. // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes. // Only acquire both at the same time when changing peer settings. @@ -534,14 +505,12 @@ func (cs *clientStream) closeReqBodyLocked() { cs.reqBodyClosed = make(chan struct{}) reqBodyClosed := cs.reqBodyClosed go func() { - cs.cc.t.markNewGoroutine() cs.reqBody.Close() close(reqBodyClosed) }() } type stickyErrWriter struct { - group synctestGroupInterface conn net.Conn timeout time.Duration err *error @@ -551,7 +520,7 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) { if *sew.err != nil { return 0, *sew.err } - n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p) + n, err = writeWithByteTimeout(sew.conn, sew.timeout, p) *sew.err = err return n, err } @@ -650,9 +619,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - tm := t.newTimer(d) + tm := time.NewTimer(d) select { - case <-tm.C(): + case <-tm.C: t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): @@ -699,6 +668,7 @@ var ( errClientConnUnusable = errors.New("http2: client conn not usable") errClientConnNotEstablished = errors.New("http2: client conn could not be established") errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + errClientConnForceClosed = errors.New("http2: client connection force closed via ClientConn.Close") ) // shouldRetryRequest is called by RoundTrip when a request fails to get @@ -753,7 +723,7 @@ func canRetryError(err error) bool { func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { if t.transportTestHooks != nil { - return t.newClientConn(nil, singleUse) + return t.newClientConn(nil, singleUse, nil) } host, _, err := net.SplitHostPort(addr) if err != nil { @@ -763,7 +733,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b if err != nil { return nil, err } - return t.newClientConn(tconn, singleUse) + return t.newClientConn(tconn, singleUse, nil) } func (t *Transport) newTLSConfig(host string) *tls.Config { @@ -815,10 +785,10 @@ func (t *Transport) expectContinueTimeout() time.Duration { } func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives()) + return t.newClientConn(c, t.disableKeepAlives(), nil) } -func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { +func (t *Transport) newClientConn(c net.Conn, singleUse bool, internalStateHook func()) (*ClientConn, error) { conf := configFromTransport(t) cc := &ClientConn{ t: t, @@ -829,7 +799,8 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro initialWindowSize: 65535, // spec default initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream, maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. - peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. + strictMaxConcurrentStreams: conf.StrictMaxConcurrentRequests, + peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. streams: make(map[uint32]*clientStream), singleUse: singleUse, seenSettingsChan: make(chan struct{}), @@ -838,14 +809,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro pingTimeout: conf.PingTimeout, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), - lastActive: t.now(), + lastActive: time.Now(), + internalStateHook: internalStateHook, } - var group synctestGroupInterface if t.transportTestHooks != nil { - t.markNewGoroutine() t.transportTestHooks.newclientconn(cc) c = cc.tconn - group = t.group } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -857,7 +826,6 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro // TODO: adjust this writer size to account for frame size + // MTU + crypto/tls record padding. cc.bw = bufio.NewWriter(stickyErrWriter{ - group: group, conn: c, timeout: conf.WriteByteTimeout, err: &cc.werr, @@ -906,7 +874,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro // Start the idle timer after the connection is fully initialized. if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d - cc.idleTimer = t.afterFunc(d, cc.onIdleTimeout) + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } go cc.readLoop() @@ -917,7 +885,7 @@ func (cc *ClientConn) healthCheck() { pingTimeout := cc.pingTimeout // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout) + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -1067,7 +1035,7 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) { return } var maxConcurrentOkay bool - if cc.t.StrictMaxConcurrentStreams { + if cc.strictMaxConcurrentStreams { // We'll tell the caller we can take a new request to // prevent the caller from dialing a new TCP // connection, but then we'll block later before @@ -1083,10 +1051,7 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) { maxConcurrentOkay = cc.currentRequestCountLocked() < int(cc.maxConcurrentStreams) } - st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && - !cc.doNotReuse && - int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && - !cc.tooIdleLocked() + st.canTakeNewRequest = maxConcurrentOkay && cc.isUsableLocked() // If this connection has never been used for a request and is closed, // then let it take a request (which will fail). @@ -1102,6 +1067,31 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) { return } +func (cc *ClientConn) isUsableLocked() bool { + return cc.goAway == nil && + !cc.closed && + !cc.closing && + !cc.doNotReuse && + int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && + !cc.tooIdleLocked() +} + +// canReserveLocked reports whether a net/http.ClientConn can reserve a slot on this conn. +// +// This follows slightly different rules than clientConnIdleState.canTakeNewRequest. +// We only permit reservations up to the conn's concurrency limit. +// This differs from ClientConn.ReserveNewRequest, which permits reservations +// past the limit when StrictMaxConcurrentStreams is set. +func (cc *ClientConn) canReserveLocked() bool { + if cc.currentRequestCountLocked() >= int(cc.maxConcurrentStreams) { + return false + } + if !cc.isUsableLocked() { + return false + } + return true +} + // currentRequestCountLocked reports the number of concurrency slots currently in use, // including active streams, reserved slots, and reset streams waiting for acknowledgement. func (cc *ClientConn) currentRequestCountLocked() int { @@ -1113,6 +1103,14 @@ func (cc *ClientConn) canTakeNewRequestLocked() bool { return st.canTakeNewRequest } +// availableLocked reports the number of concurrency slots available. +func (cc *ClientConn) availableLocked() int { + if !cc.canTakeNewRequestLocked() { + return 0 + } + return max(0, int(cc.maxConcurrentStreams)-cc.currentRequestCountLocked()) +} + // tooIdleLocked reports whether this connection has been been sitting idle // for too much wall time. func (cc *ClientConn) tooIdleLocked() bool { @@ -1120,7 +1118,7 @@ func (cc *ClientConn) tooIdleLocked() bool { // times are compared based on their wall time. We don't want // to reuse a connection that's been sitting idle during // VM/laptop suspend if monotonic time was also frozen. - return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && cc.t.timeSince(cc.lastIdle.Round(0)) > cc.idleTimeout + return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout } // onIdleTimeout is called from a time.AfterFunc goroutine. It will @@ -1137,6 +1135,7 @@ func (cc *ClientConn) closeConn() { t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn) defer t.Stop() cc.tconn.Close() + cc.maybeCallStateHook() } // A tls.Conn.Close can hang for a long time if the peer is unresponsive. @@ -1186,7 +1185,6 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { done := make(chan struct{}) cancelled := false // guarded by cc.mu go func() { - cc.t.markNewGoroutine() cc.mu.Lock() defer cc.mu.Unlock() for { @@ -1257,8 +1255,7 @@ func (cc *ClientConn) closeForError(err error) { // // In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. func (cc *ClientConn) Close() error { - err := errors.New("http2: client connection force closed via ClientConn.Close") - cc.closeForError(err) + cc.closeForError(errClientConnForceClosed) return nil } @@ -1427,7 +1424,6 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) // // It sends the request and performs post-request cleanup (closing Request.Body, etc.). func (cs *clientStream) doRequest(req *http.Request, streamf func(*clientStream)) { - cs.cc.t.markNewGoroutine() err := cs.writeRequest(req, streamf) cs.cleanupWriteRequest(err) } @@ -1558,9 +1554,9 @@ func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStre var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := cc.t.newTimer(d) + timer := time.NewTimer(d) defer timer.Stop() - respHeaderTimer = timer.C() + respHeaderTimer = timer.C respHeaderRecv = cs.respHeaderRecv } // Wait until the peer half-closes its end of the stream, @@ -1665,6 +1661,8 @@ func (cs *clientStream) cleanupWriteRequest(err error) { } bodyClosed := cs.reqBodyClosed closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil + // Have we read any frames from the connection since sending this request? + readSinceStream := cc.readBeforeStreamID > cs.ID cc.mu.Unlock() if mustCloseBody { cs.reqBody.Close() @@ -1696,8 +1694,10 @@ func (cs *clientStream) cleanupWriteRequest(err error) { // // This could be due to the server becoming unresponsive. // To avoid sending too many requests on a dead connection, - // we let the request continue to consume a concurrency slot - // until we can confirm the server is still responding. + // if we haven't read any frames from the connection since + // sending this request, we let it continue to consume + // a concurrency slot until we can confirm the server is + // still responding. // We do this by sending a PING frame along with the RST_STREAM // (unless a ping is already in flight). // @@ -1708,7 +1708,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) { // because it's short lived and will probably be closed before // we get the ping response. ping := false - if !closeOnIdle { + if !closeOnIdle && !readSinceStream { cc.mu.Lock() // rstStreamPingsBlocked works around a gRPC behavior: // see comment on the field for details. @@ -1742,6 +1742,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) { } close(cs.donec) + cc.maybeCallStateHook() } // awaitOpenSlotForStreamLocked waits until len(streams) < maxConcurrentStreams. @@ -1753,7 +1754,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { // Return a fatal error which aborts the retry loop. return errClientConnNotEstablished } - cc.lastActive = cc.t.now() + cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { return errClientConnUnusable } @@ -2092,10 +2093,10 @@ func (cc *ClientConn) forgetStreamID(id uint32) { if len(cc.streams) != slen-1 { panic("forgetting unknown stream id") } - cc.lastActive = cc.t.now() + cc.lastActive = time.Now() if len(cc.streams) == 0 && cc.idleTimer != nil { cc.idleTimer.Reset(cc.idleTimeout) - cc.lastIdle = cc.t.now() + cc.lastIdle = time.Now() } // Wake up writeRequestBody via clientStream.awaitFlowControl and // wake up RoundTrip if there is a pending request. @@ -2121,7 +2122,6 @@ type clientConnReadLoop struct { // readLoop runs in its own goroutine and reads and dispatches frames. func (cc *ClientConn) readLoop() { - cc.t.markNewGoroutine() rl := &clientConnReadLoop{cc: cc} defer rl.cleanup() cc.readerErr = rl.run() @@ -2188,9 +2188,9 @@ func (rl *clientConnReadLoop) cleanup() { if cc.idleTimeout > 0 && unusedWaitTime > cc.idleTimeout { unusedWaitTime = cc.idleTimeout } - idleTime := cc.t.now().Sub(cc.lastActive) + idleTime := time.Now().Sub(cc.lastActive) if atomic.LoadUint32(&cc.atomicReused) == 0 && idleTime < unusedWaitTime && !cc.closedOnIdle { - cc.idleTimer = cc.t.afterFunc(unusedWaitTime-idleTime, func() { + cc.idleTimer = time.AfterFunc(unusedWaitTime-idleTime, func() { cc.t.connPool().MarkDead(cc) }) } else { @@ -2250,9 +2250,9 @@ func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.readIdleTimeout - var t timer + var t *time.Timer if readIdleTimeout != 0 { - t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck) + t = time.AfterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -2795,6 +2795,7 @@ func (rl *clientConnReadLoop) streamByID(id uint32, headerOrData bool) *clientSt // See comment on ClientConn.rstStreamPingsBlocked for details. rl.cc.rstStreamPingsBlocked = false } + rl.cc.readBeforeStreamID = rl.cc.nextStreamID cs := rl.cc.streams[id] if cs != nil && !cs.readAborted { return cs @@ -2845,6 +2846,7 @@ func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error { func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { cc := rl.cc + defer cc.maybeCallStateHook() cc.mu.Lock() defer cc.mu.Unlock() @@ -2998,7 +3000,6 @@ func (cc *ClientConn) Ping(ctx context.Context) error { var pingError error errc := make(chan struct{}) go func() { - cc.t.markNewGoroutine() cc.wmu.Lock() defer cc.wmu.Unlock() if pingError = cc.fr.WritePing(false, p); pingError != nil { @@ -3026,6 +3027,7 @@ func (cc *ClientConn) Ping(ctx context.Context) error { func (rl *clientConnReadLoop) processPing(f *PingFrame) error { if f.IsAck() { cc := rl.cc + defer cc.maybeCallStateHook() cc.mu.Lock() defer cc.mu.Unlock() // If ack, notify listener if any @@ -3128,35 +3130,102 @@ type erringRoundTripper struct{ err error } func (rt erringRoundTripper) RoundTripErr() error { return rt.err } func (rt erringRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { return nil, rt.err } +var errConcurrentReadOnResBody = errors.New("http2: concurrent read on response body") + // gzipReader wraps a response body so it can lazily -// call gzip.NewReader on the first call to Read +// get gzip.Reader from the pool on the first call to Read. +// After Close is called it puts gzip.Reader to the pool immediately +// if there is no Read in progress or later when Read completes. type gzipReader struct { _ incomparable body io.ReadCloser // underlying Response.Body - zr *gzip.Reader // lazily-initialized gzip reader - zerr error // sticky error + mu sync.Mutex // guards zr and zerr + zr *gzip.Reader // stores gzip reader from the pool between reads + zerr error // sticky gzip reader init error or sentinel value to detect concurrent read and read after close } -func (gz *gzipReader) Read(p []byte) (n int, err error) { +type eofReader struct{} + +func (eofReader) Read([]byte) (int, error) { return 0, io.EOF } +func (eofReader) ReadByte() (byte, error) { return 0, io.EOF } + +var gzipPool = sync.Pool{New: func() any { return new(gzip.Reader) }} + +// gzipPoolGet gets a gzip.Reader from the pool and resets it to read from r. +func gzipPoolGet(r io.Reader) (*gzip.Reader, error) { + zr := gzipPool.Get().(*gzip.Reader) + if err := zr.Reset(r); err != nil { + gzipPoolPut(zr) + return nil, err + } + return zr, nil +} + +// gzipPoolPut puts a gzip.Reader back into the pool. +func gzipPoolPut(zr *gzip.Reader) { + // Reset will allocate bufio.Reader if we pass it anything + // other than a flate.Reader, so ensure that it's getting one. + var r flate.Reader = eofReader{} + zr.Reset(r) + gzipPool.Put(zr) +} + +// acquire returns a gzip.Reader for reading response body. +// The reader must be released after use. +func (gz *gzipReader) acquire() (*gzip.Reader, error) { + gz.mu.Lock() + defer gz.mu.Unlock() if gz.zerr != nil { - return 0, gz.zerr + return nil, gz.zerr } if gz.zr == nil { - gz.zr, err = gzip.NewReader(gz.body) - if err != nil { - gz.zerr = err - return 0, err + gz.zr, gz.zerr = gzipPoolGet(gz.body) + if gz.zerr != nil { + return nil, gz.zerr } } - return gz.zr.Read(p) + ret := gz.zr + gz.zr, gz.zerr = nil, errConcurrentReadOnResBody + return ret, nil } -func (gz *gzipReader) Close() error { - if err := gz.body.Close(); err != nil { - return err +// release returns the gzip.Reader to the pool if Close was called during Read. +func (gz *gzipReader) release(zr *gzip.Reader) { + gz.mu.Lock() + defer gz.mu.Unlock() + if gz.zerr == errConcurrentReadOnResBody { + gz.zr, gz.zerr = zr, nil + } else { // fs.ErrClosed + gzipPoolPut(zr) + } +} + +// close returns the gzip.Reader to the pool immediately or +// signals release to do so after Read completes. +func (gz *gzipReader) close() { + gz.mu.Lock() + defer gz.mu.Unlock() + if gz.zerr == nil && gz.zr != nil { + gzipPoolPut(gz.zr) + gz.zr = nil } gz.zerr = fs.ErrClosed - return nil +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + zr, err := gz.acquire() + if err != nil { + return 0, err + } + defer gz.release(zr) + + return zr.Read(p) +} + +func (gz *gzipReader) Close() error { + gz.close() + + return gz.body.Close() } type errorReader struct{ err error } @@ -3182,9 +3251,13 @@ func registerHTTPSProtocol(t *http.Transport, rt noDialH2RoundTripper) (err erro } // noDialH2RoundTripper is a RoundTripper which only tries to complete the request -// if there's already has a cached connection to the host. +// if there's already a cached connection to the host. // (The field is exported so it can be accessed via reflect from net/http; tested // by TestNoDialH2RoundTripperType) +// +// A noDialH2RoundTripper is registered with http1.Transport.RegisterProtocol, +// and the http1.Transport can use type assertions to call non-RoundTrip methods on it. +// This lets us expose, for example, NewClientConn to net/http. type noDialH2RoundTripper struct{ *Transport } func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -3195,6 +3268,85 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err return res, err } +func (rt noDialH2RoundTripper) NewClientConn(conn net.Conn, internalStateHook func()) (http.RoundTripper, error) { + tr := rt.Transport + cc, err := tr.newClientConn(conn, tr.disableKeepAlives(), internalStateHook) + if err != nil { + return nil, err + } + + // RoundTrip should block when the conn is at its concurrency limit, + // not return an error. Setting strictMaxConcurrentStreams enables this. + cc.strictMaxConcurrentStreams = true + + return netHTTPClientConn{cc}, nil +} + +// netHTTPClientConn wraps ClientConn and implements the interface net/http expects from +// the RoundTripper returned by NewClientConn. +type netHTTPClientConn struct { + cc *ClientConn +} + +func (cc netHTTPClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + return cc.cc.RoundTrip(req) +} + +func (cc netHTTPClientConn) Close() error { + return cc.cc.Close() +} + +func (cc netHTTPClientConn) Err() error { + cc.cc.mu.Lock() + defer cc.cc.mu.Unlock() + if cc.cc.closed { + return errors.New("connection closed") + } + return nil +} + +func (cc netHTTPClientConn) Reserve() error { + defer cc.cc.maybeCallStateHook() + cc.cc.mu.Lock() + defer cc.cc.mu.Unlock() + if !cc.cc.canReserveLocked() { + return errors.New("connection is unavailable") + } + cc.cc.streamsReserved++ + return nil +} + +func (cc netHTTPClientConn) Release() { + defer cc.cc.maybeCallStateHook() + cc.cc.mu.Lock() + defer cc.cc.mu.Unlock() + // We don't complain if streamsReserved is 0. + // + // This is consistent with RoundTrip: both Release and RoundTrip will + // consume a reservation iff one exists. + if cc.cc.streamsReserved > 0 { + cc.cc.streamsReserved-- + } +} + +func (cc netHTTPClientConn) Available() int { + cc.cc.mu.Lock() + defer cc.cc.mu.Unlock() + return cc.cc.availableLocked() +} + +func (cc netHTTPClientConn) InFlight() int { + cc.cc.mu.Lock() + defer cc.cc.mu.Unlock() + return cc.cc.currentRequestCountLocked() +} + +func (cc *ClientConn) maybeCallStateHook() { + if cc.internalStateHook != nil { + cc.internalStateHook() + } +} + func (t *Transport) idleConnTimeout() time.Duration { // to keep things backwards compatible, we use non-zero values of // IdleConnTimeout, followed by using the IdleConnTimeout on the underlying @@ -3228,7 +3380,7 @@ func traceGotConn(req *http.Request, cc *ClientConn, reused bool) { cc.mu.Lock() ci.WasIdle = len(cc.streams) == 0 && reused if ci.WasIdle && !cc.lastActive.IsZero() { - ci.IdleTime = cc.t.timeSince(cc.lastActive) + ci.IdleTime = time.Since(cc.lastActive) } cc.mu.Unlock() diff --git a/http2/transport_test.go b/http2/transport_test.go index f94d9e400b..532ebd870a 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -9,6 +11,7 @@ import ( "bytes" "compress/gzip" "context" + crand "crypto/rand" "crypto/tls" "encoding/hex" "errors" @@ -26,13 +29,13 @@ import ( "net/url" "os" "reflect" - "runtime" "sort" "strconv" "strings" "sync" "sync/atomic" "testing" + "testing/synctest" "time" "golang.org/x/net/http2/hpack" @@ -121,7 +124,7 @@ func TestIdleConnTimeout(t *testing.T) { }, wantNewConn: false, }} { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t testing.TB) { tt := newTestTransport(t, func(tr *Transport) { tr.IdleConnTimeout = test.idleConnTimeout }) @@ -166,7 +169,7 @@ func TestIdleConnTimeout(t *testing.T) { tc.wantFrameType(FrameSettings) // ACK to our settings } - tt.advance(test.wait) + time.Sleep(test.wait) if got, want := tc.isClosed(), test.wantNewConn; got != want { t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want) } @@ -849,10 +852,18 @@ func newLocalListener(t *testing.T) net.Listener { return ln } -func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } -func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } +func TestTransportReqBodyAfterResponse_200(t *testing.T) { + synctestTest(t, func(t testing.TB) { + testTransportReqBodyAfterResponse(t, 200) + }) +} +func TestTransportReqBodyAfterResponse_403(t *testing.T) { + synctestTest(t, func(t testing.TB) { + testTransportReqBodyAfterResponse(t, 403) + }) +} -func testTransportReqBodyAfterResponse(t *testing.T, status int) { +func testTransportReqBodyAfterResponse(t testing.TB, status int) { const bodySize = 1 << 10 tc := newTestClientConn(t) @@ -1083,6 +1094,11 @@ func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) } func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) { + synctestTest(t, func(t testing.TB) { + testTransportResPatternBubble(t, expect100Continue, resHeader, withData, trailers) + }) +} +func testTransportResPatternBubble(t testing.TB, expect100Continue, resHeader headerType, withData bool, trailers headerType) { const reqBody = "some request body" const resBody = "some response body" @@ -1163,7 +1179,8 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy } // Issue 26189, Issue 17739: ignore unknown 1xx responses -func TestTransportUnknown1xx(t *testing.T) { +func TestTransportUnknown1xx(t *testing.T) { synctestTest(t, testTransportUnknown1xx) } +func testTransportUnknown1xx(t testing.TB) { var buf bytes.Buffer defer func() { got1xxFuncForTests = nil }() got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { @@ -1213,6 +1230,9 @@ code=114 header=map[Foo-Bar:[114]] } func TestTransportReceiveUndeclaredTrailer(t *testing.T) { + synctestTest(t, testTransportReceiveUndeclaredTrailer) +} +func testTransportReceiveUndeclaredTrailer(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -1280,6 +1300,11 @@ func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { } func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) { + synctestTest(t, func(t testing.TB) { + testInvalidTrailerBubble(t, mode, wantErr, trailers...) + }) +} +func testInvalidTrailerBubble(t testing.TB, mode headerType, wantErr error, trailers ...string) { tc := newTestClientConn(t) tc.greet() @@ -1334,7 +1359,7 @@ func headerListSize(h http.Header) (size uint32) { // space for an empty "Pad-Headers" key, then adds as many copies of // filler as possible. Any remaining bytes necessary to push the // header list size up to limit are added to h["Pad-Headers"]. -func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) { +func padHeaders(t testing.TB, h http.Header, limit uint64, filler string) { if limit > 0xffffffff { t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit) } @@ -1427,61 +1452,35 @@ func TestPadHeaders(t *testing.T) { } func TestTransportChecksRequestHeaderListSize(t *testing.T) { - ts := newTestServer(t, - func(w http.ResponseWriter, r *http.Request) { - // Consume body & force client to send - // trailers before writing response. - // io.ReadAll returns non-nil err for - // requests that attempt to send greater than - // maxHeaderListSize bytes of trailers, since - // those requests generate a stream reset. - io.ReadAll(r.Body) - r.Body.Close() - }, - func(ts *httptest.Server) { - ts.Config.MaxHeaderBytes = 16 << 10 - }, - optQuiet, - ) + synctestTest(t, testTransportChecksRequestHeaderListSize) +} +func testTransportChecksRequestHeaderListSize(t testing.TB) { + const peerSize = 16 << 10 - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() + tc := newTestClientConn(t) + tc.greet(Setting{SettingMaxHeaderListSize, peerSize}) checkRoundTrip := func(req *http.Request, wantErr error, desc string) { - // Make an arbitrary request to ensure we get the server's - // settings frame and initialize peerMaxHeaderListSize. - req0, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatalf("newRequest: NewRequest: %v", err) - } - res0, err := tr.RoundTrip(req0) - if err != nil { - t.Errorf("%v: Initial RoundTrip err = %v", desc, err) - } - res0.Body.Close() - - res, err := tr.RoundTrip(req) - if !errors.Is(err, wantErr) { - if res != nil { - res.Body.Close() - } - t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr) - return - } - if err == nil { - if res == nil { - t.Errorf("%v: response nil; want non-nil.", desc) - return - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK) + t.Helper() + rt := tc.roundTrip(req) + if wantErr != nil { + if err := rt.err(); !errors.Is(err, wantErr) { + t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr) } return } - if res != nil { - t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err) - } + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + rt.wantStatus(http.StatusOK) } headerListSizeForRequest := func(req *http.Request) (size uint64) { const addGzipHeader = true @@ -1501,56 +1500,15 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { newRequest := func() *http.Request { // Body must be non-nil to enable writing trailers. body := strings.NewReader("hello") - req, err := http.NewRequest("POST", ts.URL, body) + req, err := http.NewRequest("POST", "https://example.tld/", body) if err != nil { t.Fatalf("newRequest: NewRequest: %v", err) } return req } - var ( - scMu sync.Mutex - sc *serverConn - ) - testHookGetServerConn = func(v *serverConn) { - scMu.Lock() - defer scMu.Unlock() - if sc != nil { - panic("testHookGetServerConn called multiple times") - } - sc = v - } - defer func() { - testHookGetServerConn = nil - }() - - // Validate peerMaxHeaderListSize. - req := newRequest() - checkRoundTrip(req, nil, "Initial request") - addr := authorityAddr(req.URL.Scheme, req.URL.Host) - cc, err := tr.connPool().GetClientConn(req, addr) - if err != nil { - t.Fatalf("GetClientConn: %v", err) - } - cc.mu.Lock() - peerSize := cc.peerMaxHeaderListSize - cc.mu.Unlock() - scMu.Lock() - wantSize := uint64(sc.maxHeaderListSize()) - scMu.Unlock() - if peerSize != wantSize { - t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize) - } - - // Sanity check peerSize. (*serverConn) maxHeaderListSize adds - // 320 bytes of padding. - wantHeaderBytes := uint64(ts.Config.MaxHeaderBytes) + 320 - if peerSize != wantHeaderBytes { - t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes) - } - // Pad headers & trailers, but stay under peerSize. - req = newRequest() + req := newRequest() req.Header = make(http.Header) req.Trailer = make(http.Header) filler := strings.Repeat("*", 1024) @@ -1588,6 +1546,9 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } func TestTransportChecksResponseHeaderListSize(t *testing.T) { + synctestTest(t, testTransportChecksResponseHeaderListSize) +} +func testTransportChecksResponseHeaderListSize(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -1633,7 +1594,8 @@ func TestTransportChecksResponseHeaderListSize(t *testing.T) { } } -func TestTransportCookieHeaderSplit(t *testing.T) { +func TestTransportCookieHeaderSplit(t *testing.T) { synctestTest(t, testTransportCookieHeaderSplit) } +func testTransportCookieHeaderSplit(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -1862,13 +1824,17 @@ func isTimeout(err error) bool { // Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent. func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) { - testTransportResponseHeaderTimeout(t, false) + synctestTest(t, func(t testing.TB) { + testTransportResponseHeaderTimeout(t, false) + }) } func TestTransportResponseHeaderTimeout_Body(t *testing.T) { - testTransportResponseHeaderTimeout(t, true) + synctestTest(t, func(t testing.TB) { + testTransportResponseHeaderTimeout(t, true) + }) } -func testTransportResponseHeaderTimeout(t *testing.T, body bool) { +func testTransportResponseHeaderTimeout(t testing.TB, body bool) { const bodySize = 4 << 20 tc := newTestClientConn(t, func(tr *Transport) { tr.t1 = &http.Transport{ @@ -1904,11 +1870,11 @@ func testTransportResponseHeaderTimeout(t *testing.T, body bool) { }) } - tc.advance(4 * time.Millisecond) + time.Sleep(4 * time.Millisecond) if rt.done() { t.Fatalf("RoundTrip is done after 4ms; want still waiting") } - tc.advance(1 * time.Millisecond) + time.Sleep(1 * time.Millisecond) if err := rt.err(); !isTimeout(err) { t.Fatalf("RoundTrip error: %v; want timeout error", err) @@ -2304,7 +2270,8 @@ func TestTransportNewTLSConfig(t *testing.T) { // The Google GFE responds to HEAD requests with a HEADERS frame // without END_STREAM, followed by a 0-length DATA frame with // END_STREAM. Make sure we don't get confused by that. (We did.) -func TestTransportReadHeadResponse(t *testing.T) { +func TestTransportReadHeadResponse(t *testing.T) { synctestTest(t, testTransportReadHeadResponse) } +func testTransportReadHeadResponse(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -2331,6 +2298,9 @@ func TestTransportReadHeadResponse(t *testing.T) { } func TestTransportReadHeadResponseWithBody(t *testing.T) { + synctestTest(t, testTransportReadHeadResponseWithBody) +} +func testTransportReadHeadResponseWithBody(t testing.TB) { // This test uses an invalid response format. // Discard logger output to not spam tests output. log.SetOutput(io.Discard) @@ -2371,101 +2341,102 @@ func (b neverEnding) Read(p []byte) (int, error) { return len(p), nil } -// golang.org/issue/15425: test that a handler closing the request -// body doesn't terminate the stream to the peer. (It just stops -// readability from the handler's side, and eventually the client -// runs out of flow control tokens) -func TestTransportHandlerBodyClose(t *testing.T) { - const bodySize = 10 << 20 - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - r.Body.Close() - io.Copy(w, io.LimitReader(neverEnding('A'), bodySize)) - }) - - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() +// #15425: Transport goroutine leak while the transport is still trying to +// write its body after the stream has completed. +func TestTransportStreamEndsWhileBodyIsBeingWritten(t *testing.T) { + synctestTest(t, testTransportStreamEndsWhileBodyIsBeingWritten) +} +func testTransportStreamEndsWhileBodyIsBeingWritten(t testing.TB) { + body := "this is the client request body" + const windowSize = 10 // less than len(body) - g0 := runtime.NumGoroutine() + tc := newTestClientConn(t) + tc.greet(Setting{SettingInitialWindowSize, windowSize}) - const numReq = 10 - for i := 0; i < numReq; i++ { - req, err := http.NewRequest("POST", ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - n, err := io.Copy(io.Discard, res.Body) - res.Body.Close() - if n != bodySize || err != nil { - t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize) - } - } - tr.CloseIdleConnections() + // Client sends a request, and as much body as fits into the stream window. + req, _ := http.NewRequest("PUT", "https://dummy.tld/", strings.NewReader(body)) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: false, + size: windowSize, + }) - if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool { - gd := runtime.NumGoroutine() - g0 - return gd < numReq/2 - }) { - t.Errorf("appeared to leak goroutines") - } + // Server responds without permitting the rest of the body to be sent. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "413", + ), + }) + rt.wantStatus(413) } -// https://golang.org/issue/15930 -func TestTransportFlowControl(t *testing.T) { - const bufLen = 64 << 10 - var total int64 = 100 << 20 // 100MB - if testing.Short() { - total = 10 << 20 - } - - var wrote int64 // updated atomically - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - b := make([]byte, bufLen) - for wrote < total { - n, err := w.Write(b) - atomic.AddInt64(&wrote, int64(n)) - if err != nil { - t.Errorf("ResponseWriter.Write error: %v", err) - break - } - w.(http.Flusher).Flush() +func TestTransportFlowControl(t *testing.T) { synctestTest(t, testTransportFlowControl) } +func testTransportFlowControl(t testing.TB) { + const maxBuffer = 64 << 10 // 64KiB + tc := newTestClientConn(t, func(tr *http.Transport) { + tr.HTTP2 = &http.HTTP2Config{ + MaxReceiveBufferPerConnection: maxBuffer, + MaxReceiveBufferPerStream: maxBuffer, + MaxReadFrameSize: 16 << 20, // 16MiB } }) + tc.greet() - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal("NewRequest error:", err) - } - resp, err := tr.RoundTrip(req) - if err != nil { - t.Fatal("RoundTrip error:", err) - } - defer resp.Body.Close() + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) - var read int64 - b := make([]byte, bufLen) + // Server fills up its transmit buffer. + // The client does not provide more flow control tokens, + // since the data hasn't been consumed by the user. + tc.writeData(rt.streamID(), false, make([]byte, maxBuffer)) + tc.wantIdle() + + // User reads data from the response body. + // The client sends more flow control tokens. + resp := rt.response() + if _, err := io.ReadFull(resp.Body, make([]byte, maxBuffer)); err != nil { + t.Fatalf("io.Body.Read: %v", err) + } + var connTokens, streamTokens uint32 for { - n, err := resp.Body.Read(b) - if err == io.EOF { + f := tc.readFrame() + if f == nil { break } - if err != nil { - t.Fatal("Read error:", err) + wu, ok := f.(*WindowUpdateFrame) + if !ok { + t.Fatalf("received unexpected frame %T (want WINDOW_UPDATE)", f) } - read += int64(n) - - const max = transportDefaultStreamFlow - if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max { - t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read) + switch wu.StreamID { + case 0: + connTokens += wu.Increment + case wu.StreamID: + streamTokens += wu.Increment + default: + t.Fatalf("received unexpected WINDOW_UPDATE for stream %v", wu.StreamID) } - - // Let the server get ahead of the client. - time.Sleep(1 * time.Millisecond) + } + if got, want := connTokens, uint32(maxBuffer); got != want { + t.Errorf("transport provided %v bytes of connection WINDOW_UPDATE, want %v", got, want) + } + if got, want := streamTokens, uint32(maxBuffer); got != want { + t.Errorf("transport provided %v bytes of stream WINDOW_UPDATE, want %v", got, want) } } @@ -2475,14 +2446,18 @@ func TestTransportFlowControl(t *testing.T) { // proceeds to close the TCP connection before the client gets its // response) func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) { - testTransportUsesGoAwayDebugError(t, false) + synctestTest(t, func(t testing.TB) { + testTransportUsesGoAwayDebugError(t, false) + }) } func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { - testTransportUsesGoAwayDebugError(t, true) + synctestTest(t, func(t testing.TB) { + testTransportUsesGoAwayDebugError(t, true) + }) } -func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { +func testTransportUsesGoAwayDebugError(t testing.TB, failMidBody bool) { tc := newTestClientConn(t) tc.greet() @@ -2532,7 +2507,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { } } -func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { +func testTransportReturnsUnusedFlowControl(t testing.TB, oneDataFrame bool) { tc := newTestClientConn(t) tc.greet() @@ -2573,7 +2548,7 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { t.Fatalf("body read = %v, %v; want 1, nil", n, err) } res.Body.Close() // leaving 4999 bytes unread - tc.sync() + synctest.Wait() sentAdditionalData := false tc.wantUnorderedFrames( @@ -2588,9 +2563,6 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { } return true }, - func(f *PingFrame) bool { - return true - }, func(f *WindowUpdateFrame) bool { if !oneDataFrame && !sentAdditionalData { t.Fatalf("Got WindowUpdateFrame, don't expect one yet") @@ -2609,17 +2581,22 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { // See golang.org/issue/16481 func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) { - testTransportReturnsUnusedFlowControl(t, true) + synctestTest(t, func(t testing.TB) { + testTransportReturnsUnusedFlowControl(t, true) + }) } // See golang.org/issue/20469 func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { - testTransportReturnsUnusedFlowControl(t, false) + synctestTest(t, func(t testing.TB) { + testTransportReturnsUnusedFlowControl(t, false) + }) } // Issue 16612: adjust flow control on open streams when transport // receives SETTINGS with INITIAL_WINDOW_SIZE from server. -func TestTransportAdjustsFlowControl(t *testing.T) { +func TestTransportAdjustsFlowControl(t *testing.T) { synctestTest(t, testTransportAdjustsFlowControl) } +func testTransportAdjustsFlowControl(t testing.TB) { const bodySize = 1 << 20 tc := newTestClientConn(t) @@ -2676,6 +2653,9 @@ func TestTransportAdjustsFlowControl(t *testing.T) { // See golang.org/issue/16556 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { + synctestTest(t, testTransportReturnsDataPaddingFlowControl) +} +func testTransportReturnsDataPaddingFlowControl(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -2711,6 +2691,9 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a // StreamError as a result of the response HEADERS func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { + synctestTest(t, testTransportReturnsErrorOnBadResponseHeaders) +} +func testTransportReturnsErrorOnBadResponseHeaders(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -2762,6 +2745,9 @@ func (b byteAndEOFReader) Read(p []byte) (n int, err error) { // which returns (non-0, io.EOF) and also needs to set the ContentLength // explicitly. func TestTransportBodyDoubleEndStream(t *testing.T) { + synctestTest(t, testTransportBodyDoubleEndStream) +} +func testTransportBodyDoubleEndStream(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { // Nothing. }) @@ -2916,17 +2902,20 @@ func TestTransportRequestPathPseudo(t *testing.T) { // golang.org/issue/17071 -- don't sniff the first byte of the request body // before we've determined that the ClientConn is usable. func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { + synctestTest(t, testRoundTripDoesntConsumeRequestBodyEarly) +} +func testRoundTripDoesntConsumeRequestBodyEarly(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + tc.closeWrite() + const body = "foo" req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body))) - cc := &ClientConn{ - closed: true, - reqHeaderMu: make(chan struct{}, 1), - t: &Transport{}, - } - _, err := cc.RoundTrip(req) - if err != errClientConnUnusable { - t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err) + rt := tc.roundTrip(req) + if err := rt.err(); err != errClientConnNotEstablished { + t.Fatalf("RoundTrip = %v; want errClientConnNotEstablished", err) } + slurp, err := io.ReadAll(req.Body) if err != nil { t.Errorf("ReadAll = %v", err) @@ -3031,7 +3020,8 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { req.Header = http.Header{} } -func TestTransportCloseAfterLostPing(t *testing.T) { +func TestTransportCloseAfterLostPing(t *testing.T) { synctestTest(t, testTransportCloseAfterLostPing) } +func testTransportCloseAfterLostPing(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.PingTimeout = 1 * time.Second tr.ReadIdleTimeout = 1 * time.Second @@ -3042,10 +3032,10 @@ func TestTransportCloseAfterLostPing(t *testing.T) { rt := tc.roundTrip(req) tc.wantFrameType(FrameHeaders) - tc.advance(1 * time.Second) + time.Sleep(1 * time.Second) tc.wantFrameType(FramePing) - tc.advance(1 * time.Second) + time.Sleep(1 * time.Second) err := rt.err() if err == nil || !strings.Contains(err.Error(), "client connection lost") { t.Fatalf("expected to get error about \"connection lost\", got %v", err) @@ -3081,6 +3071,9 @@ func TestTransportPingWriteBlocks(t *testing.T) { } func TestTransportPingWhenReadingMultiplePings(t *testing.T) { + synctestTest(t, testTransportPingWhenReadingMultiplePings) +} +func testTransportPingWhenReadingMultiplePings(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.ReadIdleTimeout = 1000 * time.Millisecond }) @@ -3102,20 +3095,20 @@ func TestTransportPingWhenReadingMultiplePings(t *testing.T) { for i := 0; i < 5; i++ { // No ping yet... - tc.advance(999 * time.Millisecond) + time.Sleep(999 * time.Millisecond) if f := tc.readFrame(); f != nil { t.Fatalf("unexpected frame: %v", f) } // ...ping now. - tc.advance(1 * time.Millisecond) + time.Sleep(1 * time.Millisecond) f := readFrame[*PingFrame](t, tc) tc.writePing(true, f.Data) } // Cancel the request, Transport resets it and returns an error from body reads. cancel() - tc.sync() + synctest.Wait() tc.wantFrameType(FrameRSTStream) _, err := rt.readBody() @@ -3125,6 +3118,9 @@ func TestTransportPingWhenReadingMultiplePings(t *testing.T) { } func TestTransportPingWhenReadingPingDisabled(t *testing.T) { + synctestTest(t, testTransportPingWhenReadingPingDisabled) +} +func testTransportPingWhenReadingPingDisabled(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.ReadIdleTimeout = 0 // PINGs disabled }) @@ -3144,13 +3140,16 @@ func TestTransportPingWhenReadingPingDisabled(t *testing.T) { }) // No PING is sent, even after a long delay. - tc.advance(1 * time.Minute) + time.Sleep(1 * time.Minute) if f := tc.readFrame(); f != nil { t.Fatalf("unexpected frame: %v", f) } } func TestTransportRetryAfterGOAWAYNoRetry(t *testing.T) { + synctestTest(t, testTransportRetryAfterGOAWAYNoRetry) +} +func testTransportRetryAfterGOAWAYNoRetry(t testing.TB) { tt := newTestTransport(t) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) @@ -3175,6 +3174,9 @@ func TestTransportRetryAfterGOAWAYNoRetry(t *testing.T) { } func TestTransportRetryAfterGOAWAYRetry(t *testing.T) { + synctestTest(t, testTransportRetryAfterGOAWAYRetry) +} +func testTransportRetryAfterGOAWAYRetry(t testing.TB) { tt := newTestTransport(t) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) @@ -3219,6 +3221,9 @@ func TestTransportRetryAfterGOAWAYRetry(t *testing.T) { } func TestTransportRetryAfterGOAWAYSecondRequest(t *testing.T) { + synctestTest(t, testTransportRetryAfterGOAWAYSecondRequest) +} +func testTransportRetryAfterGOAWAYSecondRequest(t testing.TB) { tt := newTestTransport(t) // First request succeeds. @@ -3282,6 +3287,9 @@ func TestTransportRetryAfterGOAWAYSecondRequest(t *testing.T) { } func TestTransportRetryAfterRefusedStream(t *testing.T) { + synctestTest(t, testTransportRetryAfterRefusedStream) +} +func testTransportRetryAfterRefusedStream(t testing.TB) { tt := newTestTransport(t) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) @@ -3320,20 +3328,21 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { rt.wantStatus(204) } -func TestTransportRetryHasLimit(t *testing.T) { +func TestTransportRetryHasLimit(t *testing.T) { synctestTest(t, testTransportRetryHasLimit) } +func testTransportRetryHasLimit(t testing.TB) { tt := newTestTransport(t) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) rt := tt.roundTrip(req) - // First attempt: Server sends a GOAWAY. tc := tt.getConn() + tc.netconn.SetReadDeadline(time.Time{}) tc.wantFrameType(FrameSettings) tc.wantFrameType(FrameWindowUpdate) - var totalDelay time.Duration count := 0 - for streamID := uint32(1); ; streamID += 2 { + start := time.Now() + for streamID := uint32(1); !rt.done(); streamID += 2 { count++ tc.wantHeaders(wantHeader{ streamID: streamID, @@ -3345,18 +3354,9 @@ func TestTransportRetryHasLimit(t *testing.T) { } tc.writeRSTStream(streamID, ErrCodeRefusedStream) - d, scheduled := tt.group.TimeUntilEvent() - if !scheduled { - if streamID == 1 { - continue - } - break - } - totalDelay += d - if totalDelay > 5*time.Minute { + if totalDelay := time.Since(start); totalDelay > 5*time.Minute { t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay) } - tt.advance(d) } if got, want := count, 5; got < count { t.Errorf("RoundTrip made %v attempts, want at least %v", got, want) @@ -3367,6 +3367,9 @@ func TestTransportRetryHasLimit(t *testing.T) { } func TestTransportResponseDataBeforeHeaders(t *testing.T) { + synctestTest(t, testTransportResponseDataBeforeHeaders) +} +func testTransportResponseDataBeforeHeaders(t testing.TB) { // Discard log output complaining about protocol error. log.SetOutput(io.Discard) t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done @@ -3408,7 +3411,7 @@ func TestTransportMaxFrameReadSize(t *testing.T) { maxReadFrameSize: 1024, want: minMaxFrameSize, }} { - t.Run(fmt.Sprint(test.maxReadFrameSize), func(t *testing.T) { + synctestSubtest(t, fmt.Sprint(test.maxReadFrameSize), func(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.MaxReadFrameSize = test.maxReadFrameSize }) @@ -3470,11 +3473,28 @@ func TestTransportRequestsLowServerLimit(t *testing.T) { // tests Transport.StrictMaxConcurrentStreams func TestTransportRequestsStallAtServerLimit(t *testing.T) { + synctestSubtest(t, "Transport", func(t testing.TB) { + testTransportRequestsStallAtServerLimit(t, func(tr *Transport) { + tr.StrictMaxConcurrentStreams = true + }) + }) + synctestSubtest(t, "HTTP2Config", func(t testing.TB) { + // HTTP2Config.StrictMaxConcurrentRequests was added in Go 1.26. + h2 := &http.HTTP2Config{} + v := reflect.ValueOf(h2).Elem().FieldByName("StrictMaxConcurrentRequests") + if !v.IsValid() { + t.Skip("HTTP2Config does not contain StrictMaxConcurrentRequests") + } + v.SetBool(true) + testTransportRequestsStallAtServerLimit(t, func(tr *http.Transport) { + tr.HTTP2 = h2 + }) + }) +} +func testTransportRequestsStallAtServerLimit(t testing.TB, opt any) { const maxConcurrent = 2 - tc := newTestClientConn(t, func(tr *Transport) { - tr.StrictMaxConcurrentStreams = true - }) + tc := newTestClientConn(t, opt) tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) cancelClientRequest := make(chan struct{}) @@ -3517,7 +3537,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { // Cancel the maxConcurrent'th request. // The request should fail. close(cancelClientRequest) - tc.sync() + synctest.Wait() if err := rts[maxConcurrent].err(); err == nil { t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent) } @@ -3551,6 +3571,9 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { } func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { + synctestTest(t, testTransportMaxDecoderHeaderTableSize) +} +func testTransportMaxDecoderHeaderTableSize(t testing.TB) { var reqSize, resSize uint32 = 8192, 16384 tc := newTestClientConn(t, func(tr *Transport) { tr.MaxDecoderHeaderTableSize = reqSize @@ -3572,6 +3595,9 @@ func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { } func TestTransportMaxEncoderHeaderTableSize(t *testing.T) { + synctestTest(t, testTransportMaxEncoderHeaderTableSize) +} +func testTransportMaxEncoderHeaderTableSize(t testing.TB) { var peerAdvertisedMaxHeaderTableSize uint32 = 16384 tc := newTestClientConn(t, func(tr *Transport) { tr.MaxEncoderHeaderTableSize = 8192 @@ -3610,59 +3636,52 @@ func TestAuthorityAddr(t *testing.T) { // Issue 20448: stop allocating for DATA frames' payload after // Response.Body.Close is called. func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { - megabyteZero := make([]byte, 1<<20) + synctestTest(t, testTransportAllocationsAfterResponseBodyClose) +} +func testTransportAllocationsAfterResponseBodyClose(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() - writeErr := make(chan error, 1) + // Send request. + req, _ := http.NewRequest("PUT", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - w.(http.Flusher).Flush() - var sum int64 - for i := 0; i < 100; i++ { - n, err := w.Write(megabyteZero) - sum += int64(n) - if err != nil { - writeErr <- err - return - } - } - t.Logf("wrote all %d bytes", sum) - writeErr <- nil + // Receive response with some body. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), }) + tc.writeData(rt.streamID(), false, make([]byte, 64)) + tc.wantIdle() - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } + // Client reads a byte of the body, and then closes it. + respBody := rt.response().Body var buf [1]byte - if _, err := res.Body.Read(buf[:]); err != nil { + if _, err := respBody.Read(buf[:]); err != nil { t.Error(err) } - if err := res.Body.Close(); err != nil { + if err := respBody.Close(); err != nil { t.Error(err) } + tc.wantFrameType(FrameRSTStream) - trb, ok := res.Body.(transportResponseBody) - if !ok { - t.Fatalf("res.Body = %T; want transportResponseBody", res.Body) - } - if trb.cs.bufPipe.b != nil { - t.Errorf("response body pipe is still open") - } + // Server sends more of the body, which is ignored. + tc.writeData(rt.streamID(), false, make([]byte, 64)) - gotErr := <-writeErr - if gotErr == nil { - t.Errorf("Handler unexpectedly managed to write its entire response without getting an error") - } else if gotErr != errStreamClosed { - t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr) + if _, err := respBody.Read(buf[:]); err == nil { + t.Error("read from closed body unexpectedly succeeded") } } // Issue 18891: make sure Request.Body == NoBody means no DATA frame // is ever sent, even if empty. -func TestTransportNoBodyMeansNoDATA(t *testing.T) { +func TestTransportNoBodyMeansNoDATA(t *testing.T) { synctestTest(t, testTransportNoBodyMeansNoDATA) } +func testTransportNoBodyMeansNoDATA(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -3756,6 +3775,9 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { // Verify transport doesn't crash when receiving bogus response lacking a :status header. // Issue 22880. func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { + synctestTest(t, testTransportHandlesInvalidStatuslessResponse) +} +func testTransportHandlesInvalidStatuslessResponse(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -3842,172 +3864,53 @@ func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) { } } -func activeStreams(cc *ClientConn) int { - count := 0 - cc.mu.Lock() - defer cc.mu.Unlock() - for _, cs := range cc.streams { - select { - case <-cs.abort: - default: - count++ - } - } - return count -} +func BenchmarkClientGzip(b *testing.B) { + disableGoroutineTracking(b) + b.ReportAllocs() -type closeMode int + const responseSize = 1024 * 1024 -const ( - closeAtHeaders closeMode = iota - closeAtBody - shutdown - shutdownCancel -) + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + if _, err := io.CopyN(gz, crand.Reader, responseSize); err != nil { + b.Fatal(err) + } + gz.Close() + + data := buf.Bytes() + ts := newTestServer(b, + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write(data) + }, + optQuiet, + ) -// See golang.org/issue/17292 -func testClientConnClose(t *testing.T, closeMode closeMode) { - clientDone := make(chan struct{}) - defer close(clientDone) - handlerDone := make(chan struct{}) - closeDone := make(chan struct{}) - beforeHeader := func() {} - bodyWrite := func(w http.ResponseWriter) {} - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - defer close(handlerDone) - beforeHeader() - w.WriteHeader(http.StatusOK) - w.(http.Flusher).Flush() - bodyWrite(w) - select { - case <-w.(http.CloseNotifier).CloseNotify(): - // client closed connection before completion - if closeMode == shutdown || closeMode == shutdownCancel { - t.Error("expected request to complete") - } - case <-clientDone: - if closeMode == closeAtHeaders || closeMode == closeAtBody { - t.Error("expected connection closed by client") - } - } - }) tr := &Transport{TLSClientConfig: tlsConfigInsecure} defer tr.CloseIdleConnections() - ctx := context.Background() - cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false) + req, err := http.NewRequest("GET", ts.URL, nil) if err != nil { - t.Fatal(err) - } - if closeMode == closeAtHeaders { - beforeHeader = func() { - if err := cc.Close(); err != nil { - t.Error(err) - } - close(closeDone) - } - } - var sendBody chan struct{} - if closeMode == closeAtBody { - sendBody = make(chan struct{}) - bodyWrite = func(w http.ResponseWriter) { - <-sendBody - b := make([]byte, 32) - w.Write(b) - w.(http.Flusher).Flush() - if err := cc.Close(); err != nil { - t.Errorf("unexpected ClientConn close error: %v", err) - } - close(closeDone) - w.Write(b) - w.(http.Flusher).Flush() - } - } - res, err := cc.RoundTrip(req) - if res != nil { - defer res.Body.Close() + b.Fatal(err) } - if closeMode == closeAtHeaders { - got := fmt.Sprint(err) - want := "http2: client connection force closed via ClientConn.Close" - if got != want { - t.Fatalf("RoundTrip error = %v, want %v", got, want) - } - } else { + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + res, err := tr.RoundTrip(req) if err != nil { - t.Fatalf("RoundTrip: %v", err) - } - if got, want := activeStreams(cc), 1; got != want { - t.Errorf("got %d active streams, want %d", got, want) - } - } - switch closeMode { - case shutdownCancel: - if err = cc.Shutdown(canceledCtx); err != context.Canceled { - t.Errorf("got %v, want %v", err, context.Canceled) - } - if cc.closing == false { - t.Error("expected closing to be true") - } - if cc.CanTakeNewRequest() == true { - t.Error("CanTakeNewRequest to return false") - } - if v, want := len(cc.streams), 1; v != want { - t.Errorf("expected %d active streams, got %d", want, v) - } - clientDone <- struct{}{} - <-handlerDone - case shutdown: - wait := make(chan struct{}) - shutdownEnterWaitStateHook = func() { - close(wait) - shutdownEnterWaitStateHook = func() {} - } - defer func() { shutdownEnterWaitStateHook = func() {} }() - shutdown := make(chan struct{}, 1) - go func() { - if err = cc.Shutdown(context.Background()); err != nil { - t.Error(err) - } - close(shutdown) - }() - // Let the shutdown to enter wait state - <-wait - cc.mu.Lock() - if cc.closing == false { - t.Error("expected closing to be true") - } - cc.mu.Unlock() - if cc.CanTakeNewRequest() == true { - t.Error("CanTakeNewRequest to return false") - } - if got, want := activeStreams(cc), 1; got != want { - t.Errorf("got %d active streams, want %d", got, want) - } - // Let the active request finish - clientDone <- struct{}{} - // Wait for the shutdown to end - select { - case <-shutdown: - case <-time.After(2 * time.Second): - t.Fatal("expected server connection to close") + b.Fatalf("RoundTrip err = %v; want nil", err) } - case closeAtHeaders, closeAtBody: - if closeMode == closeAtBody { - go close(sendBody) - if _, err := io.Copy(io.Discard, res.Body); err == nil { - t.Error("expected a Copy error, got nil") - } + if res.StatusCode != http.StatusOK { + b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) } - <-closeDone - if got, want := activeStreams(cc), 0; got != want { - t.Errorf("got %d active streams, want %d", got, want) + n, err := io.Copy(io.Discard, res.Body) + res.Body.Close() + if err != nil { + b.Fatalf("RoundTrip err = %v; want nil", err) } - // wait for server to get the connection close notice - select { - case <-handlerDone: - case <-time.After(2 * time.Second): - t.Fatal("expected server connection to close") + if n != responseSize { + b.Fatalf("RoundTrip expected %d bytes, got %d", responseSize, n) } } } @@ -4015,27 +3918,125 @@ func testClientConnClose(t *testing.T, closeMode closeMode) { // The client closes the connection just after the server got the client's HEADERS // frame, but before the server sends its HEADERS response back. The expected // result is an error on RoundTrip explaining the client closed the connection. -func TestClientConnCloseAtHeaders(t *testing.T) { - testClientConnClose(t, closeAtHeaders) +func TestClientConnCloseAtHeaders(t *testing.T) { synctestTest(t, testClientConnCloseAtHeaders) } +func testClientConnCloseAtHeaders(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + tc.cc.Close() + synctest.Wait() + if err := rt.err(); err != errClientConnForceClosed { + t.Fatalf("RoundTrip error = %v, want errClientConnForceClosed", err) + } } -// The client closes the connection between two server's response DATA frames. +// The client closes the connection while reading the response. // The expected behavior is a response body io read error on the client. -func TestClientConnCloseAtBody(t *testing.T) { - testClientConnClose(t, closeAtBody) +func TestClientConnCloseAtBody(t *testing.T) { synctestTest(t, testClientConnCloseAtBody) } +func testClientConnCloseAtBody(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.writeData(rt.streamID(), false, make([]byte, 64)) + tc.cc.Close() + synctest.Wait() + + if _, err := io.Copy(io.Discard, rt.response().Body); err == nil { + t.Error("expected a Copy error, got nil") + } } // The client sends a GOAWAY frame before the server finished processing a request. // We expect the connection not to close until the request is completed. -func TestClientConnShutdown(t *testing.T) { - testClientConnClose(t, shutdown) +func TestClientConnShutdown(t *testing.T) { synctestTest(t, testClientConnShutdown) } +func testClientConnShutdown(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + go tc.cc.Shutdown(context.Background()) + synctest.Wait() + + tc.wantFrameType(FrameGoAway) + tc.wantIdle() // connection is not closed + body := []byte("body") + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.writeData(rt.streamID(), true, body) + + rt.wantStatus(200) + rt.wantBody(body) + + // Now that the client has received the response, it closes the connection. + tc.wantClosed() } // The client sends a GOAWAY frame before the server finishes processing a request, // but cancels the passed context before the request is completed. The expected // behavior is the client closing the connection after the context is canceled. -func TestClientConnShutdownCancel(t *testing.T) { - testClientConnClose(t, shutdownCancel) +func TestClientConnShutdownCancel(t *testing.T) { synctestTest(t, testClientConnShutdownCancel) } +func testClientConnShutdownCancel(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + ctx, cancel := context.WithCancel(t.Context()) + var shutdownErr error + go func() { + shutdownErr = tc.cc.Shutdown(ctx) + }() + synctest.Wait() + + tc.wantFrameType(FrameGoAway) + tc.wantIdle() // connection is not closed + + cancel() + synctest.Wait() + + if shutdownErr != context.Canceled { + t.Fatalf("ClientConn.Shutdown(ctx) did not return context.Canceled after cancelling context") + } + + // The documentation for this test states: + // The expected behavior is the client closing the connection + // after the context is canceled. + // + // This seems reasonable, but it isn't what we do. + // When ClientConn.Shutdown's context is canceled, Shutdown returns but + // the connection is not closed. + // + // TODO: Figure out the correct behavior. + if rt.done() { + t.Fatal("RoundTrip unexpectedly returned during shutdown") + } } // Issue 25009: use Request.GetBody if present, even if it seems like @@ -4117,6 +4118,11 @@ func (r *errReader) Read(p []byte) (int, error) { } func testTransportBodyReadError(t *testing.T, body []byte) { + synctestTest(t, func(t testing.TB) { + testTransportBodyReadErrorBubble(t, body) + }) +} +func testTransportBodyReadErrorBubble(t testing.TB, body []byte) { tc := newTestClientConn(t) tc.greet() @@ -4149,10 +4155,6 @@ readFrames: if err := rt.err(); err != bodyReadError { t.Fatalf("err = %v; want %v", err, bodyReadError) } - - if got := activeStreams(tc.cc); got != 0 { - t.Fatalf("active streams count: %v; want 0", got) - } } func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } @@ -4161,7 +4163,8 @@ func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyRea // Issue 32254: verify that the client sends END_STREAM flag eagerly with the last // (or in this test-case the only one) request body data frame, and does not send // extra zero-len data frames. -func TestTransportBodyEagerEndStream(t *testing.T) { +func TestTransportBodyEagerEndStream(t *testing.T) { synctestTest(t, testTransportBodyEagerEndStream) } +func testTransportBodyEagerEndStream(t testing.TB) { const reqBody = "some request body" const resBody = "some response body" @@ -4205,17 +4208,21 @@ func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) { []byte("123"), []byte("456"), }} - testTransportBodyLargerThanSpecifiedContentLength(t, body, 3) + synctestTest(t, func(t testing.TB) { + testTransportBodyLargerThanSpecifiedContentLength(t, body, 3) + }) } func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) { body := &chunkReader{[][]byte{ []byte("123"), }} - testTransportBodyLargerThanSpecifiedContentLength(t, body, 2) + synctestTest(t, func(t testing.TB) { + testTransportBodyLargerThanSpecifiedContentLength(t, body, 2) + }) } -func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) { +func testTransportBodyLargerThanSpecifiedContentLength(t testing.TB, body *chunkReader, contentLen int64) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { r.Body.Read(make([]byte, 6)) }) @@ -4299,35 +4306,28 @@ func TestTransportNewClientConnCloseOnWriteError(t *testing.T) { } func TestTransportRoundtripCloseOnWriteError(t *testing.T) { - req, err := http.NewRequest("GET", "https://dummy.tld/", nil) - if err != nil { - t.Fatal(err) - } - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {}) + synctestTest(t, testTransportRoundtripCloseOnWriteError) +} +func testTransportRoundtripCloseOnWriteError(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - ctx := context.Background() - cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false) - if err != nil { - t.Fatal(err) - } + body := tc.newRequestBody() + body.writeBytes(1) + req, _ := http.NewRequest("GET", "https://dummy.tld/", body) + rt := tc.roundTrip(req) writeErr := errors.New("write error") - cc.wmu.Lock() - cc.werr = writeErr - cc.wmu.Unlock() + tc.closeWriteWithError(writeErr) - _, err = cc.RoundTrip(req) - if err != writeErr { - t.Fatalf("expected %v, got %v", writeErr, err) + body.writeBytes(1) + if err := rt.err(); err != writeErr { + t.Fatalf("RoundTrip error %v, want %v", err, writeErr) } - cc.mu.Lock() - closed := cc.closed - cc.mu.Unlock() - if !closed { - t.Fatal("expected closed") + rt2 := tc.roundTrip(req) + if err := rt2.err(); err != errClientConnUnusable { + t.Fatalf("RoundTrip error %v, want errClientConnUnusable", err) } } @@ -4360,7 +4360,7 @@ func TestTransportBodyRewindRace(t *testing.T) { for i := 0; i < clients; i++ { req, err := http.NewRequest("POST", ts.URL, bytes.NewBufferString("abcdef")) if err != nil { - t.Fatalf("unexpect new request error: %v", err) + t.Fatalf("unexpected new request error: %v", err) } go func() { @@ -4399,7 +4399,7 @@ func TestTransportServerResetStreamAtHeaders(t *testing.T) { req, err := http.NewRequest("POST", ts.URL, errorReader{io.EOF}) if err != nil { - t.Fatalf("unexpect new request error: %v", err) + t.Fatalf("unexpected new request error: %v", err) } req.ContentLength = 0 // so transport is tempted to sniff it req.Header.Set("Expect", "100-continue") @@ -4818,6 +4818,9 @@ func TestTransportCloseRequestBody(t *testing.T) { } func TestTransportRetriesOnStreamProtocolError(t *testing.T) { + synctestTest(t, testTransportRetriesOnStreamProtocolError) +} +func testTransportRetriesOnStreamProtocolError(t testing.TB) { // This test verifies that // - receiving a protocol error on a connection does not interfere with // other requests in flight on that connection; @@ -4893,7 +4896,8 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) { rt1.wantStatus(200) } -func TestClientConnReservations(t *testing.T) { +func TestClientConnReservations(t *testing.T) { synctestTest(t, testClientConnReservations) } +func testClientConnReservations(t testing.TB) { tc := newTestClientConn(t) tc.greet( Setting{ID: SettingMaxConcurrentStreams, Val: initialMaxConcurrentStreams}, @@ -4944,7 +4948,8 @@ func TestClientConnReservations(t *testing.T) { } } -func TestTransportTimeoutServerHangs(t *testing.T) { +func TestTransportTimeoutServerHangs(t *testing.T) { synctestTest(t, testTransportTimeoutServerHangs) } +func testTransportTimeoutServerHangs(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -4953,7 +4958,7 @@ func TestTransportTimeoutServerHangs(t *testing.T) { rt := tc.roundTrip(req) tc.wantFrameType(FrameHeaders) - tc.advance(5 * time.Second) + time.Sleep(5 * time.Second) if f := tc.readFrame(); f != nil { t.Fatalf("unexpected frame: %v", f) } @@ -4962,20 +4967,13 @@ func TestTransportTimeoutServerHangs(t *testing.T) { } cancel() - tc.sync() + synctest.Wait() if rt.err() != context.Canceled { t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err()) } } func TestTransportContentLengthWithoutBody(t *testing.T) { - contentLength := "" - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", contentLength) - }) - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - for _, test := range []struct { name string contentLength string @@ -4996,7 +4994,14 @@ func TestTransportContentLengthWithoutBody(t *testing.T) { wantContentLength: 0, }, } { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t testing.TB) { + contentLength := "" + ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", contentLength) + }) + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + contentLength = test.contentLength req, _ := http.NewRequest("GET", ts.URL, nil) @@ -5021,6 +5026,9 @@ func TestTransportContentLengthWithoutBody(t *testing.T) { } func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { + synctestTest(t, testTransportCloseResponseBodyWhileRequestBodyHangs) +} +func testTransportCloseResponseBodyWhileRequestBodyHangs(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.(http.Flusher).Flush() @@ -5044,7 +5052,8 @@ func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { pw.Close() } -func TestTransport300ResponseBody(t *testing.T) { +func TestTransport300ResponseBody(t *testing.T) { synctestTest(t, testTransport300ResponseBody) } +func testTransport300ResponseBody(t testing.TB) { reqc := make(chan struct{}) body := []byte("response body") ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { @@ -5120,7 +5129,8 @@ func (c *slowWriteConn) Write(b []byte) (n int, err error) { return c.Conn.Write(b) } -func TestTransportSlowWrites(t *testing.T) { +func TestTransportSlowWrites(t *testing.T) { synctestTest(t, testTransportSlowWrites) } +func testTransportSlowWrites(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, ) @@ -5145,10 +5155,14 @@ func TestTransportSlowWrites(t *testing.T) { } func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) { - testTransportClosesConnAfterGoAway(t, 0) + synctestTest(t, func(t testing.TB) { + testTransportClosesConnAfterGoAway(t, 0) + }) } func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { - testTransportClosesConnAfterGoAway(t, 1) + synctestTest(t, func(t testing.TB) { + testTransportClosesConnAfterGoAway(t, 1) + }) } // testTransportClosesConnAfterGoAway verifies that the transport @@ -5157,7 +5171,7 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { // lastStream is the last stream ID in the GOAWAY frame. // When 0, the transport (unsuccessfully) retries the request (stream 1); // when 1, the transport reads the response after receiving the GOAWAY. -func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { +func testTransportClosesConnAfterGoAway(t testing.TB, lastStream uint32) { tc := newTestClientConn(t) tc.greet() @@ -5384,7 +5398,8 @@ func TestDialRaceResumesDial(t *testing.T) { } } -func TestTransportDataAfter1xxHeader(t *testing.T) { +func TestTransportDataAfter1xxHeader(t *testing.T) { synctestTest(t, testTransportDataAfter1xxHeader) } +func testTransportDataAfter1xxHeader(t testing.TB) { // Discard logger output to avoid spamming stderr. log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) @@ -5514,7 +5529,7 @@ func TestTransport1xxLimits(t *testing.T) { hcount: 20, limited: false, }} { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t testing.TB) { tc := newTestClientConn(t, test.opt) tc.greet() @@ -5549,7 +5564,10 @@ func TestTransport1xxLimits(t *testing.T) { } } -func TestTransportSendPingWithReset(t *testing.T) { +// TestTransportSendPingWithReset verifies that when a request to an unresponsive server +// is canceled, it continues to consume a concurrency slot until the server responds to a PING. +func TestTransportSendPingWithReset(t *testing.T) { synctestTest(t, testTransportSendPingWithReset) } +func testTransportSendPingWithReset(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.StrictMaxConcurrentStreams = true }) @@ -5559,7 +5577,7 @@ func TestTransportSendPingWithReset(t *testing.T) { // Start several requests. var rts []*testRoundTrip - for i := 0; i < maxConcurrent+1; i++ { + for i := range maxConcurrent + 1 { req := must(http.NewRequest("GET", "https://dummy.tld/", nil)) rt := tc.roundTrip(req) if i >= maxConcurrent { @@ -5567,25 +5585,17 @@ func TestTransportSendPingWithReset(t *testing.T) { continue } tc.wantFrameType(FrameHeaders) - tc.writeHeaders(HeadersFrameParam{ - StreamID: rt.streamID(), - EndHeaders: true, - BlockFragment: tc.makeHeaderBlockFragment( - ":status", "200", - ), - }) - rt.wantStatus(200) rts = append(rts, rt) } // Cancel one request. We send a PING frame along with the RST_STREAM. - rts[0].response().Body.Close() + rts[0].cancel() tc.wantRSTStream(rts[0].streamID(), ErrCodeCancel) pf := readFrame[*PingFrame](t, tc) tc.wantIdle() // Cancel another request. No PING frame, since one is in flight. - rts[1].response().Body.Close() + rts[1].cancel() tc.wantRSTStream(rts[1].streamID(), ErrCodeCancel) tc.wantIdle() @@ -5594,21 +5604,63 @@ func TestTransportSendPingWithReset(t *testing.T) { tc.writePing(true, pf.Data) tc.wantFrameType(FrameHeaders) tc.wantIdle() +} - // Receive a byte of data for the remaining stream, which resets our ability - // to send pings (see comment on ClientConn.rstStreamPingsBlocked). - tc.writeData(rts[2].streamID(), false, []byte{0}) +// TestTransportNoPingAfterResetWithFrames verifies that when a request to a responsive +// server is canceled (specifically: when frames have been received from the server +// in the time since the request was first sent), the request is immediately canceled and +// does not continue to consume a concurrency slot. +func TestTransportNoPingAfterResetWithFrames(t *testing.T) { + synctestTest(t, testTransportNoPingAfterResetWithFrames) +} +func testTransportNoPingAfterResetWithFrames(t testing.TB) { + tc := newTestClientConn(t, func(tr *Transport) { + tr.StrictMaxConcurrentStreams = true + }) - // Cancel the last request. We send another PING, since none are in flight. - rts[2].response().Body.Close() - tc.wantRSTStream(rts[2].streamID(), ErrCodeCancel) - tc.wantFrameType(FramePing) + const maxConcurrent = 1 + tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent}) + + // Start request #1. + // The server immediately responds with request headers. + req1 := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + rt1 := tc.roundTrip(req1) + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt1.streamID(), + EndHeaders: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt1.wantStatus(200) + + // Start request #2. + // The connection is at its concurrency limit, so this request is not yet sent. + req2 := must(http.NewRequest("GET", "https://dummy.tld/", nil)) + rt2 := tc.roundTrip(req2) tc.wantIdle() + + // Cancel request #1. + // This frees a concurrency slot, and request #2 is sent. + rt1.cancel() + tc.wantRSTStream(rt1.streamID(), ErrCodeCancel) + tc.wantFrameType(FrameHeaders) + + // Cancel request #2. + // We send a PING along with the RST_STREAM, since no frames have been received + // since this request was sent. + rt2.cancel() + tc.wantRSTStream(rt2.streamID(), ErrCodeCancel) + tc.wantFrameType(FramePing) } // Issue #70505: gRPC gets upset if we send more than 2 pings per HEADERS/DATA frame // sent by the server. func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) { + synctestTest(t, testTransportSendNoMoreThanOnePingWithReset) +} +func testTransportSendNoMoreThanOnePingWithReset(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -5674,6 +5726,9 @@ func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) { } func TestTransportConnBecomesUnresponsive(t *testing.T) { + synctestTest(t, testTransportConnBecomesUnresponsive) +} +func testTransportConnBecomesUnresponsive(t testing.TB) { // We send a number of requests in series to an unresponsive connection. // Each request is canceled or times out without a response. // Eventually, we open a new connection rather than trying to use the old one. @@ -5744,19 +5799,19 @@ func TestTransportConnBecomesUnresponsive(t *testing.T) { } // Test that the Transport can use a conn provided to it by a TLSNextProto hook. -func TestTransportTLSNextProtoConnOK(t *testing.T) { +func TestTransportTLSNextProtoConnOK(t *testing.T) { synctestTest(t, testTransportTLSNextProtoConnOK) } +func testTransportTLSNextProtoConnOK(t testing.TB) { t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe(tt.group) + cli, _ := synctestNetPipe() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { - tt.group.Join() t1.TLSNextProto["h2"]("dummy.tld", cliTLS) }() - tt.sync() + synctest.Wait() tc := tt.getConn() tc.greet() @@ -5787,18 +5842,20 @@ func TestTransportTLSNextProtoConnOK(t *testing.T) { // Test the case where a conn provided via a TLSNextProto hook immediately encounters an error. func TestTransportTLSNextProtoConnImmediateFailureUsed(t *testing.T) { + synctestTest(t, testTransportTLSNextProtoConnImmediateFailureUsed) +} +func testTransportTLSNextProtoConnImmediateFailureUsed(t testing.TB) { t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe(tt.group) + cli, _ := synctestNetPipe() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { - tt.group.Join() t1.TLSNextProto["h2"]("dummy.tld", cliTLS) }() - tt.sync() + synctest.Wait() tc := tt.getConn() // The connection encounters an error before we send a request that uses it. @@ -5825,6 +5882,9 @@ func TestTransportTLSNextProtoConnImmediateFailureUsed(t *testing.T) { // Test the case where a conn provided via a TLSNextProto hook is closed for idleness // before we use it. func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) { + synctestTest(t, testTransportTLSNextProtoConnIdleTimoutBeforeUse) +} +func testTransportTLSNextProtoConnIdleTimoutBeforeUse(t testing.TB) { t1 := &http.Transport{ IdleConnTimeout: 1 * time.Second, } @@ -5832,17 +5892,17 @@ func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) { tt := newTestTransport(t, t2) // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe(tt.group) + cli, _ := synctestNetPipe() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { - tt.group.Join() t1.TLSNextProto["h2"]("dummy.tld", cliTLS) }() - tt.sync() - tc := tt.getConn() + synctest.Wait() + _ = tt.getConn() // The connection encounters an error before we send a request that uses it. - tc.advance(2 * time.Second) + time.Sleep(2 * time.Second) + synctest.Wait() // Send a request on the Transport. // @@ -5857,18 +5917,20 @@ func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) { // Test the case where a conn provided via a TLSNextProto hook immediately encounters an error, // but no requests are sent which would use the bad connection. func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) { + synctestTest(t, testTransportTLSNextProtoConnImmediateFailureUnused) +} +func testTransportTLSNextProtoConnImmediateFailureUnused(t testing.TB) { t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe(tt.group) + cli, _ := synctestNetPipe() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { - tt.group.Join() t1.TLSNextProto["h2"]("dummy.tld", cliTLS) }() - tt.sync() + synctest.Wait() tc := tt.getConn() // The connection encounters an error before we send a request that uses it. @@ -5876,7 +5938,7 @@ func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) { // Some time passes. // The dead connection is removed from the pool. - tc.advance(10 * time.Second) + time.Sleep(10 * time.Second) // Send a request on the Transport. // @@ -5959,6 +6021,9 @@ func TestExtendedConnectClientWithoutServerSupport(t *testing.T) { // Issue #70658: Make sure extended CONNECT requests don't get stuck if a // connection fails early in its lifetime. func TestExtendedConnectReadFrameError(t *testing.T) { + synctestTest(t, testExtendedConnectReadFrameError) +} +func testExtendedConnectReadFrameError(t testing.TB) { tc := newTestClientConn(t) tc.wantFrameType(FrameSettings) tc.wantFrameType(FrameWindowUpdate) diff --git a/http2/writesched.go b/http2/writesched.go index cc893adc29..7de27be525 100644 --- a/http2/writesched.go +++ b/http2/writesched.go @@ -42,6 +42,8 @@ type OpenStreamOptions struct { // PusherID is zero if the stream was initiated by the client. Otherwise, // PusherID names the stream that pushed the newly opened stream. PusherID uint32 + // priority is used to set the priority of the newly opened stream. + priority PriorityParam } // FrameWriteRequest is a request to write a frame. @@ -183,45 +185,75 @@ func (wr *FrameWriteRequest) replyToWriter(err error) { } // writeQueue is used by implementations of WriteScheduler. +// +// Each writeQueue contains a queue of FrameWriteRequests, meant to store all +// FrameWriteRequests associated with a given stream. This is implemented as a +// two-stage queue: currQueue[currPos:] and nextQueue. Removing an item is done +// by incrementing currPos of currQueue. Adding an item is done by appending it +// to the nextQueue. If currQueue is empty when trying to remove an item, we +// can swap currQueue and nextQueue to remedy the situation. +// This two-stage queue is analogous to the use of two lists in Okasaki's +// purely functional queue but without the overhead of reversing the list when +// swapping stages. +// +// writeQueue also contains prev and next, this can be used by implementations +// of WriteScheduler to construct data structures that represent the order of +// writing between different streams (e.g. circular linked list). type writeQueue struct { - s []FrameWriteRequest + currQueue []FrameWriteRequest + nextQueue []FrameWriteRequest + currPos int + prev, next *writeQueue } -func (q *writeQueue) empty() bool { return len(q.s) == 0 } +func (q *writeQueue) empty() bool { + return (len(q.currQueue) - q.currPos + len(q.nextQueue)) == 0 +} func (q *writeQueue) push(wr FrameWriteRequest) { - q.s = append(q.s, wr) + q.nextQueue = append(q.nextQueue, wr) } func (q *writeQueue) shift() FrameWriteRequest { - if len(q.s) == 0 { + if q.empty() { panic("invalid use of queue") } - wr := q.s[0] - // TODO: less copy-happy queue. - copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = FrameWriteRequest{} - q.s = q.s[:len(q.s)-1] + if q.currPos >= len(q.currQueue) { + q.currQueue, q.currPos, q.nextQueue = q.nextQueue, 0, q.currQueue[:0] + } + wr := q.currQueue[q.currPos] + q.currQueue[q.currPos] = FrameWriteRequest{} + q.currPos++ return wr } +func (q *writeQueue) peek() *FrameWriteRequest { + if q.currPos < len(q.currQueue) { + return &q.currQueue[q.currPos] + } + if len(q.nextQueue) > 0 { + return &q.nextQueue[0] + } + return nil +} + // consume consumes up to n bytes from q.s[0]. If the frame is // entirely consumed, it is removed from the queue. If the frame // is partially consumed, the frame is kept with the consumed // bytes removed. Returns true iff any bytes were consumed. func (q *writeQueue) consume(n int32) (FrameWriteRequest, bool) { - if len(q.s) == 0 { + if q.empty() { return FrameWriteRequest{}, false } - consumed, rest, numresult := q.s[0].Consume(n) + consumed, rest, numresult := q.peek().Consume(n) switch numresult { case 0: return FrameWriteRequest{}, false case 1: q.shift() case 2: - q.s[0] = rest + *q.peek() = rest } return consumed, true } @@ -230,10 +262,15 @@ type writeQueuePool []*writeQueue // put inserts an unused writeQueue into the pool. func (p *writeQueuePool) put(q *writeQueue) { - for i := range q.s { - q.s[i] = FrameWriteRequest{} + for i := range q.currQueue { + q.currQueue[i] = FrameWriteRequest{} + } + for i := range q.nextQueue { + q.nextQueue[i] = FrameWriteRequest{} } - q.s = q.s[:0] + q.currQueue = q.currQueue[:0] + q.nextQueue = q.nextQueue[:0] + q.currPos = 0 *p = append(*p, q) } diff --git a/http2/writesched_benchmarks_test.go b/http2/writesched_benchmarks_test.go new file mode 100644 index 0000000000..ca5f99d3f7 --- /dev/null +++ b/http2/writesched_benchmarks_test.go @@ -0,0 +1,197 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "testing" +) + +func benchmarkThroughput(b *testing.B, wsFunc func() WriteScheduler, priority PriorityParam) { + const maxFrameSize = 16 + const streamCount = 100 + + ws := wsFunc() + sc := &serverConn{maxFrameSize: maxFrameSize} + streams := make([]*stream, streamCount) + // Possible stream payloads. We vary the payload size of different streams + // to simulate real traffic somewhat. + streamsFrame := [][]byte{ + make([]byte, maxFrameSize*5), + make([]byte, maxFrameSize*10), + make([]byte, maxFrameSize*15), + make([]byte, maxFrameSize*20), + make([]byte, maxFrameSize*25), + } + for i := range streams { + streamID := uint32(i) + 1 + streams[i] = &stream{ + id: streamID, + sc: sc, + } + streams[i].flow.add(1 << 30) // arbitrary large value + + ws.OpenStream(streamID, OpenStreamOptions{ + priority: priority, + }) + } + + for b.Loop() { + for i := range streams { + streamID := uint32(i) + 1 + ws.Push(FrameWriteRequest{ + write: &writeData{ + streamID: streamID, + p: streamsFrame[i%len(streamsFrame)], + endStream: false, + }, + stream: streams[i], + }) + } + for { + wr, ok := ws.Pop() + if !ok { + break + } + if wr.DataSize() != maxFrameSize { + b.Fatalf("wr.Pop() = %v data bytes, want %v", wr.DataSize(), maxFrameSize) + } + } + } + + for i := range streams { + streamID := uint32(i) + 1 + ws.CloseStream(streamID) + } +} + +func benchmarkStreamLifetime(b *testing.B, wsFunc func() WriteScheduler, priority PriorityParam) { + const maxFrameSize = 16 + const streamCount = 100 + + ws := wsFunc() + sc := &serverConn{maxFrameSize: maxFrameSize} + streams := make([]*stream, streamCount) + // Possible stream payloads. We vary the payload size of different streams + // to simulate real traffic somewhat. + streamsFrame := [][]byte{ + make([]byte, maxFrameSize*5), + make([]byte, maxFrameSize*10), + make([]byte, maxFrameSize*15), + make([]byte, maxFrameSize*20), + make([]byte, maxFrameSize*25), + } + for i := range streams { + streamID := uint32(i) + 1 + streams[i] = &stream{ + id: streamID, + sc: sc, + } + streams[i].flow.add(1 << 30) // arbitrary large value + } + + for b.Loop() { + for i := range streams { + streamID := uint32(i) + 1 + ws.OpenStream(streamID, OpenStreamOptions{ + priority: priority, + }) + ws.Push(FrameWriteRequest{ + write: &writeData{ + streamID: streamID, + p: streamsFrame[i%len(streamsFrame)], + endStream: false, + }, + stream: streams[i], + }) + } + for { + wr, ok := ws.Pop() + if !ok { + break + } + if wr.DataSize() != maxFrameSize { + b.Fatalf("wr.Pop() = %v data bytes, want %v", wr.DataSize(), maxFrameSize) + } + } + for i := range streams { + streamID := uint32(i) + 1 + ws.CloseStream(streamID) + } + } + +} + +func BenchmarkWriteSchedulerThroughputRoundRobin(b *testing.B) { + benchmarkThroughput(b, newRoundRobinWriteScheduler, PriorityParam{}) +} + +func BenchmarkWriteSchedulerLifetimeRoundRobin(b *testing.B) { + benchmarkStreamLifetime(b, newRoundRobinWriteScheduler, PriorityParam{}) +} + +func BenchmarkWriteSchedulerThroughputRandom(b *testing.B) { + benchmarkThroughput(b, NewRandomWriteScheduler, PriorityParam{}) +} + +func BenchmarkWriteSchedulerLifetimeRandom(b *testing.B) { + benchmarkStreamLifetime(b, NewRandomWriteScheduler, PriorityParam{}) +} + +func BenchmarkWriteSchedulerThroughputPriorityRFC7540(b *testing.B) { + benchmarkThroughput(b, func() WriteScheduler { return NewPriorityWriteScheduler(nil) }, PriorityParam{}) +} + +func BenchmarkWriteSchedulerLifetimePriorityRFC7540(b *testing.B) { + // RFC7540 priority scheduler does not always succeed in closing the + // stream, causing this benchmark to panic due to opening an already open + // stream. + b.SkipNow() + benchmarkStreamLifetime(b, func() WriteScheduler { return NewPriorityWriteScheduler(nil) }, PriorityParam{}) +} + +func BenchmarkWriteSchedulerThroughputPriorityRFC9218Incremental(b *testing.B) { + benchmarkThroughput(b, newPriorityWriteSchedulerRFC9218, PriorityParam{ + urgency: defaultRFC9218Priority.urgency, + incremental: 1, + }) +} + +func BenchmarkWriteSchedulerLifetimePriorityRFC9218Incremental(b *testing.B) { + benchmarkStreamLifetime(b, newPriorityWriteSchedulerRFC9218, PriorityParam{ + urgency: defaultRFC9218Priority.urgency, + incremental: 1, + }) +} + +func BenchmarkWriteSchedulerThroughputPriorityRFC9218NonIncremental(b *testing.B) { + benchmarkThroughput(b, newPriorityWriteSchedulerRFC9218, PriorityParam{ + urgency: defaultRFC9218Priority.urgency, + incremental: 0, + }) +} + +func BenchmarkWriteSchedulerLifetimePriorityRFC9218NonIncremental(b *testing.B) { + benchmarkStreamLifetime(b, newPriorityWriteSchedulerRFC9218, PriorityParam{ + urgency: defaultRFC9218Priority.urgency, + incremental: 0, + }) +} + +func BenchmarkWriteQueue(b *testing.B) { + var qp writeQueuePool + frameCount := 25 + for b.Loop() { + q := qp.get() + for range frameCount { + q.push(FrameWriteRequest{}) + } + for !q.empty() { + // Since we pushed empty frames, consuming 1 byte is enough to + // consume the entire frame. + q.consume(1) + } + qp.put(q) + } +} diff --git a/http2/writesched_priority.go b/http2/writesched_priority_rfc7540.go similarity index 77% rename from http2/writesched_priority.go rename to http2/writesched_priority_rfc7540.go index f6783339d1..4e33c29a24 100644 --- a/http2/writesched_priority.go +++ b/http2/writesched_priority_rfc7540.go @@ -11,7 +11,7 @@ import ( ) // RFC 7540, Section 5.3.5: the default weight is 16. -const priorityDefaultWeight = 15 // 16 = 15 + 1 +const priorityDefaultWeightRFC7540 = 15 // 16 = 15 + 1 // PriorityWriteSchedulerConfig configures a priorityWriteScheduler. type PriorityWriteSchedulerConfig struct { @@ -66,8 +66,8 @@ func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler } } - ws := &priorityWriteScheduler{ - nodes: make(map[uint32]*priorityNode), + ws := &priorityWriteSchedulerRFC7540{ + nodes: make(map[uint32]*priorityNodeRFC7540), maxClosedNodesInTree: cfg.MaxClosedNodesInTree, maxIdleNodesInTree: cfg.MaxIdleNodesInTree, enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, @@ -81,32 +81,32 @@ func NewPriorityWriteScheduler(cfg *PriorityWriteSchedulerConfig) WriteScheduler return ws } -type priorityNodeState int +type priorityNodeStateRFC7540 int const ( - priorityNodeOpen priorityNodeState = iota - priorityNodeClosed - priorityNodeIdle + priorityNodeOpenRFC7540 priorityNodeStateRFC7540 = iota + priorityNodeClosedRFC7540 + priorityNodeIdleRFC7540 ) -// priorityNode is a node in an HTTP/2 priority tree. +// priorityNodeRFC7540 is a node in an HTTP/2 priority tree. // Each node is associated with a single stream ID. // See RFC 7540, Section 5.3. -type priorityNode struct { - q writeQueue // queue of pending frames to write - id uint32 // id of the stream, or 0 for the root of the tree - weight uint8 // the actual weight is weight+1, so the value is in [1,256] - state priorityNodeState // open | closed | idle - bytes int64 // number of bytes written by this node, or 0 if closed - subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree +type priorityNodeRFC7540 struct { + q writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state priorityNodeStateRFC7540 // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree // These links form the priority tree. - parent *priorityNode - kids *priorityNode // start of the kids list - prev, next *priorityNode // doubly-linked list of siblings + parent *priorityNodeRFC7540 + kids *priorityNodeRFC7540 // start of the kids list + prev, next *priorityNodeRFC7540 // doubly-linked list of siblings } -func (n *priorityNode) setParent(parent *priorityNode) { +func (n *priorityNodeRFC7540) setParent(parent *priorityNodeRFC7540) { if n == parent { panic("setParent to self") } @@ -141,7 +141,7 @@ func (n *priorityNode) setParent(parent *priorityNode) { } } -func (n *priorityNode) addBytes(b int64) { +func (n *priorityNodeRFC7540) addBytes(b int64) { n.bytes += b for ; n != nil; n = n.parent { n.subtreeBytes += b @@ -154,7 +154,7 @@ func (n *priorityNode) addBytes(b int64) { // // f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true // if any ancestor p of n is still open (ignoring the root node). -func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f func(*priorityNode, bool) bool) bool { +func (n *priorityNodeRFC7540) walkReadyInOrder(openParent bool, tmp *[]*priorityNodeRFC7540, f func(*priorityNodeRFC7540, bool) bool) bool { if !n.q.empty() && f(n, openParent) { return true } @@ -165,7 +165,7 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f // Don't consider the root "open" when updating openParent since // we can't send data frames on the root stream (only control frames). if n.id != 0 { - openParent = openParent || (n.state == priorityNodeOpen) + openParent = openParent || (n.state == priorityNodeOpenRFC7540) } // Common case: only one kid or all kids have the same weight. @@ -195,7 +195,7 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f *tmp = append(*tmp, n.kids) n.kids.setParent(nil) } - sort.Sort(sortPriorityNodeSiblings(*tmp)) + sort.Sort(sortPriorityNodeSiblingsRFC7540(*tmp)) for i := len(*tmp) - 1; i >= 0; i-- { (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids } @@ -207,15 +207,15 @@ func (n *priorityNode) walkReadyInOrder(openParent bool, tmp *[]*priorityNode, f return false } -type sortPriorityNodeSiblings []*priorityNode +type sortPriorityNodeSiblingsRFC7540 []*priorityNodeRFC7540 -func (z sortPriorityNodeSiblings) Len() int { return len(z) } -func (z sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } -func (z sortPriorityNodeSiblings) Less(i, k int) bool { +func (z sortPriorityNodeSiblingsRFC7540) Len() int { return len(z) } +func (z sortPriorityNodeSiblingsRFC7540) Swap(i, k int) { z[i], z[k] = z[k], z[i] } +func (z sortPriorityNodeSiblingsRFC7540) Less(i, k int) bool { // Prefer the subtree that has sent fewer bytes relative to its weight. // See sections 5.3.2 and 5.3.4. - wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) - wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + wi, bi := float64(z[i].weight)+1, float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight)+1, float64(z[k].subtreeBytes) if bi == 0 && bk == 0 { return wi >= wk } @@ -225,13 +225,13 @@ func (z sortPriorityNodeSiblings) Less(i, k int) bool { return bi/bk <= wi/wk } -type priorityWriteScheduler struct { +type priorityWriteSchedulerRFC7540 struct { // root is the root of the priority tree, where root.id = 0. // The root queues control frames that are not associated with any stream. - root priorityNode + root priorityNodeRFC7540 // nodes maps stream ids to priority tree nodes. - nodes map[uint32]*priorityNode + nodes map[uint32]*priorityNodeRFC7540 // maxID is the maximum stream id in nodes. maxID uint32 @@ -239,7 +239,7 @@ type priorityWriteScheduler struct { // lists of nodes that have been closed or are idle, but are kept in // the tree for improved prioritization. When the lengths exceed either // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. - closedNodes, idleNodes []*priorityNode + closedNodes, idleNodes []*priorityNodeRFC7540 // From the config. maxClosedNodesInTree int @@ -248,19 +248,19 @@ type priorityWriteScheduler struct { enableWriteThrottle bool // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. - tmp []*priorityNode + tmp []*priorityNodeRFC7540 // pool of empty queues for reuse. queuePool writeQueuePool } -func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStreamOptions) { +func (ws *priorityWriteSchedulerRFC7540) OpenStream(streamID uint32, options OpenStreamOptions) { // The stream may be currently idle but cannot be opened or closed. if curr := ws.nodes[streamID]; curr != nil { - if curr.state != priorityNodeIdle { + if curr.state != priorityNodeIdleRFC7540 { panic(fmt.Sprintf("stream %d already opened", streamID)) } - curr.state = priorityNodeOpen + curr.state = priorityNodeOpenRFC7540 return } @@ -272,11 +272,11 @@ func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStream if parent == nil { parent = &ws.root } - n := &priorityNode{ + n := &priorityNodeRFC7540{ q: *ws.queuePool.get(), id: streamID, - weight: priorityDefaultWeight, - state: priorityNodeOpen, + weight: priorityDefaultWeightRFC7540, + state: priorityNodeOpenRFC7540, } n.setParent(parent) ws.nodes[streamID] = n @@ -285,24 +285,23 @@ func (ws *priorityWriteScheduler) OpenStream(streamID uint32, options OpenStream } } -func (ws *priorityWriteScheduler) CloseStream(streamID uint32) { +func (ws *priorityWriteSchedulerRFC7540) CloseStream(streamID uint32) { if streamID == 0 { panic("violation of WriteScheduler interface: cannot close stream 0") } if ws.nodes[streamID] == nil { panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) } - if ws.nodes[streamID].state != priorityNodeOpen { + if ws.nodes[streamID].state != priorityNodeOpenRFC7540 { panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) } n := ws.nodes[streamID] - n.state = priorityNodeClosed + n.state = priorityNodeClosedRFC7540 n.addBytes(-n.bytes) q := n.q ws.queuePool.put(&q) - n.q.s = nil if ws.maxClosedNodesInTree > 0 { ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) } else { @@ -310,7 +309,7 @@ func (ws *priorityWriteScheduler) CloseStream(streamID uint32) { } } -func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority PriorityParam) { +func (ws *priorityWriteSchedulerRFC7540) AdjustStream(streamID uint32, priority PriorityParam) { if streamID == 0 { panic("adjustPriority on root") } @@ -324,11 +323,11 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit return } ws.maxID = streamID - n = &priorityNode{ + n = &priorityNodeRFC7540{ q: *ws.queuePool.get(), id: streamID, - weight: priorityDefaultWeight, - state: priorityNodeIdle, + weight: priorityDefaultWeightRFC7540, + state: priorityNodeIdleRFC7540, } n.setParent(&ws.root) ws.nodes[streamID] = n @@ -340,7 +339,7 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit parent := ws.nodes[priority.StreamDep] if parent == nil { n.setParent(&ws.root) - n.weight = priorityDefaultWeight + n.weight = priorityDefaultWeightRFC7540 return } @@ -381,8 +380,8 @@ func (ws *priorityWriteScheduler) AdjustStream(streamID uint32, priority Priorit n.weight = priority.Weight } -func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) { - var n *priorityNode +func (ws *priorityWriteSchedulerRFC7540) Push(wr FrameWriteRequest) { + var n *priorityNodeRFC7540 if wr.isControl() { n = &ws.root } else { @@ -401,8 +400,8 @@ func (ws *priorityWriteScheduler) Push(wr FrameWriteRequest) { n.q.push(wr) } -func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) { - ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNode, openParent bool) bool { +func (ws *priorityWriteSchedulerRFC7540) Pop() (wr FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *priorityNodeRFC7540, openParent bool) bool { limit := int32(math.MaxInt32) if openParent { limit = ws.writeThrottleLimit @@ -428,7 +427,7 @@ func (ws *priorityWriteScheduler) Pop() (wr FrameWriteRequest, ok bool) { return wr, ok } -func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, maxSize int, n *priorityNode) { +func (ws *priorityWriteSchedulerRFC7540) addClosedOrIdleNode(list *[]*priorityNodeRFC7540, maxSize int, n *priorityNodeRFC7540) { if maxSize == 0 { return } @@ -442,7 +441,7 @@ func (ws *priorityWriteScheduler) addClosedOrIdleNode(list *[]*priorityNode, max *list = append(*list, n) } -func (ws *priorityWriteScheduler) removeNode(n *priorityNode) { +func (ws *priorityWriteSchedulerRFC7540) removeNode(n *priorityNodeRFC7540) { for n.kids != nil { n.kids.setParent(n.parent) } diff --git a/http2/writesched_priority_test.go b/http2/writesched_priority_rfc7540_test.go similarity index 89% rename from http2/writesched_priority_test.go rename to http2/writesched_priority_rfc7540_test.go index 5aad057bea..6fcee2353a 100644 --- a/http2/writesched_priority_test.go +++ b/http2/writesched_priority_rfc7540_test.go @@ -11,11 +11,11 @@ import ( "testing" ) -func defaultPriorityWriteScheduler() *priorityWriteScheduler { - return NewPriorityWriteScheduler(nil).(*priorityWriteScheduler) +func defaultPriorityWriteScheduler() *priorityWriteSchedulerRFC7540 { + return NewPriorityWriteScheduler(nil).(*priorityWriteSchedulerRFC7540) } -func checkPriorityWellFormed(ws *priorityWriteScheduler) error { +func checkPriorityWellFormed(ws *priorityWriteSchedulerRFC7540) error { for id, n := range ws.nodes { if id != n.id { return fmt.Errorf("bad ws.nodes: ws.nodes[%d] = %d", id, n.id) @@ -40,7 +40,7 @@ func checkPriorityWellFormed(ws *priorityWriteScheduler) error { return nil } -func fmtTree(ws *priorityWriteScheduler, fmtNode func(*priorityNode) string) string { +func fmtTree(ws *priorityWriteSchedulerRFC7540, fmtNode func(*priorityNodeRFC7540) string) string { var ids []int for _, n := range ws.nodes { ids = append(ids, int(n.id)) @@ -61,7 +61,7 @@ func fmtTree(ws *priorityWriteScheduler, fmtNode func(*priorityNode) string) str return buf.String() } -func fmtNodeParentSkipRoot(n *priorityNode) string { +func fmtNodeParentSkipRoot(n *priorityNodeRFC7540) string { switch { case n.id == 0: return "" @@ -72,7 +72,7 @@ func fmtNodeParentSkipRoot(n *priorityNode) string { } } -func fmtNodeWeightParentSkipRoot(n *priorityNode) string { +func fmtNodeWeightParentSkipRoot(n *priorityNodeRFC7540) string { switch { case n.id == 0: return "" @@ -158,7 +158,7 @@ func TestPriorityAdjustOwnParent(t *testing.T) { } func TestPriorityClosedStreams(t *testing.T) { - ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{MaxClosedNodesInTree: 2}).(*priorityWriteScheduler) + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{MaxClosedNodesInTree: 2}).(*priorityWriteSchedulerRFC7540) ws.OpenStream(1, OpenStreamOptions{}) ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) ws.OpenStream(3, OpenStreamOptions{PusherID: 2}) @@ -196,7 +196,7 @@ func TestPriorityClosedStreams(t *testing.T) { } func TestPriorityClosedStreamsDisabled(t *testing.T) { - ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{}).(*priorityWriteScheduler) + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{}).(*priorityWriteSchedulerRFC7540) ws.OpenStream(1, OpenStreamOptions{}) ws.OpenStream(2, OpenStreamOptions{PusherID: 1}) ws.OpenStream(3, OpenStreamOptions{PusherID: 2}) @@ -215,7 +215,7 @@ func TestPriorityClosedStreamsDisabled(t *testing.T) { } func TestPriorityIdleStreams(t *testing.T) { - ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{MaxIdleNodesInTree: 2}).(*priorityWriteScheduler) + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{MaxIdleNodesInTree: 2}).(*priorityWriteSchedulerRFC7540) ws.AdjustStream(1, PriorityParam{StreamDep: 0, Weight: 15}) // idle ws.AdjustStream(2, PriorityParam{StreamDep: 0, Weight: 15}) // idle ws.AdjustStream(3, PriorityParam{StreamDep: 2, Weight: 20}) // idle @@ -236,7 +236,7 @@ func TestPriorityIdleStreams(t *testing.T) { } func TestPriorityIdleStreamsDisabled(t *testing.T) { - ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{}).(*priorityWriteScheduler) + ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{}).(*priorityWriteSchedulerRFC7540) ws.AdjustStream(1, PriorityParam{StreamDep: 0, Weight: 15}) // idle ws.AdjustStream(2, PriorityParam{StreamDep: 0, Weight: 15}) // idle ws.AdjustStream(3, PriorityParam{StreamDep: 2, Weight: 20}) // idle @@ -295,7 +295,7 @@ func TestPrioritySection531Exclusive(t *testing.T) { } } -func makeSection533Tree() *priorityWriteScheduler { +func makeSection533Tree() *priorityWriteSchedulerRFC7540 { // Initial tree from RFC 7540 Section 5.3.3. // A,B,C,D,E,F = 1,2,3,4,5,6 ws := defaultPriorityWriteScheduler() @@ -548,6 +548,39 @@ func TestPriorityWeights(t *testing.T) { } } +func TestPriorityWeightsMinMax(t *testing.T) { + ws := defaultPriorityWriteScheduler() + ws.OpenStream(1, OpenStreamOptions{}) + ws.OpenStream(2, OpenStreamOptions{}) + + sc := &serverConn{maxFrameSize: 8} + st1 := &stream{id: 1, sc: sc} + st2 := &stream{id: 2, sc: sc} + st1.flow.add(40) + st2.flow.add(40) + + // st2 gets 256x the bandwidth of st1 (256 = (255+1)/(0+1)). + // The maximum frame size is 8 bytes. The write sequence should be: + // st2, total bytes so far is (st1=0, st=8) + // st1, total bytes so far is (st1=8, st=8) + // st2, total bytes so far is (st1=8, st=16) + // st2, total bytes so far is (st1=8, st=24) + // st2, total bytes so far is (st1=8, st=32) + // st2, total bytes so far is (st1=8, st=40) // 5x bandwidth + // st1, total bytes so far is (st1=16, st=40) + // st1, total bytes so far is (st1=24, st=40) + // st1, total bytes so far is (st1=32, st=40) + // st1, total bytes so far is (st1=40, st=40) + ws.Push(FrameWriteRequest{&writeData{1, make([]byte, 40), false}, st1, nil}) + ws.Push(FrameWriteRequest{&writeData{2, make([]byte, 40), false}, st2, nil}) + ws.AdjustStream(1, PriorityParam{StreamDep: 0, Weight: 0}) + ws.AdjustStream(2, PriorityParam{StreamDep: 0, Weight: 255}) + + if err := checkPopAll(ws, []uint32{2, 1, 2, 2, 2, 2, 1, 1, 1, 1}); err != nil { + t.Error(err) + } +} + func TestPriorityRstStreamOnNonOpenStreams(t *testing.T) { ws := NewPriorityWriteScheduler(&PriorityWriteSchedulerConfig{ MaxClosedNodesInTree: 0, @@ -565,7 +598,7 @@ func TestPriorityRstStreamOnNonOpenStreams(t *testing.T) { // https://go.dev/issue/66514 func TestPriorityIssue66514(t *testing.T) { - addDep := func(ws *priorityWriteScheduler, child uint32, parent uint32) { + addDep := func(ws *priorityWriteSchedulerRFC7540, child uint32, parent uint32) { ws.AdjustStream(child, PriorityParam{ StreamDep: parent, Exclusive: false, @@ -573,7 +606,7 @@ func TestPriorityIssue66514(t *testing.T) { }) } - validateDepTree := func(ws *priorityWriteScheduler, id uint32, t *testing.T) { + validateDepTree := func(ws *priorityWriteSchedulerRFC7540, id uint32, t *testing.T) { for n := ws.nodes[id]; n != nil; n = n.parent { if n.parent == nil { if n.id != uint32(0) { @@ -583,7 +616,7 @@ func TestPriorityIssue66514(t *testing.T) { } } - ws := NewPriorityWriteScheduler(nil).(*priorityWriteScheduler) + ws := NewPriorityWriteScheduler(nil).(*priorityWriteSchedulerRFC7540) // Root entry addDep(ws, uint32(1), uint32(0)) diff --git a/http2/writesched_priority_rfc9218.go b/http2/writesched_priority_rfc9218.go new file mode 100644 index 0000000000..cb4cadc32d --- /dev/null +++ b/http2/writesched_priority_rfc9218.go @@ -0,0 +1,209 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "fmt" + "math" +) + +type streamMetadata struct { + location *writeQueue + priority PriorityParam +} + +type priorityWriteSchedulerRFC9218 struct { + // control contains control frames (SETTINGS, PING, etc.). + control writeQueue + + // heads contain the head of a circular list of streams. + // We put these heads within a nested array that represents urgency and + // incremental, as defined in + // https://www.rfc-editor.org/rfc/rfc9218.html#name-priority-parameters. + // 8 represents u=0 up to u=7, and 2 represents i=false and i=true. + heads [8][2]*writeQueue + + // streams contains a mapping between each stream ID and their metadata, so + // we can quickly locate them when needing to, for example, adjust their + // priority. + streams map[uint32]streamMetadata + + // queuePool are empty queues for reuse. + queuePool writeQueuePool + + // prioritizeIncremental is used to determine whether we should prioritize + // incremental streams or not, when urgency is the same in a given Pop() + // call. + prioritizeIncremental bool +} + +func newPriorityWriteSchedulerRFC9218() WriteScheduler { + ws := &priorityWriteSchedulerRFC9218{ + streams: make(map[uint32]streamMetadata), + } + return ws +} + +func (ws *priorityWriteSchedulerRFC9218) OpenStream(streamID uint32, opt OpenStreamOptions) { + if ws.streams[streamID].location != nil { + panic(fmt.Errorf("stream %d already opened", streamID)) + } + q := ws.queuePool.get() + ws.streams[streamID] = streamMetadata{ + location: q, + priority: opt.priority, + } + + u, i := opt.priority.urgency, opt.priority.incremental + if ws.heads[u][i] == nil { + ws.heads[u][i] = q + q.next = q + q.prev = q + } else { + // Queues are stored in a ring. + // Insert the new stream before ws.head, putting it at the end of the list. + q.prev = ws.heads[u][i].prev + q.next = ws.heads[u][i] + q.prev.next = q + q.next.prev = q + } +} + +func (ws *priorityWriteSchedulerRFC9218) CloseStream(streamID uint32) { + metadata := ws.streams[streamID] + q, u, i := metadata.location, metadata.priority.urgency, metadata.priority.incremental + if q == nil { + return + } + if q.next == q { + // This was the only open stream. + ws.heads[u][i] = nil + } else { + q.prev.next = q.next + q.next.prev = q.prev + if ws.heads[u][i] == q { + ws.heads[u][i] = q.next + } + } + delete(ws.streams, streamID) + ws.queuePool.put(q) +} + +func (ws *priorityWriteSchedulerRFC9218) AdjustStream(streamID uint32, priority PriorityParam) { + metadata := ws.streams[streamID] + q, u, i := metadata.location, metadata.priority.urgency, metadata.priority.incremental + if q == nil { + return + } + + // Remove stream from current location. + if q.next == q { + // This was the only open stream. + ws.heads[u][i] = nil + } else { + q.prev.next = q.next + q.next.prev = q.prev + if ws.heads[u][i] == q { + ws.heads[u][i] = q.next + } + } + + // Insert stream to the new queue. + u, i = priority.urgency, priority.incremental + if ws.heads[u][i] == nil { + ws.heads[u][i] = q + q.next = q + q.prev = q + } else { + // Queues are stored in a ring. + // Insert the new stream before ws.head, putting it at the end of the list. + q.prev = ws.heads[u][i].prev + q.next = ws.heads[u][i] + q.prev.next = q + q.next.prev = q + } + + // Update the metadata. + ws.streams[streamID] = streamMetadata{ + location: q, + priority: priority, + } +} + +func (ws *priorityWriteSchedulerRFC9218) Push(wr FrameWriteRequest) { + if wr.isControl() { + ws.control.push(wr) + return + } + q := ws.streams[wr.StreamID()].location + if q == nil { + // This is a closed stream. + // wr should not be a HEADERS or DATA frame. + // We push the request onto the control queue. + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + ws.control.push(wr) + return + } + q.push(wr) +} + +func (ws *priorityWriteSchedulerRFC9218) Pop() (FrameWriteRequest, bool) { + // Control and RST_STREAM frames first. + if !ws.control.empty() { + return ws.control.shift(), true + } + + // On the next Pop(), we want to prioritize incremental if we prioritized + // non-incremental request of the same urgency this time. Vice-versa. + // i.e. when there are incremental and non-incremental requests at the same + // priority, we give 50% of our bandwidth to the incremental ones in + // aggregate and 50% to the first non-incremental one (since + // non-incremental streams do not use round-robin writes). + ws.prioritizeIncremental = !ws.prioritizeIncremental + + // Always prioritize lowest u (i.e. highest urgency level). + for u := range ws.heads { + for i := range ws.heads[u] { + // When we want to prioritize incremental, we try to pop i=true + // first before i=false when u is the same. + if ws.prioritizeIncremental { + i = (i + 1) % 2 + } + q := ws.heads[u][i] + if q == nil { + continue + } + for { + if wr, ok := q.consume(math.MaxInt32); ok { + if i == 1 { + // For incremental streams, we update head to q.next so + // we can round-robin between multiple streams that can + // immediately benefit from partial writes. + ws.heads[u][i] = q.next + } else { + // For non-incremental streams, we try to finish one to + // completion rather than doing round-robin. However, + // we update head here so that if q.consume() is !ok + // (e.g. the stream has no more frame to consume), head + // is updated to the next q that has frames to consume + // on future iterations. This way, we do not prioritize + // writing to unavailable stream on next Pop() calls, + // preventing head-of-line blocking. + ws.heads[u][i] = q + } + return wr, true + } + q = q.next + if q == ws.heads[u][i] { + break + } + } + + } + } + return FrameWriteRequest{}, false +} diff --git a/http2/writesched_priority_rfc9218_test.go b/http2/writesched_priority_rfc9218_test.go new file mode 100644 index 0000000000..28a820185d --- /dev/null +++ b/http2/writesched_priority_rfc9218_test.go @@ -0,0 +1,326 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "reflect" + "testing" +) + +func TestPrioritySchedulerUrgency(t *testing.T) { + const maxFrameSize = 16 + sc := &serverConn{maxFrameSize: maxFrameSize} + ws := newPriorityWriteSchedulerRFC9218() + streams := make([]*stream, 5) + for i := range streams { + streamID := uint32(i) + 1 + streams[i] = &stream{ + id: streamID, + sc: sc, + } + streams[i].flow.add(1 << 20) // arbitrary large value + ws.OpenStream(streamID, OpenStreamOptions{ + priority: PriorityParam{ + urgency: 7, + incremental: 0, + }, + }) + wr := FrameWriteRequest{ + write: &writeData{ + streamID: streamID, + p: make([]byte, maxFrameSize*(i+1)), + endStream: false, + }, + stream: streams[i], + } + ws.Push(wr) + } + // Raise the urgency of all even-numbered streams. + for i := range streams { + streamID := uint32(i) + 1 + if streamID%2 == 1 { + continue + } + ws.AdjustStream(streamID, PriorityParam{ + urgency: 0, + incremental: 0, + }) + } + const controlFrames = 2 + for range controlFrames { + ws.Push(makeWriteNonStreamRequest()) + } + + // We should get the control frames first. + for range controlFrames { + wr, ok := ws.Pop() + if !ok || wr.StreamID() != 0 { + t.Fatalf("wr.Pop() = stream %v, %v; want 0, true", wr.StreamID(), ok) + } + } + + // Each stream should write maxFrameSize bytes until it runs out of data. + // Higher-urgency even-numbered streams should come first. + want := []uint32{2, 2, 4, 4, 4, 4, 1, 3, 3, 3, 5, 5, 5, 5, 5} + var got []uint32 + for { + wr, ok := ws.Pop() + if !ok { + break + } + if wr.DataSize() != maxFrameSize { + t.Fatalf("wr.Pop() = %v data bytes, want %v", wr.DataSize(), maxFrameSize) + } + got = append(got, wr.StreamID()) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("popped streams %v, want %v", got, want) + } +} + +func TestPrioritySchedulerIncremental(t *testing.T) { + const maxFrameSize = 16 + sc := &serverConn{maxFrameSize: maxFrameSize} + ws := newPriorityWriteSchedulerRFC9218() + streams := make([]*stream, 5) + for i := range streams { + streamID := uint32(i) + 1 + streams[i] = &stream{ + id: streamID, + sc: sc, + } + streams[i].flow.add(1 << 20) // arbitrary large value + ws.OpenStream(streamID, OpenStreamOptions{ + priority: PriorityParam{ + urgency: 7, + incremental: 0, + }, + }) + wr := FrameWriteRequest{ + write: &writeData{ + streamID: streamID, + p: make([]byte, maxFrameSize*(i+1)), + endStream: false, + }, + stream: streams[i], + } + ws.Push(wr) + } + // Make even-numbered streams incremental. + for i := range streams { + streamID := uint32(i) + 1 + if streamID%2 == 1 { + continue + } + ws.AdjustStream(streamID, PriorityParam{ + urgency: 7, + incremental: 1, + }) + } + const controlFrames = 2 + for range controlFrames { + ws.Push(makeWriteNonStreamRequest()) + } + + // We should get the control frames first. + for range controlFrames { + wr, ok := ws.Pop() + if !ok || wr.StreamID() != 0 { + t.Fatalf("wr.Pop() = stream %v, %v; want 0, true", wr.StreamID(), ok) + } + } + + // Each stream should write maxFrameSize bytes until it runs out of data. + // We should: + // - Round-robin between even and odd-numbered streams as they have + // different i but the same u. + // - Amongst even-numbered streams, round-robin writes as they are + // incremental. + // - Among odd-numbered streams, do not round-robin as they are + // non-incremental. + want := []uint32{2, 1, 4, 3, 2, 3, 4, 3, 4, 5, 4, 5, 5, 5, 5} + var got []uint32 + for { + wr, ok := ws.Pop() + if !ok { + break + } + if wr.DataSize() != maxFrameSize { + t.Fatalf("wr.Pop() = %v data bytes, want %v", wr.DataSize(), maxFrameSize) + } + got = append(got, wr.StreamID()) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("popped streams %v, want %v", got, want) + } +} + +func TestPrioritySchedulerUrgencyAndIncremental(t *testing.T) { + const maxFrameSize = 16 + sc := &serverConn{maxFrameSize: maxFrameSize} + ws := newPriorityWriteSchedulerRFC9218() + streams := make([]*stream, 6) + for i := range streams { + streamID := uint32(i) + 1 + streams[i] = &stream{ + id: streamID, + sc: sc, + } + streams[i].flow.add(1 << 20) // arbitrary large value + ws.OpenStream(streamID, OpenStreamOptions{ + priority: PriorityParam{ + urgency: 7, + incremental: 0, + }, + }) + wr := FrameWriteRequest{ + write: &writeData{ + streamID: streamID, + p: make([]byte, maxFrameSize*(i+1)), + endStream: false, + }, + stream: streams[i], + } + ws.Push(wr) + } + // Make even-numbered streams incremental and of higher urgency. + for i := range streams { + streamID := uint32(i) + 1 + if streamID%2 == 1 { + continue + } + ws.AdjustStream(streamID, PriorityParam{ + urgency: 0, + incremental: 1, + }) + } + // Close stream 1 and 4 + ws.CloseStream(1) + ws.CloseStream(4) + const controlFrames = 2 + for range controlFrames { + ws.Push(makeWriteNonStreamRequest()) + } + + // We should get the control frames first. + for range controlFrames { + wr, ok := ws.Pop() + if !ok || wr.StreamID() != 0 { + t.Fatalf("wr.Pop() = stream %v, %v; want 0, true", wr.StreamID(), ok) + } + } + + // Each stream should write maxFrameSize bytes until it runs out of data. + // We should: + // - Get even-numbered streams first that are written in a round-robin + // manner as they have higher urgency and are incremental. + // - Get odd-numbered streams after that are written one-by-one to + // completion as they are of lower urgency and are not incremental. + // - Skip stream 1 and 4 that have been closed. + want := []uint32{2, 6, 2, 6, 6, 6, 6, 6, 3, 3, 3, 5, 5, 5, 5, 5} + var got []uint32 + for { + wr, ok := ws.Pop() + if !ok { + break + } + if wr.DataSize() != maxFrameSize { + t.Fatalf("wr.Pop() = %v data bytes, want %v", wr.DataSize(), maxFrameSize) + } + got = append(got, wr.StreamID()) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("popped streams %v, want %v", got, want) + } +} + +func TestPrioritySchedulerIdempotentUpdate(t *testing.T) { + const maxFrameSize = 16 + sc := &serverConn{maxFrameSize: maxFrameSize} + ws := newPriorityWriteSchedulerRFC9218() + streams := make([]*stream, 6) + for i := range streams { + streamID := uint32(i) + 1 + streams[i] = &stream{ + id: streamID, + sc: sc, + } + streams[i].flow.add(1 << 20) // arbitrary large value + ws.OpenStream(streamID, OpenStreamOptions{ + priority: PriorityParam{ + urgency: 7, + incremental: 0, + }, + }) + wr := FrameWriteRequest{ + write: &writeData{ + streamID: streamID, + p: make([]byte, maxFrameSize*(i+1)), + endStream: false, + }, + stream: streams[i], + } + ws.Push(wr) + } + // Make even-numbered streams incremental and of higher urgency. + for i := range streams { + streamID := uint32(i) + 1 + if streamID%2 == 1 { + continue + } + ws.AdjustStream(streamID, PriorityParam{ + urgency: 0, + incremental: 1, + }) + } + ws.CloseStream(1) + // Repeat the same priority update to ensure idempotency. + for i := range streams { + streamID := uint32(i) + 1 + if streamID%2 == 1 { + continue + } + ws.AdjustStream(streamID, PriorityParam{ + urgency: 0, + incremental: 1, + }) + } + ws.CloseStream(2) + const controlFrames = 2 + for range controlFrames { + ws.Push(makeWriteNonStreamRequest()) + } + + // We should get the control frames first. + for range controlFrames { + wr, ok := ws.Pop() + if !ok || wr.StreamID() != 0 { + t.Fatalf("wr.Pop() = stream %v, %v; want 0, true", wr.StreamID(), ok) + } + } + + // Each stream should write maxFrameSize bytes until it runs out of data. + // We should: + // - Get even-numbered streams first that are written in a round-robin + // manner as they have higher urgency and are incremental. + // - Get odd-numbered streams after that are written one-by-one to + // completion as they are of lower urgency and are not incremental. + // - Skip stream 1 and 4 that have been closed. + want := []uint32{4, 6, 4, 6, 4, 6, 4, 6, 6, 6, 3, 3, 3, 5, 5, 5, 5, 5} + var got []uint32 + for { + wr, ok := ws.Pop() + if !ok { + break + } + if wr.DataSize() != maxFrameSize { + t.Fatalf("wr.Pop() = %v data bytes, want %v", wr.DataSize(), maxFrameSize) + } + got = append(got, wr.StreamID()) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("popped streams %v, want %v", got, want) + } +} diff --git a/http2/writesched_roundrobin.go b/http2/writesched_roundrobin.go index 54fe86322d..737cff9ecb 100644 --- a/http2/writesched_roundrobin.go +++ b/http2/writesched_roundrobin.go @@ -25,7 +25,7 @@ type roundRobinWriteScheduler struct { } // newRoundRobinWriteScheduler constructs a new write scheduler. -// The round robin scheduler priorizes control frames +// The round robin scheduler prioritizes control frames // like SETTINGS and PING over DATA frames. // When there are no control frames to send, it performs a round-robin // selection from the ready streams. diff --git a/internal/http3/qpack.go b/internal/http3/qpack.go index 66f4e29762..8fb4860b54 100644 --- a/internal/http3/qpack.go +++ b/internal/http3/qpack.go @@ -224,7 +224,7 @@ func (st *stream) readPrefixedInt(prefixLen uint8) (firstByte byte, v int64, err return firstByte, v, err } -// readPrefixedInt reads an RFC 7541 prefixed integer from st. +// readPrefixedIntWithByte reads an RFC 7541 prefixed integer from st. // The first byte has already been read from the stream. func (st *stream) readPrefixedIntWithByte(firstByte byte, prefixLen uint8) (v int64, err error) { prefixMask := (byte(1) << prefixLen) - 1 @@ -285,7 +285,7 @@ func (st *stream) readPrefixedString(prefixLen uint8) (firstByte byte, s string, return firstByte, s, err } -// readPrefixedString reads an RFC 7541 string from st. +// readPrefixedStringWithByte reads an RFC 7541 string from st. // The first byte has already been read from the stream. func (st *stream) readPrefixedStringWithByte(firstByte byte, prefixLen uint8) (s string, err error) { size, err := st.readPrefixedIntWithByte(firstByte, prefixLen) diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go index acd8613d0e..ba6a234af7 100644 --- a/internal/http3/roundtrip_test.go +++ b/internal/http3/roundtrip_test.go @@ -125,7 +125,7 @@ func TestRoundTripResponseContentLength(t *testing.T) { }, wantContentLength: -1, }, { - name: "unparseable", + name: "unparsable", respHeader: http.Header{ ":status": []string{"200"}, "content-length": []string{"1 1"}, @@ -185,7 +185,7 @@ func TestRoundTripMalformedResponses(t *testing.T) { ":status": []string{"200", "204"}, }, }, { - name: "unparseable :status", + name: "unparsable :status", respHeader: http.Header{ ":status": []string{"frogpants"}, }, diff --git a/internal/http3/stream.go b/internal/http3/stream.go index 0f975407be..345e2f507f 100644 --- a/internal/http3/stream.go +++ b/internal/http3/stream.go @@ -68,7 +68,7 @@ func newStream(qs *quic.Stream) *stream { // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.1 func (st *stream) readFrameHeader() (ftype frameType, err error) { if st.lim >= 0 { - // We shoudn't call readFrameHeader before ending the previous frame. + // We shouldn't call readFrameHeader before ending the previous frame. return 0, errH3FrameError } ftype, err = readVarint[frameType](st) diff --git a/internal/http3/stream_test.go b/internal/http3/stream_test.go index 12b281c558..a034cc7697 100644 --- a/internal/http3/stream_test.go +++ b/internal/http3/stream_test.go @@ -198,7 +198,7 @@ func TestStreamReadFrameHeaderPartial(t *testing.T) { st1.stream.CloseWrite() if _, err := st2.readFrameHeader(); err == nil { - t.Fatalf("%v/%v bytes of frame available: st.readFrameHeader() succeded; want error", i, len(frame)) + t.Fatalf("%v/%v bytes of frame available: st.readFrameHeader() succeeded; want error", i, len(frame)) } } } diff --git a/internal/httpcommon/request.go b/internal/httpcommon/request.go index 4b70553179..1e10f89ebf 100644 --- a/internal/httpcommon/request.go +++ b/internal/httpcommon/request.go @@ -51,7 +51,7 @@ type EncodeHeadersParam struct { DefaultUserAgent string } -// EncodeHeadersParam is the result of EncodeHeaders. +// EncodeHeadersResult is the result of EncodeHeaders. type EncodeHeadersResult struct { HasBody bool HasTrailers bool @@ -399,7 +399,7 @@ type ServerRequestResult struct { // If the request should be rejected, this is a short string suitable for passing // to the http2 package's CountError function. - // It might be a bit odd to return errors this way rather than returing an error, + // It might be a bit odd to return errors this way rather than returning an error, // but this ensures we don't forget to include a CountError reason. InvalidReason string } diff --git a/internal/httpsfv/httpsfv.go b/internal/httpsfv/httpsfv.go new file mode 100644 index 0000000000..4ae2ca5b8e --- /dev/null +++ b/internal/httpsfv/httpsfv.go @@ -0,0 +1,665 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httpsfv provides functionality for dealing with HTTP Structured +// Field Values. +package httpsfv + +import ( + "slices" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +func isLCAlpha(b byte) bool { + return (b >= 'a' && b <= 'z') +} + +func isAlpha(b byte) bool { + return isLCAlpha(b) || (b >= 'A' && b <= 'Z') +} + +func isDigit(b byte) bool { + return b >= '0' && b <= '9' +} + +func isVChar(b byte) bool { + return b >= 0x21 && b <= 0x7e +} + +func isSP(b byte) bool { + return b == 0x20 +} + +func isTChar(b byte) bool { + if isAlpha(b) || isDigit(b) { + return true + } + return slices.Contains([]byte{'!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~'}, b) +} + +func countLeftWhitespace(s string) int { + i := 0 + for _, ch := range []byte(s) { + if ch != ' ' && ch != '\t' { + break + } + i++ + } + return i +} + +// https://www.rfc-editor.org/rfc/rfc4648#section-8. +func decOctetHex(ch1, ch2 byte) (ch byte, ok bool) { + decBase16 := func(in byte) (out byte, ok bool) { + if !isDigit(in) && !(in >= 'a' && in <= 'f') { + return 0, false + } + if isDigit(in) { + return in - '0', true + } + return in - 'a' + 10, true + } + + if ch1, ok = decBase16(ch1); !ok { + return 0, ok + } + if ch2, ok = decBase16(ch2); !ok { + return 0, ok + } + return ch1<<4 | ch2, true +} + +// ParseList parses a list from a given HTTP Structured Field Values. +// +// Given an HTTP SFV string that represents a list, it will call the given +// function using each of the members and parameters contained in the list. +// This allows the caller to extract information out of the list. +// +// This function will return once it encounters the end of the string, or +// something that is not a list. If it cannot consume the entire given +// string, the ok value returned will be false. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-list. +func ParseList(s string, f func(member, param string)) (ok bool) { + for len(s) != 0 { + var member, param string + if len(s) != 0 && s[0] == '(' { + if member, s, ok = consumeBareInnerList(s, nil); !ok { + return ok + } + } else { + if member, s, ok = consumeBareItem(s); !ok { + return ok + } + } + if param, s, ok = consumeParameter(s, nil); !ok { + return ok + } + if f != nil { + f(member, param) + } + + s = s[countLeftWhitespace(s):] + if len(s) == 0 { + break + } + if s[0] != ',' { + return false + } + s = s[1:] + s = s[countLeftWhitespace(s):] + if len(s) == 0 { + return false + } + } + return true +} + +// consumeBareInnerList consumes an inner list +// (https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-inner-list), +// except for the inner list's top-most parameter. +// For example, given `(a;b c;d);e`, it will consume only `(a;b c;d)`. +func consumeBareInnerList(s string, f func(bareItem, param string)) (consumed, rest string, ok bool) { + if len(s) == 0 || s[0] != '(' { + return "", s, false + } + rest = s[1:] + for len(rest) != 0 { + var bareItem, param string + rest = rest[countLeftWhitespace(rest):] + if len(rest) != 0 && rest[0] == ')' { + rest = rest[1:] + break + } + if bareItem, rest, ok = consumeBareItem(rest); !ok { + return "", s, ok + } + if param, rest, ok = consumeParameter(rest, nil); !ok { + return "", s, ok + } + if len(rest) == 0 || (rest[0] != ')' && !isSP(rest[0])) { + return "", s, false + } + if f != nil { + f(bareItem, param) + } + } + return s[:len(s)-len(rest)], rest, true +} + +// ParseBareInnerList parses a bare inner list from a given HTTP Structured +// Field Values. +// +// We define a bare inner list as an inner list +// (https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-inner-list), +// without the top-most parameter of the inner list. For example, given the +// inner list `(a;b c;d);e`, the bare inner list would be `(a;b c;d)`. +// +// Given an HTTP SFV string that represents a bare inner list, it will call the +// given function using each of the bare item and parameter within the bare +// inner list. This allows the caller to extract information out of the bare +// inner list. +// +// This function will return once it encounters the end of the bare inner list, +// or something that is not a bare inner list. If it cannot consume the entire +// given string, the ok value returned will be false. +func ParseBareInnerList(s string, f func(bareItem, param string)) (ok bool) { + _, rest, ok := consumeBareInnerList(s, f) + return rest == "" && ok +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-item. +func consumeItem(s string, f func(bareItem, param string)) (consumed, rest string, ok bool) { + var bareItem, param string + if bareItem, rest, ok = consumeBareItem(s); !ok { + return "", s, ok + } + if param, rest, ok = consumeParameter(rest, nil); !ok { + return "", s, ok + } + if f != nil { + f(bareItem, param) + } + return s[:len(s)-len(rest)], rest, true +} + +// ParseItem parses an item from a given HTTP Structured Field Values. +// +// Given an HTTP SFV string that represents an item, it will call the given +// function once, with the bare item and the parameter of the item. This allows +// the caller to extract information out of the item. +// +// This function will return once it encounters the end of the string, or +// something that is not an item. If it cannot consume the entire given +// string, the ok value returned will be false. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-item. +func ParseItem(s string, f func(bareItem, param string)) (ok bool) { + _, rest, ok := consumeItem(s, f) + return rest == "" && ok +} + +// ParseDictionary parses a dictionary from a given HTTP Structured Field +// Values. +// +// Given an HTTP SFV string that represents a dictionary, it will call the +// given function using each of the keys, values, and parameters contained in +// the dictionary. This allows the caller to extract information out of the +// dictionary. +// +// This function will return once it encounters the end of the string, or +// something that is not a dictionary. If it cannot consume the entire given +// string, the ok value returned will be false. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-dictionary. +func ParseDictionary(s string, f func(key, val, param string)) (ok bool) { + for len(s) != 0 { + var key, val, param string + val = "?1" // Default value for empty val is boolean true. + if key, s, ok = consumeKey(s); !ok { + return ok + } + if len(s) != 0 && s[0] == '=' { + s = s[1:] + if len(s) != 0 && s[0] == '(' { + if val, s, ok = consumeBareInnerList(s, nil); !ok { + return ok + } + } else { + if val, s, ok = consumeBareItem(s); !ok { + return ok + } + } + } + if param, s, ok = consumeParameter(s, nil); !ok { + return ok + } + if f != nil { + f(key, val, param) + } + s = s[countLeftWhitespace(s):] + if len(s) == 0 { + break + } + if s[0] == ',' { + s = s[1:] + } + s = s[countLeftWhitespace(s):] + if len(s) == 0 { + return false + } + } + return true +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#parse-param. +func consumeParameter(s string, f func(key, val string)) (consumed, rest string, ok bool) { + rest = s + for len(rest) != 0 { + var key, val string + val = "?1" // Default value for empty val is boolean true. + if rest[0] != ';' { + break + } + rest = rest[1:] + rest = rest[countLeftWhitespace(rest):] + key, rest, ok = consumeKey(rest) + if !ok { + return "", s, ok + } + if len(rest) != 0 && rest[0] == '=' { + rest = rest[1:] + val, rest, ok = consumeBareItem(rest) + if !ok { + return "", s, ok + } + } + if f != nil { + f(key, val) + } + } + return s[:len(s)-len(rest)], rest, true +} + +// ParseParameter parses a parameter from a given HTTP Structured Field Values. +// +// Given an HTTP SFV string that represents a parameter, it will call the given +// function using each of the keys and values contained in the parameter. This +// allows the caller to extract information out of the parameter. +// +// This function will return once it encounters the end of the string, or +// something that is not a parameter. If it cannot consume the entire given +// string, the ok value returned will be false. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#parse-param. +func ParseParameter(s string, f func(key, val string)) (ok bool) { + _, rest, ok := consumeParameter(s, f) + return rest == "" && ok +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-key. +func consumeKey(s string) (consumed, rest string, ok bool) { + if len(s) == 0 || (!isLCAlpha(s[0]) && s[0] != '*') { + return "", s, false + } + i := 0 + for _, ch := range []byte(s) { + if !isLCAlpha(ch) && !isDigit(ch) && !slices.Contains([]byte("_-.*"), ch) { + break + } + i++ + } + return s[:i], s[i:], true +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim. +func consumeIntegerOrDecimal(s string) (consumed, rest string, ok bool) { + var i, signOffset, periodIndex int + var isDecimal bool + if i < len(s) && s[i] == '-' { + i++ + signOffset++ + } + if i >= len(s) { + return "", s, false + } + if !isDigit(s[i]) { + return "", s, false + } + for i < len(s) { + ch := s[i] + if isDigit(ch) { + i++ + continue + } + if !isDecimal && ch == '.' { + if i-signOffset > 12 { + return "", s, false + } + periodIndex = i + isDecimal = true + i++ + continue + } + break + } + if !isDecimal && i-signOffset > 15 { + return "", s, false + } + if isDecimal { + if i-signOffset > 16 { + return "", s, false + } + if s[i-1] == '.' { + return "", s, false + } + if i-periodIndex-1 > 3 { + return "", s, false + } + } + return s[:i], s[i:], true +} + +// ParseInteger parses an integer from a given HTTP Structured Field Values. +// +// The entire HTTP SFV string must consist of a valid integer. It returns the +// parsed integer and an ok boolean value, indicating success or not. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim. +func ParseInteger(s string) (parsed int64, ok bool) { + if _, rest, ok := consumeIntegerOrDecimal(s); !ok || rest != "" { + return 0, false + } + if n, err := strconv.ParseInt(s, 10, 64); err == nil { + return n, true + } + return 0, false +} + +// ParseDecimal parses a decimal from a given HTTP Structured Field Values. +// +// The entire HTTP SFV string must consist of a valid decimal. It returns the +// parsed decimal and an ok boolean value, indicating success or not. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-an-integer-or-decim. +func ParseDecimal(s string) (parsed float64, ok bool) { + if _, rest, ok := consumeIntegerOrDecimal(s); !ok || rest != "" { + return 0, false + } + if !strings.Contains(s, ".") { + return 0, false + } + if n, err := strconv.ParseFloat(s, 64); err == nil { + return n, true + } + return 0, false +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-string. +func consumeString(s string) (consumed, rest string, ok bool) { + if len(s) == 0 || s[0] != '"' { + return "", s, false + } + for i := 1; i < len(s); i++ { + switch ch := s[i]; ch { + case '\\': + if i+1 >= len(s) { + return "", s, false + } + i++ + if ch = s[i]; ch != '"' && ch != '\\' { + return "", s, false + } + case '"': + return s[:i+1], s[i+1:], true + default: + if !isVChar(ch) && !isSP(ch) { + return "", s, false + } + } + } + return "", s, false +} + +// ParseString parses a Go string from a given HTTP Structured Field Values. +// +// The entire HTTP SFV string must consist of a valid string. It returns the +// parsed string and an ok boolean value, indicating success or not. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-string. +func ParseString(s string) (parsed string, ok bool) { + if _, rest, ok := consumeString(s); !ok || rest != "" { + return "", false + } + return s[1 : len(s)-1], true +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-token +func consumeToken(s string) (consumed, rest string, ok bool) { + if len(s) == 0 || (!isAlpha(s[0]) && s[0] != '*') { + return "", s, false + } + i := 0 + for _, ch := range []byte(s) { + if !isTChar(ch) && !slices.Contains([]byte(":/"), ch) { + break + } + i++ + } + return s[:i], s[i:], true +} + +// ParseToken parses a token from a given HTTP Structured Field Values. +// +// The entire HTTP SFV string must consist of a valid token. It returns the +// parsed token and an ok boolean value, indicating success or not. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-token +func ParseToken(s string) (parsed string, ok bool) { + if _, rest, ok := consumeToken(s); !ok || rest != "" { + return "", false + } + return s, true +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-byte-sequence. +func consumeByteSequence(s string) (consumed, rest string, ok bool) { + if len(s) == 0 || s[0] != ':' { + return "", s, false + } + for i := 1; i < len(s); i++ { + if ch := s[i]; ch == ':' { + return s[:i+1], s[i+1:], true + } + if ch := s[i]; !isAlpha(ch) && !isDigit(ch) && !slices.Contains([]byte("+/="), ch) { + return "", s, false + } + } + return "", s, false +} + +// ParseByteSequence parses a byte sequence from a given HTTP Structured Field +// Values. +// +// The entire HTTP SFV string must consist of a valid byte sequence. It returns +// the parsed byte sequence and an ok boolean value, indicating success or not. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-byte-sequence. +func ParseByteSequence(s string) (parsed []byte, ok bool) { + if _, rest, ok := consumeByteSequence(s); !ok || rest != "" { + return nil, false + } + return []byte(s[1 : len(s)-1]), true +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-boolean. +func consumeBoolean(s string) (consumed, rest string, ok bool) { + if len(s) >= 2 && (s[:2] == "?0" || s[:2] == "?1") { + return s[:2], s[2:], true + } + return "", s, false +} + +// ParseBoolean parses a boolean from a given HTTP Structured Field Values. +// +// The entire HTTP SFV string must consist of a valid boolean. It returns the +// parsed boolean and an ok boolean value, indicating success or not. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-boolean. +func ParseBoolean(s string) (parsed bool, ok bool) { + if _, rest, ok := consumeBoolean(s); !ok || rest != "" { + return false, false + } + return s == "?1", true +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-date. +func consumeDate(s string) (consumed, rest string, ok bool) { + if len(s) == 0 || s[0] != '@' { + return "", s, false + } + if _, rest, ok = consumeIntegerOrDecimal(s[1:]); !ok { + return "", s, ok + } + consumed = s[:len(s)-len(rest)] + if slices.Contains([]byte(consumed), '.') { + return "", s, false + } + return consumed, rest, ok +} + +// ParseDate parses a date from a given HTTP Structured Field Values. +// +// The entire HTTP SFV string must consist of a valid date. It returns the +// parsed date and an ok boolean value, indicating success or not. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-date. +func ParseDate(s string) (parsed time.Time, ok bool) { + if _, rest, ok := consumeDate(s); !ok || rest != "" { + return time.Time{}, false + } + if n, ok := ParseInteger(s[1:]); !ok { + return time.Time{}, false + } else { + return time.Unix(n, 0), true + } +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-display-string. +func consumeDisplayString(s string) (consumed, rest string, ok bool) { + // To prevent excessive allocation, especially when input is large, we + // maintain a buffer of 4 bytes to keep track of the last rune we + // encounter. This way, we can validate that the display string conforms to + // UTF-8 without actually building the whole string. + var lastRune [4]byte + var runeLen int + isPartOfValidRune := func(ch byte) bool { + lastRune[runeLen] = ch + runeLen++ + if utf8.FullRune(lastRune[:runeLen]) { + r, s := utf8.DecodeRune(lastRune[:runeLen]) + if r == utf8.RuneError { + return false + } + copy(lastRune[:], lastRune[s:runeLen]) + runeLen -= s + return true + } + return runeLen <= 4 + } + + if len(s) <= 1 || s[:2] != `%"` { + return "", s, false + } + i := 2 + for i < len(s) { + ch := s[i] + if !isVChar(ch) && !isSP(ch) { + return "", s, false + } + switch ch { + case '"': + if runeLen > 0 { + return "", s, false + } + return s[:i+1], s[i+1:], true + case '%': + if i+2 >= len(s) { + return "", s, false + } + if ch, ok = decOctetHex(s[i+1], s[i+2]); !ok { + return "", s, ok + } + if ok = isPartOfValidRune(ch); !ok { + return "", s, ok + } + i += 3 + default: + if ok = isPartOfValidRune(ch); !ok { + return "", s, ok + } + i++ + } + } + return "", s, false +} + +// ParseDisplayString parses a display string from a given HTTP Structured +// Field Values. +// +// The entire HTTP SFV string must consist of a valid display string. It +// returns the parsed display string and an ok boolean value, indicating +// success or not. +// +// https://www.rfc-editor.org/rfc/rfc9651.html#name-parsing-a-display-string. +func ParseDisplayString(s string) (parsed string, ok bool) { + if _, rest, ok := consumeDisplayString(s); !ok || rest != "" { + return "", false + } + // consumeDisplayString() already validates that we have a valid display + // string. Therefore, we can just construct the display string, without + // validating it again. + s = s[2 : len(s)-1] + var b strings.Builder + for i := 0; i < len(s); { + if s[i] == '%' { + decoded, _ := decOctetHex(s[i+1], s[i+2]) + b.WriteByte(decoded) + i += 3 + continue + } + b.WriteByte(s[i]) + i++ + } + return b.String(), true +} + +// https://www.rfc-editor.org/rfc/rfc9651.html#parse-bare-item. +func consumeBareItem(s string) (consumed, rest string, ok bool) { + if len(s) == 0 { + return "", s, false + } + ch := s[0] + switch { + case ch == '-' || isDigit(ch): + return consumeIntegerOrDecimal(s) + case ch == '"': + return consumeString(s) + case ch == '*' || isAlpha(ch): + return consumeToken(s) + case ch == ':': + return consumeByteSequence(s) + case ch == '?': + return consumeBoolean(s) + case ch == '@': + return consumeDate(s) + case ch == '%': + return consumeDisplayString(s) + default: + return "", s, false + } +} diff --git a/internal/httpsfv/httpsfv_test.go b/internal/httpsfv/httpsfv_test.go new file mode 100644 index 0000000000..9e911ba9b9 --- /dev/null +++ b/internal/httpsfv/httpsfv_test.go @@ -0,0 +1,1439 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package httpsfv + +import ( + "slices" + "strconv" + "strings" + "testing" + "time" +) + +func TestParseList(t *testing.T) { + tests := []struct { + name string + in string + wantMembers []string + wantParams []string + wantOk bool + }{ + { + name: "valid list", + in: `a, b,c`, + wantMembers: []string{"a", "b", "c"}, + wantParams: []string{"", "", ""}, + wantOk: true, + }, + { + name: "valid list with params", + in: `a;foo=bar, b,c; baz=baz`, + wantMembers: []string{"a", "b", "c"}, + wantParams: []string{";foo=bar", "", "; baz=baz"}, + wantOk: true, + }, + { + name: "valid list with fake commas", + in: `a;foo=",", (",")`, + wantMembers: []string{"a", `(",")`}, + wantParams: []string{`;foo=","`, ""}, + wantOk: true, + }, + { + name: "valid list with inner list member", + in: `(a b c); foo, bar;baz`, + wantMembers: []string{"(a b c)", "bar"}, + wantParams: []string{"; foo", ";baz"}, + wantOk: true, + }, + { + name: "invalid list with trailing comma", + in: `a;foo=bar, b,c; baz=baz,`, + wantMembers: []string{"a", "b", "c"}, + wantParams: []string{";foo=bar", "", "; baz=baz"}, + }, + { + name: "invalid list with unclosed string", + in: `", b, c,d`, + }, + } + + for _, tc := range tests { + var gotMembers, gotParams []string + f := func(member, param string) { + gotMembers = append(gotMembers, member) + gotParams = append(gotParams, param) + } + ok := ParseList(tc.in, f) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if !slices.Equal(tc.wantMembers, gotMembers) { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotMembers, tc.wantMembers) + } + if !slices.Equal(tc.wantParams, gotParams) { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotParams, tc.wantParams) + } + } +} + +func TestConsumeBareInnerList(t *testing.T) { + tests := []struct { + name string + in string + wantBareItems []string + wantParams []string + wantListParam string + wantOk bool + }{ + { + name: "valid inner list without param", + in: `(a b c)`, + wantBareItems: []string{"a", "b", "c"}, + wantParams: []string{"", "", ""}, + wantOk: true, + }, + { + name: "valid inner list with param", + in: `(a;d b c;e)`, + wantBareItems: []string{"a", "b", "c"}, + wantParams: []string{";d", "", ";e"}, + wantOk: true, + }, + { + name: "valid inner list with fake ending parenthesis", + in: `(")";foo=")")`, + wantBareItems: []string{`")"`}, + wantParams: []string{`;foo=")"`}, + wantOk: true, + }, + { + name: "valid inner list with list parameter", + in: `(a b;c); d`, + wantBareItems: []string{"a", "b"}, + wantParams: []string{"", ";c"}, + wantOk: true, + }, + { + name: "valid inner list with more content after", + in: `(a b;c); d, a`, + wantBareItems: []string{"a", "b"}, + wantParams: []string{"", ";c"}, + wantOk: true, + }, + { + name: "invalid inner list", + in: `(a b;c `, + wantBareItems: []string{"a", "b"}, + wantParams: []string{"", ";c"}, + }, + } + + for _, tc := range tests { + var gotBareItems, gotParams []string + f := func(bareItem, param string) { + gotBareItems = append(gotBareItems, bareItem) + gotParams = append(gotParams, param) + } + gotConsumed, gotRest, ok := consumeBareInnerList(tc.in, f) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if !slices.Equal(tc.wantBareItems, gotBareItems) { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotBareItems, tc.wantBareItems) + } + if !slices.Equal(tc.wantParams, gotParams) { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotParams, tc.wantParams) + } + if gotConsumed+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, gotConsumed, gotRest, tc.in) + } + } +} + +func TestParseBareInnerList(t *testing.T) { + tests := []struct { + name string + in string + wantBareItems []string + wantParams []string + wantOk bool + }{ + { + name: "valid inner list", + in: `(a b;c)`, + wantBareItems: []string{"a", "b"}, + wantParams: []string{"", ";c"}, + wantOk: true, + }, + { + name: "valid inner list with list parameter", + in: `(a b;c); d`, + wantBareItems: []string{"a", "b"}, + wantParams: []string{"", ";c"}, + }, + { + name: "invalid inner list", + in: `(a b;c `, + wantBareItems: []string{"a", "b"}, + wantParams: []string{"", ";c"}, + }, + } + + for _, tc := range tests { + var gotBareItems, gotParams []string + f := func(bareItem, param string) { + gotBareItems = append(gotBareItems, bareItem) + gotParams = append(gotParams, param) + } + ok := ParseBareInnerList(tc.in, f) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if !slices.Equal(tc.wantBareItems, gotBareItems) { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotBareItems, tc.wantBareItems) + } + if !slices.Equal(tc.wantParams, gotParams) { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotParams, tc.wantParams) + } + } +} + +func TestConsumeItem(t *testing.T) { + tests := []struct { + name string + in string + wantBareItem string + wantParam string + wantOk bool + }{ + { + name: "valid bare item", + in: `fookey`, + wantBareItem: `fookey`, + wantOk: true, + }, + { + name: "valid bare item and param", + in: `fookey; a="123"`, + wantBareItem: `fookey`, + wantParam: `; a="123"`, + wantOk: true, + }, + { + name: "valid item with content after", + in: `fookey; a="123", otheritem; otherparam=1`, + wantBareItem: `fookey`, + wantParam: `; a="123"`, + wantOk: true, + }, + { + name: "invalid just param", + in: `;a="123"`, + }, + } + + for _, tc := range tests { + var gotBareItem, gotParam string + f := func(bareItem, param string) { + gotBareItem = bareItem + gotParam = param + } + gotConsumed, gotRest, ok := consumeItem(tc.in, f) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.wantBareItem != gotBareItem { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotBareItem, tc.wantBareItem) + } + if tc.wantParam != gotParam { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotParam, tc.wantParam) + } + if gotConsumed+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, gotConsumed, gotRest, tc.in) + } + } +} + +func TestParseItem(t *testing.T) { + tests := []struct { + name string + in string + wantBareItem string + wantParam string + wantOk bool + }{ + { + name: "valid bare item", + in: `fookey`, + wantBareItem: `fookey`, + wantOk: true, + }, + { + name: "valid bare item and param", + in: `fookey; a="123"`, + wantBareItem: `fookey`, + wantParam: `; a="123"`, + wantOk: true, + }, + { + name: "valid item with content after", + in: `fookey; a="123", otheritem; otherparam=1`, + wantBareItem: `fookey`, + wantParam: `; a="123"`, + }, + { + name: "invalid just param", + in: `;a="123"`, + }, + } + + for _, tc := range tests { + var gotBareItem, gotParam string + f := func(bareItem, param string) { + gotBareItem = bareItem + gotParam = param + } + ok := ParseItem(tc.in, f) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.wantBareItem != gotBareItem { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotBareItem, tc.wantBareItem) + } + if tc.wantParam != gotParam { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotParam, tc.wantParam) + } + } +} + +func TestParseDictionary(t *testing.T) { + tests := []struct { + name string + in string + wantVal string + wantParam string + wantOk bool + }{ + { + name: "valid dictionary with simple value", + in: `a=b, want=foo, c=d`, + wantVal: "foo", + wantOk: true, + }, + { + name: "valid dictionary with implicit value", + in: `a, want, c=d`, + wantVal: "?1", + wantOk: true, + }, + { + name: "valid dictionary with parameter", + in: `a, want=foo;bar=baz, c=d`, + wantVal: "foo", + wantParam: ";bar=baz", + wantOk: true, + }, + { + name: "valid dictionary with inner list", + in: `a, want=(a b c d;e;f);g=h, c=d`, + wantVal: "(a b c d;e;f)", + wantParam: ";g=h", + wantOk: true, + }, + { + name: "valid dictionary with fake commas", + in: `a=(";");b=";",want=foo;bar`, + wantVal: "foo", + wantParam: ";bar", + wantOk: true, + }, + { + name: "invalid dictionary with bad key", + in: `UPPERCASEKEY=BAD, want=foo, c=d`, + }, + { + name: "invalid dictionary with trailing comma", + in: `trailing=comma,`, + }, + { + name: "invalid dictionary with unclosed string", + in: `a=""",want=foo;bar`, + }, + } + + for _, tc := range tests { + var gotVal, gotParam string + f := func(key, val, param string) { + if key == "want" { + gotVal = val + gotParam = param + } + } + ok := ParseDictionary(tc.in, f) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.wantVal != gotVal { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotVal, tc.wantVal) + } + if tc.wantParam != gotParam { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, gotParam, tc.wantParam) + } + } +} + +func TestConsumeParameter(t *testing.T) { + tests := []struct { + name string + in string + want any + wantOk bool + }{ + { + name: "valid string", + in: `;parameter;want="wantvalue"`, + want: "wantvalue", + wantOk: true, + }, + { + name: "valid integer", + in: `;parameter;want=123456;something`, + want: 123456, + wantOk: true, + }, + { + name: "valid decimal", + in: `;parameter;want=3.14;something`, + want: 3.14, + wantOk: true, + }, + { + name: "valid implicit bool", + in: `;parameter;want;something`, + want: true, + wantOk: true, + }, + { + name: "valid token", + in: `;want=*atoken;something`, + want: "*atoken", + wantOk: true, + }, + { + name: "valid byte sequence", + in: `;want=:eWF5Cg==:;something`, + want: "eWF5Cg==", + wantOk: true, + }, + { + name: "valid repeated key", + in: `;want=:eWF5Cg==:;now;want=1;is;repeated;want="overwritten!"`, + want: "overwritten!", + wantOk: true, + }, + { + name: "valid parameter with content after", + in: `;want=:eWF5Cg==:;now;want=1;is;repeated;want="overwritten!", some=stuff`, + want: "overwritten!", + wantOk: true, + }, + { + name: "invalid parameter", + in: `;UPPERCASEKEY=NOT_ACCEPTED`, + }, + } + + for _, tc := range tests[len(tests)-1:] { + var got any + f := func(key, val string) { + if key != "want" { + return + } + switch { + case strings.HasPrefix(val, "?"): // Bool + got = val == "?1" + case strings.HasPrefix(val, `"`): // String + got = val[1 : len(val)-1] + case strings.HasPrefix(val, "*"): // Token + got = val + case strings.HasPrefix(val, ":"): // Byte sequence + got = val[1 : len(val)-1] + default: + if valConv, err := strconv.Atoi(val); err == nil { // Integer + got = valConv + return + } + if valConv, err := strconv.ParseFloat(val, 64); err == nil { // Float + got = valConv + return + } + } + } + consumed, rest, ok := consumeParameter(tc.in, f) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if got != tc.want { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if consumed+rest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, rest, tc.in) + } + } +} + +func TestParseParameter(t *testing.T) { + tests := []struct { + name string + in string + want any + wantOk bool + }{ + { + name: "valid parameter", + in: `;parameter;want="wantvalue"`, + want: "wantvalue", + wantOk: true, + }, + { + name: "valid parameter with content after", + in: `;want=:eWF5Cg==:;now;want=1;is;repeated;want="overwritten!", some=stuff`, + want: "overwritten!", + }, + { + name: "invalid parameter", + in: `;UPPERCASEKEY=NOT_ACCEPTED`, + }, + } + + for _, tc := range tests[len(tests)-1:] { + var got any + f := func(key, val string) { + if key != "want" { + return + } + switch { + case strings.HasPrefix(val, "?"): // Bool + got = val == "?1" + case strings.HasPrefix(val, `"`): // String + got = val[1 : len(val)-1] + case strings.HasPrefix(val, "*"): // Token + got = val + case strings.HasPrefix(val, ":"): // Byte sequence + got = val[1 : len(val)-1] + default: + if valConv, err := strconv.Atoi(val); err == nil { // Integer + got = valConv + return + } + if valConv, err := strconv.ParseFloat(val, 64); err == nil { // Float + got = valConv + return + } + } + } + ok := ParseParameter(tc.in, f) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if got != tc.want { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} + +func TestConsumeKey(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid basic key", + in: `fookey`, + want: `fookey`, + wantOk: true, + }, + { + name: "valid basic key with more content after", + in: `fookey,u=7`, + want: `fookey`, + wantOk: true, + }, + { + name: "invalid key", + in: `1keycannotstartwithnum`, + }, + } + + for _, tc := range tests { + got, gotRest, ok := consumeKey(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if got+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, gotRest, tc.in) + } + } +} + +func TestConsumeIntegerOrDecimal(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid integer", + in: "123456", + want: "123456", + wantOk: true, + }, + { + name: "valid integer with more content after", + in: "123456,12345", + want: "123456", + wantOk: true, + }, + { + name: "valid max integer", + in: "999999999999999", + want: "999999999999999", + wantOk: true, + }, + { + name: "valid min integer", + in: "-999999999999999", + want: "-999999999999999", + wantOk: true, + }, + { + name: "invalid integer too high", + in: "9999999999999999", + }, + { + name: "invalid integer too low", + in: "-9999999999999999", + }, + { + name: "valid decimal", + in: "-123456789012.123", + want: "-123456789012.123", + wantOk: true, + }, + { + name: "invalid decimal integer component too long", + in: "1234567890123.1", + }, + { + name: "invalid decimal fraction component too long", + in: "1.1234", + }, + { + name: "invalid decimal trailing dot", + in: "1.", + }, + } + + for _, tc := range tests { + got, gotRest, ok := consumeIntegerOrDecimal(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if got+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, gotRest, tc.in) + } + } +} + +func TestParseInteger(t *testing.T) { + tests := []struct { + name string + in string + want int64 + wantOk bool + }{ + { + name: "valid integer", + in: "123456", + want: 123456, + wantOk: true, + }, + { + name: "valid integer with more content after", + in: "123456,12345", + }, + { + name: "valid max integer", + in: "999999999999999", + want: 999999999999999, + wantOk: true, + }, + { + name: "valid min integer", + in: "-999999999999999", + want: -999999999999999, + wantOk: true, + }, + { + name: "invalid integer too high", + in: "9999999999999999", + }, + { + name: "invalid integer too low", + in: "-9999999999999999", + }, + { + name: "invalid integer with fraction", + in: "-123456789012.123", + }, + } + + for _, tc := range tests { + got, ok := ParseInteger(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} + +func TestParseDecimal(t *testing.T) { + tests := []struct { + name string + in string + want float64 + wantOk bool + }{ + { + name: "valid decimal", + in: "123456.789", + want: 123456.789, + wantOk: true, + }, + { + name: "valid decimal with more content after", + in: "123456.789, 123", + }, + { + name: "invalid decimal with no fraction", + in: "123456", + }, + { + name: "invalid decimal integer component too long", + in: "1234567890123.1", + }, + { + name: "invalid decimal fraction component too long", + in: "1.1234", + }, + { + name: "invalid decimal trailing dot", + in: "1.", + }, + } + + for _, tc := range tests { + got, ok := ParseDecimal(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} + +func TestConsumeString(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid basic string", + in: `"foo bar"`, + want: `"foo bar"`, + wantOk: true, + }, + { + name: "valid basic string with more content after", + in: `"foo bar", a=3`, + want: `"foo bar"`, + wantOk: true, + }, + { + name: "valid string with escaped dquote", + in: `"foo bar \""`, + want: `"foo bar \""`, + wantOk: true, + }, + { + name: "invalid string no starting dquote", + in: `foo bar"`, + }, + { + name: "invalid string no closing dquote", + in: `"foo bar`, + }, + { + name: "invalid string invalid character", + in: string([]byte{'"', 0x00, '"'}), + }, + } + + for _, tc := range tests { + got, gotRest, ok := consumeString(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if got+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, gotRest, tc.in) + } + } +} + +func TestParseString(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid basic string", + in: `"foo bar"`, + want: "foo bar", + wantOk: true, + }, + { + name: "valid basic string with more content after", + in: `"foo bar", a=3`, + }, + { + name: "valid string with escaped dquote", + in: `"foo bar \""`, + want: `foo bar \"`, + wantOk: true, + }, + { + name: "invalid string no starting dquote", + in: `foo bar"`, + }, + { + name: "invalid string no closing dquote", + in: `"foo bar`, + }, + { + name: "invalid string invalid character", + in: string([]byte{'"', 0x00, '"'}), + }, + } + + for _, tc := range tests { + got, ok := ParseString(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} + +func TestConsumeToken(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid token", + in: "*atoken", + want: "*atoken", + wantOk: true, + }, + { + name: "valid token with more content after", + in: "*atoken something", + want: "*atoken", + wantOk: true, + }, + { + name: "invalid token", + in: "0invalid", + }, + } + + for _, tc := range tests { + got, gotRest, ok := consumeToken(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if got+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, gotRest, tc.in) + } + } +} + +func TestParseToken(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid token", + in: "a_b-c.d3:f%00/*", + want: "a_b-c.d3:f%00/*", + wantOk: true, + }, + { + name: "valid token with uppercase", + in: "FOOBAR", + want: "FOOBAR", + wantOk: true, + }, + { + name: "valid token with content after", + in: "FOOBAR, foobar", + }, + { + name: "invalid token", + in: "0invalid", + }, + } + + for _, tc := range tests { + got, ok := ParseToken(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} + +func TestConsumeByteSequence(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid byte sequence", + in: ":aGVsbG8gd29ybGQ=:", + want: ":aGVsbG8gd29ybGQ=:", + wantOk: true, + }, + { + name: "valid byte sequence with more content after", + in: ":aGVsbG8gd29ybGQ=::aGVsbG8gd29ybGQ=:", + want: ":aGVsbG8gd29ybGQ=:", + wantOk: true, + }, + { + name: "invalid byte sequence character", + in: ":-:", + }, + { + name: "invalid byte sequence opening", + in: "aGVsbG8gd29ybGQ=:", + }, + { + name: "invalid byte sequence closing", + in: ":aGVsbG8gd29ybGQ=", + }, + } + + for _, tc := range tests { + got, gotRest, ok := consumeByteSequence(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if got+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, gotRest, tc.in) + } + } +} + +func TestParseByteSequence(t *testing.T) { + tests := []struct { + name string + in string + want []byte + wantOk bool + }{ + { + name: "valid byte sequence", + in: ":aGVsbG8gd29ybGQ=:", + want: []byte("aGVsbG8gd29ybGQ="), + wantOk: true, + }, + { + name: "valid byte sequence with more content after", + in: ":aGVsbG8gd29ybGQ=::aGVsbG8gd29ybGQ=:", + }, + { + name: "invalid byte sequence character", + in: ":-:", + }, + { + name: "invalid byte sequence opening", + in: "aGVsbG8gd29ybGQ=:", + }, + { + name: "invalid byte sequence closing", + in: ":aGVsbG8gd29ybGQ=", + }, + } + + for _, tc := range tests { + got, ok := ParseByteSequence(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if !slices.Equal(tc.want, got) { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} + +func TestConsumeBoolean(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid boolean", + in: "?0", + want: "?0", + wantOk: true, + }, + { + name: "valid boolean with more content after", + in: "?1, a=1", + want: "?1", + wantOk: true, + }, + { + name: "invalid boolean", + in: "!2", + }, + } + + for _, tc := range tests { + got, gotRest, ok := consumeBoolean(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if got+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, gotRest, tc.in) + } + } +} + +func TestParseBoolean(t *testing.T) { + tests := []struct { + name string + in string + want bool + wantOk bool + }{ + { + name: "valid boolean false", + in: "?0", + want: false, + wantOk: true, + }, + { + name: "valid boolean true", + in: "?1", + want: true, + wantOk: true, + }, + { + name: "valid boolean with more content after", + in: "?1, a=1", + }, + { + name: "invalid boolean", + in: "?2", + }, + } + + for _, tc := range tests { + got, ok := ParseBoolean(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} + +func TestConsumeDate(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid zero date", + in: "@0", + want: "@0", + wantOk: true, + }, + { + name: "valid positive date", + in: "@1659578233", + want: "@1659578233", + wantOk: true, + }, + { + name: "valid negative date", + in: "@-1659578233", + want: "@-1659578233", + wantOk: true, + }, + { + name: "valid large date", + in: "@25340221440", + want: "@25340221440", + wantOk: true, + }, + { + name: "valid small date", + in: "@-62135596800", + want: "@-62135596800", + wantOk: true, + }, + { + name: "invalid decimal date", + in: "@1.2", + }, + { + name: "valid date with more content after", + in: "@1659578233, foo;bar", + want: "@1659578233", + wantOk: true, + }, + } + + for _, tc := range tests { + got, gotRest, ok := consumeDate(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if got+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, gotRest, tc.in) + } + } +} + +func TestParseDate(t *testing.T) { + tests := []struct { + name string + in string + want time.Time + wantOk bool + }{ + { + name: "valid zero date", + in: "@0", + want: time.Unix(0, 0), + wantOk: true, + }, + { + name: "valid positive date", + in: "@1659578233", + want: time.Date(2022, 8, 4, 1, 57, 13, 0, time.UTC).Local(), + wantOk: true, + }, + { + name: "valid negative date", + in: "@-1659578233", + want: time.Date(1917, 5, 30, 22, 2, 47, 0, time.UTC).Local(), + wantOk: true, + }, + { + name: "valid max date required", + in: "@253402214400", + want: time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC).Local(), + wantOk: true, + }, + { + name: "valid min date required", + in: "@-62135596800", + want: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC).Local(), + wantOk: true, + }, + { + name: "invalid date with fraction", + in: "@0.123", + }, + { + name: "valid date with more content after", + in: "@0, @0", + }, + } + + for _, tc := range tests { + got, ok := ParseDate(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} + +func TestConsumeDisplayString(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid ascii string", + in: "%\" !%22#$%25&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\"", + want: "%\" !%22#$%25&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\"", + wantOk: true, + }, + { + name: "valid lowercase non-ascii string", + in: `%"f%c3%bc%c3%bc"`, + want: `%"f%c3%bc%c3%bc"`, + wantOk: true, + }, + { + name: "invalid uppercase non-ascii string", + in: `%"f%C3%BC%C3%BC"`, + }, + { + name: "invalid unquoted string", + in: "%foo", + }, + { + name: "invalid string missing initial quote", + in: `%foo"`, + }, + { + name: "invalid string missing closing quote", + in: `%"foo`, + }, + { + name: "invalid tab in string", + in: "%\"\t\"", + }, + { + name: "invalid newline in string", + in: "%\"\n\"", + }, + { + name: "invalid single quoted string", + in: `%'foo'`, + }, + { + name: "invalid string bad escaping", + in: `%\"foo %a"`, + }, + { + name: "valid string with escaped quotes", + in: `%"foo %22bar%22 \\ baz"`, + want: `%"foo %22bar%22 \\ baz"`, + wantOk: true, + }, + { + name: "invalid sequence id utf-8 string", + in: `%"%a0%a1"`, + }, + { + name: "invalid 2 bytes sequence utf-8 string", + in: `%"%c3%28"`, + }, + { + name: "invalid 3 bytes sequence utf-8 string", + in: `%"%e2%28%a1"`, + }, + { + name: "invalid 4 bytes sequence utf-8 string", + in: `%"%f0%28%8c%28"`, + }, + { + name: "invalid hex utf-8 string", + in: `%"%g0%1w"`, + }, + { + name: "valid byte order mark in display string", + in: `%"BOM: %ef%bb%bf"`, + want: `%"BOM: %ef%bb%bf"`, + wantOk: true, + }, + { + name: "valid string with content after", + in: `%"foo\nbar", foo;bar`, + want: `%"foo\nbar"`, + wantOk: true, + }, + { + name: "invalid unfinished 4 bytes rune", + in: `%"%f0%9f%98"`, + }, + } + + for _, tc := range tests { + got, gotRest, ok := consumeDisplayString(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + if got+gotRest != tc.in { + t.Fatalf("test %q: %#v + %#v != %#v", tc.name, got, gotRest, tc.in) + } + } +} + +func TestParseDisplayString(t *testing.T) { + tests := []struct { + name string + in string + want string + wantOk bool + }{ + { + name: "valid ascii string", + in: "%\" !%22#$%25&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\"", + want: " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~", + wantOk: true, + }, + { + name: "valid lowercase non-ascii string", + in: `%"f%c3%bc%c3%bc"`, + want: "füü", + wantOk: true, + }, + { + name: "invalid uppercase non-ascii string", + in: `%"f%C3%BC%C3%BC"`, + }, + { + name: "invalid unquoted string", + in: "%foo", + }, + { + name: "invalid string missing initial quote", + in: `%foo"`, + }, + { + name: "invalid string missing closing quote", + in: `%"foo`, + }, + { + name: "invalid tab in string", + in: "%\"\t\"", + }, + { + name: "invalid newline in string", + in: "%\"\n\"", + }, + { + name: "invalid single quoted string", + in: `%'foo'`, + }, + { + name: "invalid string bad escaping", + in: `%\"foo %a"`, + }, + { + name: "valid string with escaped quotes", + in: "%\"foo %22bar%22 \\ baz\"", + want: "foo \"bar\" \\ baz", + wantOk: true, + }, + { + name: "invalid sequence id utf-8 string", + in: `%"%a0%a1"`, + }, + { + name: "invalid 2 bytes sequence utf-8 string", + in: `%"%c3%28"`, + }, + { + name: "invalid 3 bytes sequence utf-8 string", + in: `%"%e2%28%a1"`, + }, + { + name: "invalid 4 bytes sequence utf-8 string", + in: `%"%f0%28%8c%28"`, + }, + { + name: "invalid hex utf-8 string", + in: `%"%g0%1w"`, + }, + { + name: "valid byte order mark in display string", + in: `%"BOM: %ef%bb%bf"`, + want: "BOM: \uFEFF", + wantOk: true, + }, + { + name: "valid string with content after", + in: `%"foo\nbar", foo;bar`, + }, + { + name: "invalid unfinished 4 bytes rune", + in: `%"%f0%9f%98"`, + }, + } + + for _, tc := range tests { + got, ok := ParseDisplayString(tc.in) + if ok != tc.wantOk { + t.Fatalf("test %q: want ok to be %v, got: %v", tc.name, tc.wantOk, ok) + } + if tc.want != got { + t.Fatalf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tc.name, got, tc.want) + } + } +} diff --git a/internal/quic/cmd/interop/main.go b/internal/quic/cmd/interop/main.go index 5b652a2b15..6cc8e96a59 100644 --- a/internal/quic/cmd/interop/main.go +++ b/internal/quic/cmd/interop/main.go @@ -84,7 +84,20 @@ func main() { // "[...] offer only ChaCha20 as a ciphersuite." // // crypto/tls does not support configuring TLS 1.3 ciphersuites, - // so we can't support this test. + // so we can't support this test on the client. + if *listen != "" && len(urls) == 0 { + config.TLSConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + if len(hello.CipherSuites) == 1 && hello.CipherSuites[0] == tls.TLS_CHACHA20_POLY1305_SHA256 { + return nil, nil + } + return nil, fmt.Errorf("this test requires the client to offer only ChaCha20") + } + basicTest(ctx, config, urls) + return + } + case "ecn": + // TODO: We give ECN feedback to the sender, but we don't add our own + // ECN marks to outgoing packets. case "transfer": // "The client should use small initial flow control windows // for both stream- and connection-level flow control @@ -101,7 +114,11 @@ func main() { case "resumption": // TODO case "retry": - // TODO + if *listen != "" && len(urls) == 0 { + config.RequireAddressValidation = true + } + basicTest(ctx, config, urls) + return case "versionnegotiation": // "The client should start a connection using // an unsupported version number [...]" diff --git a/internal/quic/quicwire/wire.go b/internal/quic/quicwire/wire.go index 0edf42227d..06682520c8 100644 --- a/internal/quic/quicwire/wire.go +++ b/internal/quic/quicwire/wire.go @@ -46,7 +46,7 @@ func ConsumeVarint(b []byte) (v uint64, n int) { return 0, -1 } -// consumeVarintInt64 parses a variable-length integer as an int64. +// ConsumeVarintInt64 parses a variable-length integer as an int64. func ConsumeVarintInt64(b []byte) (v int64, n int) { u, n := ConsumeVarint(b) // QUIC varints are 62-bits large, so this conversion can never overflow. diff --git a/internal/socks/socks.go b/internal/socks/socks.go index 84fcc32b63..8eedb84cec 100644 --- a/internal/socks/socks.go +++ b/internal/socks/socks.go @@ -297,7 +297,7 @@ func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, b = append(b, up.Username...) b = append(b, byte(len(up.Password))) b = append(b, up.Password...) - // TODO(mikio): handle IO deadlines and cancelation if + // TODO(mikio): handle IO deadlines and cancellation if // necessary if _, err := rw.Write(b); err != nil { return err diff --git a/nettest/conntest.go b/nettest/conntest.go index 4297d408c0..8b98dfe21c 100644 --- a/nettest/conntest.go +++ b/nettest/conntest.go @@ -142,7 +142,7 @@ func testPingPong(t *testing.T, c1, c2 net.Conn) { } // testRacyRead tests that it is safe to mutate the input Read buffer -// immediately after cancelation has occurred. +// immediately after cancellation has occurred. func testRacyRead(t *testing.T, c1, c2 net.Conn) { go chunkedCopy(c2, rand.New(rand.NewSource(0))) @@ -170,7 +170,7 @@ func testRacyRead(t *testing.T, c1, c2 net.Conn) { } // testRacyWrite tests that it is safe to mutate the input Write buffer -// immediately after cancelation has occurred. +// immediately after cancellation has occurred. func testRacyWrite(t *testing.T, c1, c2 net.Conn) { go chunkedCopy(io.Discard, c2) @@ -318,7 +318,7 @@ func testCloseTimeout(t *testing.T, c1, c2 net.Conn) { defer wg.Wait() wg.Add(3) - // Test for cancelation upon connection closure. + // Test for cancellation upon connection closure. c1.SetDeadline(neverTimeout) go func() { defer wg.Done() diff --git a/publicsuffix/list.go b/publicsuffix/list.go index 047cb30eb1..7ab8b3cf13 100644 --- a/publicsuffix/list.go +++ b/publicsuffix/list.go @@ -51,6 +51,7 @@ package publicsuffix // import "golang.org/x/net/publicsuffix" import ( "fmt" "net/http/cookiejar" + "net/netip" "strings" ) @@ -84,6 +85,10 @@ func (list) String() string { // domains like "foo.appspot.com" can be found at // https://wiki.mozilla.org/Public_Suffix_List/Use_Cases func PublicSuffix(domain string) (publicSuffix string, icann bool) { + if _, err := netip.ParseAddr(domain); err == nil { + return domain, false + } + lo, hi := uint32(0), uint32(numTLD) s, suffix, icannNode, wildcard := domain, len(domain), false, false loop: diff --git a/publicsuffix/list_test.go b/publicsuffix/list_test.go index 7a1bb0fe5c..7648fdb5f7 100644 --- a/publicsuffix/list_test.go +++ b/publicsuffix/list_test.go @@ -5,6 +5,7 @@ package publicsuffix import ( + "net/netip" "sort" "strings" "testing" @@ -85,6 +86,11 @@ var publicSuffixTestCases = []struct { // Empty string. {"", "", false}, + // IP addresses don't have a domain hierarchy + {"192.0.2.0", "192.0.2.0", false}, + {"::ffff:192.0.2.0", "::ffff:192.0.2.0", false}, + {"2001:db8::", "2001:db8::", false}, + // The .ao rules are: // ao // ed.ao @@ -332,6 +338,10 @@ type slowPublicSuffixRule struct { // This function returns the public suffix, not the registrable domain, and so // it stops after step 6. func slowPublicSuffix(domain string) (string, bool) { + if _, err := netip.ParseAddr(domain); err == nil { + return domain, false + } + match := func(rulePart, domainPart string) bool { switch rulePart[0] { case '*': diff --git a/quic/acks.go b/quic/acks.go index d4ac4496e1..90f82bed03 100644 --- a/quic/acks.go +++ b/quic/acks.go @@ -25,6 +25,15 @@ type ackState struct { // The number of ack-eliciting packets in seen that we have not yet acknowledged. unackedAckEliciting int + + // Total ECN counters for this packet number space. + ecn ecnCounts +} + +type ecnCounts struct { + t0 int + t1 int + ce int } // shouldProcess reports whether a packet should be handled or discarded. @@ -43,10 +52,10 @@ func (acks *ackState) shouldProcess(num packetNumber) bool { } // receive records receipt of a packet. -func (acks *ackState) receive(now time.Time, space numberSpace, num packetNumber, ackEliciting bool) { +func (acks *ackState) receive(now time.Time, space numberSpace, num packetNumber, ackEliciting bool, ecn ecnBits) { if ackEliciting { acks.unackedAckEliciting++ - if acks.mustAckImmediately(space, num) { + if acks.mustAckImmediately(space, num, ecn) { acks.nextAck = now } else if acks.nextAck.IsZero() { // This packet does not need to be acknowledged immediately, @@ -70,6 +79,15 @@ func (acks *ackState) receive(now time.Time, space numberSpace, num packetNumber acks.maxRecvTime = now } + switch ecn { + case ecnECT0: + acks.ecn.t0++ + case ecnECT1: + acks.ecn.t1++ + case ecnCE: + acks.ecn.ce++ + } + // Limit the total number of ACK ranges by dropping older ranges. // // Remembering more ranges results in larger ACK frames. @@ -92,7 +110,7 @@ func (acks *ackState) receive(now time.Time, space numberSpace, num packetNumber // mustAckImmediately reports whether an ack-eliciting packet must be acknowledged immediately, // or whether the ack may be deferred. -func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber) bool { +func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber, ecn ecnBits) bool { // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1 if space != appDataSpace { // "[...] all ack-eliciting Initial and Handshake packets [...]" @@ -128,6 +146,12 @@ func (acks *ackState) mustAckImmediately(space numberSpace, num packetNumber) bo // there are no gaps. If it does not, there must be a gap. return true } + // "[...] packets marked with the ECN Congestion Experienced (CE) codepoint + // in the IP header SHOULD be acknowledged immediately [...]" + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.1-9 + if ecn == ecnCE { + return true + } // "[...] SHOULD send an ACK frame after receiving at least two ack-eliciting packets." // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.2.2 // diff --git a/quic/acks_test.go b/quic/acks_test.go index 7fca5617ba..2abdc31ff8 100644 --- a/quic/acks_test.go +++ b/quic/acks_test.go @@ -17,7 +17,7 @@ func TestAcksDisallowDuplicate(t *testing.T) { receive := []packetNumber{0, 1, 2, 4, 7, 6, 9} seen := map[packetNumber]bool{} for i, pnum := range receive { - acks.receive(now, appDataSpace, pnum, true) + acks.receive(now, appDataSpace, pnum, true, ecnNotECT) seen[pnum] = true for ppnum := packetNumber(0); ppnum < 11; ppnum++ { if got, want := acks.shouldProcess(ppnum), !seen[ppnum]; got != want { @@ -32,7 +32,7 @@ func TestAcksDisallowDiscardedAckRanges(t *testing.T) { acks := ackState{} now := time.Now() for pnum := packetNumber(0); ; pnum += 2 { - acks.receive(now, appDataSpace, pnum, true) + acks.receive(now, appDataSpace, pnum, true, ecnNotECT) send, _ := acks.acksToSend(now) for ppnum := packetNumber(0); ppnum < packetNumber(send.min()); ppnum++ { if acks.shouldProcess(ppnum) { @@ -158,13 +158,13 @@ func TestAcksSent(t *testing.T) { start := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) for _, p := range test.ackedPackets { t.Logf("receive %v.%v, ack-eliciting=%v", test.space, p.pnum, p.ackEliciting) - acks.receive(start, test.space, p.pnum, p.ackEliciting) + acks.receive(start, test.space, p.pnum, p.ackEliciting, ecnNotECT) } t.Logf("send an ACK frame") acks.sentAck() for _, p := range test.packets { t.Logf("receive %v.%v, ack-eliciting=%v", test.space, p.pnum, p.ackEliciting) - acks.receive(start, test.space, p.pnum, p.ackEliciting) + acks.receive(start, test.space, p.pnum, p.ackEliciting, ecnNotECT) } switch { case len(test.wantAcks) == 0: @@ -208,13 +208,13 @@ func TestAcksSent(t *testing.T) { func TestAcksDiscardAfterAck(t *testing.T) { acks := ackState{} now := time.Now() - acks.receive(now, appDataSpace, 0, true) - acks.receive(now, appDataSpace, 2, true) - acks.receive(now, appDataSpace, 4, true) - acks.receive(now, appDataSpace, 5, true) - acks.receive(now, appDataSpace, 6, true) + acks.receive(now, appDataSpace, 0, true, ecnNotECT) + acks.receive(now, appDataSpace, 2, true, ecnNotECT) + acks.receive(now, appDataSpace, 4, true, ecnNotECT) + acks.receive(now, appDataSpace, 5, true, ecnNotECT) + acks.receive(now, appDataSpace, 6, true, ecnNotECT) acks.handleAck(6) // discards all ranges prior to the one containing packet 6 - acks.receive(now, appDataSpace, 7, true) + acks.receive(now, appDataSpace, 7, true, ecnNotECT) got, _ := acks.acksToSend(now) if len(got) != 1 { t.Errorf("acks.acksToSend contains ranges prior to last acknowledged ack; got %v, want 1 range", got) @@ -224,9 +224,9 @@ func TestAcksDiscardAfterAck(t *testing.T) { func TestAcksLargestSeen(t *testing.T) { acks := ackState{} now := time.Now() - acks.receive(now, appDataSpace, 0, true) - acks.receive(now, appDataSpace, 4, true) - acks.receive(now, appDataSpace, 1, true) + acks.receive(now, appDataSpace, 0, true, ecnNotECT) + acks.receive(now, appDataSpace, 4, true, ecnNotECT) + acks.receive(now, appDataSpace, 1, true, ecnNotECT) if got, want := acks.largestSeen(), packetNumber(4); got != want { t.Errorf("acks.largestSeen() = %v, want %v", got, want) } diff --git a/quic/bench_test.go b/quic/bench_test.go index 9d8e5d2318..002b40e604 100644 --- a/quic/bench_test.go +++ b/quic/bench_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( diff --git a/quic/config_test.go b/quic/config_test.go index 3511cd4a54..df878dab00 100644 --- a/quic/config_test.go +++ b/quic/config_test.go @@ -2,11 +2,19 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic -import "testing" +import ( + "testing" + "testing/synctest" +) func TestConfigTransportParameters(t *testing.T) { + synctest.Test(t, testConfigTransportParameters) +} +func testConfigTransportParameters(t *testing.T) { const ( wantInitialMaxData = int64(1) wantInitialMaxStreamData = int64(2) diff --git a/quic/conn.go b/quic/conn.go index b9ec0e4059..fd812b8a28 100644 --- a/quic/conn.go +++ b/quic/conn.go @@ -67,11 +67,7 @@ type Conn struct { // connTestHooks override conn behavior in tests. type connTestHooks interface { // init is called after a conn is created. - init() - - // nextMessage is called to request the next event from msgc. - // Used to give tests control of the connection event loop. - nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any) + init(first bool) // handleTLSEvent is called with each TLS event. handleTLSEvent(tls.QUICEvent) @@ -79,13 +75,6 @@ type connTestHooks interface { // newConnID is called to generate a new connection ID. // Permits tests to generate consistent connection IDs rather than random ones. newConnID(seq int64) ([]byte, error) - - // waitUntil blocks until the until func returns true or the context is done. - // Used to synchronize asynchronous blocking operations in tests. - waitUntil(ctx context.Context, until func() bool) error - - // timeNow returns the current time. - timeNow() time.Time } // newServerConnIDs is connection IDs associated with a new server connection. @@ -102,7 +91,6 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname s endpoint: e, config: config, peerAddr: unmapAddrPort(peerAddr), - msgc: make(chan any, 1), donec: make(chan struct{}), peerAckDelayExponent: -1, } @@ -177,7 +165,7 @@ func newConn(now time.Time, side connSide, cids newServerConnIDs, peerHostname s } if c.testHooks != nil { - c.testHooks.init() + c.testHooks.init(true) } go c.loop(now) return c, nil @@ -299,17 +287,12 @@ func (c *Conn) loop(now time.Time) { // The connection timer sends a message to the connection loop on expiry. // We need to give it an expiry when creating it, so set the initial timeout to // an arbitrary large value. The timer will be reset before this expires (and it - // isn't a problem if it does anyway). Skip creating the timer in tests which - // take control of the connection message loop. - var timer *time.Timer + // isn't a problem if it does anyway). var lastTimeout time.Time - hooks := c.testHooks - if hooks == nil { - timer = time.AfterFunc(1*time.Hour, func() { - c.sendMsg(timerEvent{}) - }) - defer timer.Stop() - } + timer := time.AfterFunc(1*time.Hour, func() { + c.sendMsg(timerEvent{}) + }) + defer timer.Stop() for c.lifetime.state != connStateDone { sendTimeout := c.maybeSend(now) // try sending @@ -326,10 +309,7 @@ func (c *Conn) loop(now time.Time) { } var m any - if hooks != nil { - // Tests only: Wait for the test to tell us to continue. - now, m = hooks.nextMessage(c.msgc, nextTimeout) - } else if !nextTimeout.IsZero() && nextTimeout.Before(now) { + if !nextTimeout.IsZero() && nextTimeout.Before(now) { // A connection timer has expired. now = time.Now() m = timerEvent{} @@ -372,6 +352,9 @@ func (c *Conn) loop(now time.Time) { case func(time.Time, *Conn): // Send a func to msgc to run it on the main Conn goroutine m(now, c) + case func(now, next time.Time, _ *Conn): + // Send a func to msgc to run it on the main Conn goroutine + m(now, nextTimeout, c) default: panic(fmt.Sprintf("quic: unrecognized conn message %T", m)) } @@ -410,31 +393,7 @@ func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) er defer close(donec) f(now, c) } - if c.testHooks != nil { - // In tests, we can't rely on being able to send a message immediately: - // c.msgc might be full, and testConnHooks.nextMessage might be waiting - // for us to block before it processes the next message. - // To avoid a deadlock, we send the message in waitUntil. - // If msgc is empty, the message is buffered. - // If msgc is full, we block and let nextMessage process the queue. - msgc := c.msgc - c.testHooks.waitUntil(ctx, func() bool { - for { - select { - case msgc <- msg: - msgc = nil // send msg only once - case <-donec: - return true - case <-c.donec: - return true - default: - return false - } - } - }) - } else { - c.sendMsg(msg) - } + c.sendMsg(msg) select { case <-donec: case <-c.donec: @@ -444,16 +403,6 @@ func (c *Conn) runOnLoop(ctx context.Context, f func(now time.Time, c *Conn)) er } func (c *Conn) waitOnDone(ctx context.Context, ch <-chan struct{}) error { - if c.testHooks != nil { - return c.testHooks.waitUntil(ctx, func() bool { - select { - case <-ch: - return true - default: - } - return false - }) - } // Check the channel before the context. // We always prefer to return results when available, // even when provided with an already-canceled context. diff --git a/quic/conn_async_test.go b/quic/conn_async_test.go index f261e90025..08cc7d337a 100644 --- a/quic/conn_async_test.go +++ b/quic/conn_async_test.go @@ -2,44 +2,21 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( "context" "errors" "fmt" - "path/filepath" - "runtime" - "sync" + "testing/synctest" ) -// asyncTestState permits handling asynchronous operations in a synchronous test. -// -// For example, a test may want to write to a stream and observe that -// STREAM frames are sent with the contents of the write in response -// to MAX_STREAM_DATA frames received from the peer. -// The Stream.Write is an asynchronous operation, but the test is simpler -// if we can start the write, observe the first STREAM frame sent, -// send a MAX_STREAM_DATA frame, observe the next STREAM frame sent, etc. -// -// We do this by instrumenting points where operations can block. -// We start async operations like Write in a goroutine, -// and wait for the operation to either finish or hit a blocking point. -// When the connection event loop is idle, we check a list of -// blocked operations to see if any can be woken. -type asyncTestState struct { - mu sync.Mutex - notify chan struct{} - blocked map[*blockedAsync]struct{} -} - // An asyncOp is an asynchronous operation that results in (T, error). type asyncOp[T any] struct { - v T - err error - - caller string - tc *testConn + v T + err error donec chan struct{} cancelFunc context.CancelFunc } @@ -47,17 +24,18 @@ type asyncOp[T any] struct { // cancel cancels the async operation's context, and waits for // the operation to complete. func (a *asyncOp[T]) cancel() { + synctest.Wait() select { case <-a.donec: return // already done default: } a.cancelFunc() - <-a.tc.asyncTestState.notify + synctest.Wait() select { case <-a.donec: default: - panic(fmt.Errorf("%v: async op failed to finish after being canceled", a.caller)) + panic(fmt.Errorf("async op failed to finish after being canceled")) } } @@ -71,115 +49,30 @@ var errNotDone = errors.New("async op is not done") // control over the progress of operations, an asyncOp can only // become done in reaction to the test taking some action. func (a *asyncOp[T]) result() (v T, err error) { - a.tc.wait() + synctest.Wait() select { case <-a.donec: return a.v, a.err default: - return v, errNotDone + return a.v, errNotDone } } -// A blockedAsync is a blocked async operation. -type blockedAsync struct { - until func() bool // when this returns true, the operation is unblocked - donec chan struct{} // closed when the operation is unblocked -} - -type asyncContextKey struct{} - // runAsync starts an asynchronous operation. // // The function f should call a blocking function such as // Stream.Write or Conn.AcceptStream and return its result. // It must use the provided context. func runAsync[T any](tc *testConn, f func(context.Context) (T, error)) *asyncOp[T] { - as := &tc.asyncTestState - if as.notify == nil { - as.notify = make(chan struct{}) - as.mu.Lock() - as.blocked = make(map[*blockedAsync]struct{}) - as.mu.Unlock() - } - _, file, line, _ := runtime.Caller(1) - ctx := context.WithValue(context.Background(), asyncContextKey{}, true) - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(tc.t.Context()) a := &asyncOp[T]{ - tc: tc, - caller: fmt.Sprintf("%v:%v", filepath.Base(file), line), donec: make(chan struct{}), cancelFunc: cancel, } go func() { + defer close(a.donec) a.v, a.err = f(ctx) - close(a.donec) - as.notify <- struct{}{} }() - tc.t.Cleanup(func() { - if _, err := a.result(); err == errNotDone { - tc.t.Errorf("%v: async operation is still executing at end of test", a.caller) - a.cancel() - } - }) - // Wait for the operation to either finish or block. - <-as.notify - tc.wait() + synctest.Wait() return a } - -// waitUntil waits for a blocked async operation to complete. -// The operation is complete when the until func returns true. -func (as *asyncTestState) waitUntil(ctx context.Context, until func() bool) error { - if until() { - return nil - } - if err := ctx.Err(); err != nil { - // Context has already expired. - return err - } - if ctx.Value(asyncContextKey{}) == nil { - // Context is not one that we've created, and hasn't expired. - // This probably indicates that we've tried to perform a - // blocking operation without using the async test harness here, - // which may have unpredictable results. - panic("blocking async point with unexpected Context") - } - b := &blockedAsync{ - until: until, - donec: make(chan struct{}), - } - // Record this as a pending blocking operation. - as.mu.Lock() - as.blocked[b] = struct{}{} - as.mu.Unlock() - // Notify the creator of the operation that we're blocked, - // and wait to be woken up. - as.notify <- struct{}{} - select { - case <-b.donec: - case <-ctx.Done(): - return ctx.Err() - } - return nil -} - -// wakeAsync tries to wake up a blocked async operation. -// It returns true if one was woken, false otherwise. -func (as *asyncTestState) wakeAsync() bool { - as.mu.Lock() - var woken *blockedAsync - for w := range as.blocked { - if w.until() { - woken = w - delete(as.blocked, w) - break - } - } - as.mu.Unlock() - if woken == nil { - return false - } - close(woken.donec) - <-as.notify // must not hold as.mu while blocked here - return true -} diff --git a/quic/conn_close.go b/quic/conn_close.go index 5001ab13f0..d22f3df5c8 100644 --- a/quic/conn_close.go +++ b/quic/conn_close.go @@ -109,7 +109,7 @@ func (c *Conn) setState(now time.Time, state connState) { } } -// confirmHandshake is called when the TLS handshake completes. +// handshakeDone is called when the TLS handshake completes. func (c *Conn) handshakeDone() { close(c.lifetime.readyc) } diff --git a/quic/conn_close_test.go b/quic/conn_close_test.go index 0b37b3ecfc..472a8f2d62 100644 --- a/quic/conn_close_test.go +++ b/quic/conn_close_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -9,10 +11,14 @@ import ( "crypto/tls" "errors" "testing" + "testing/synctest" "time" ) func TestConnCloseResponseBackoff(t *testing.T) { + synctest.Test(t, testConnCloseResponseBackoff) +} +func testConnCloseResponseBackoff(t *testing.T) { tc := newTestConn(t, clientSide, func(c *Config) { clear(c.StatelessResetKey[:]) }) @@ -34,18 +40,18 @@ func TestConnCloseResponseBackoff(t *testing.T) { tc.writeFrames(packetType1RTT, debugFramePing{}) tc.wantIdle("packets received immediately after CONN_CLOSE receive no response") - tc.advance(1100 * time.Microsecond) + time.Sleep(1100 * time.Microsecond) tc.writeFrames(packetType1RTT, debugFramePing{}) tc.wantFrame("receiving packet 1.1ms after CONN_CLOSE generates another CONN_CLOSE", packetType1RTT, debugFrameConnectionCloseTransport{ code: errNo, }) - tc.advance(1100 * time.Microsecond) + time.Sleep(1100 * time.Microsecond) tc.writeFrames(packetType1RTT, debugFramePing{}) tc.wantIdle("no response to packet, because CONN_CLOSE backoff is now 2ms") - tc.advance(1000 * time.Microsecond) + time.Sleep(1000 * time.Microsecond) tc.writeFrames(packetType1RTT, debugFramePing{}) tc.wantFrame("2ms since last CONN_CLOSE, receiving a packet generates another CONN_CLOSE", packetType1RTT, debugFrameConnectionCloseTransport{ @@ -55,7 +61,7 @@ func TestConnCloseResponseBackoff(t *testing.T) { t.Errorf("conn.Wait() = %v, want still waiting", err) } - tc.advance(100000 * time.Microsecond) + time.Sleep(100000 * time.Microsecond) tc.writeFrames(packetType1RTT, debugFramePing{}) tc.wantIdle("drain timer expired, no more responses") @@ -68,6 +74,9 @@ func TestConnCloseResponseBackoff(t *testing.T) { } func TestConnCloseWithPeerResponse(t *testing.T) { + synctest.Test(t, testConnCloseWithPeerResponse) +} +func testConnCloseWithPeerResponse(t *testing.T) { qr := &qlogRecord{} tc := newTestConn(t, clientSide, qr.config) tc.handshake() @@ -99,7 +108,7 @@ func TestConnCloseWithPeerResponse(t *testing.T) { t.Errorf("non-blocking conn.Wait() = %v, want %v", err, wantErr) } - tc.advance(1 * time.Second) // long enough to exit the draining state + time.Sleep(1 * time.Second) // long enough to exit the draining state qr.wantEvents(t, jsonEvent{ "name": "connectivity:connection_closed", "data": map[string]any{ @@ -109,6 +118,9 @@ func TestConnCloseWithPeerResponse(t *testing.T) { } func TestConnClosePeerCloses(t *testing.T) { + synctest.Test(t, testConnClosePeerCloses) +} +func testConnClosePeerCloses(t *testing.T) { qr := &qlogRecord{} tc := newTestConn(t, clientSide, qr.config) tc.handshake() @@ -137,7 +149,7 @@ func TestConnClosePeerCloses(t *testing.T) { reason: "because", }) - tc.advance(1 * time.Second) // long enough to exit the draining state + time.Sleep(1 * time.Second) // long enough to exit the draining state qr.wantEvents(t, jsonEvent{ "name": "connectivity:connection_closed", "data": map[string]any{ @@ -147,6 +159,9 @@ func TestConnClosePeerCloses(t *testing.T) { } func TestConnCloseReceiveInInitial(t *testing.T) { + synctest.Test(t, testConnCloseReceiveInInitial) +} +func testConnCloseReceiveInInitial(t *testing.T) { tc := newTestConn(t, clientSide) tc.wantFrame("client sends Initial CRYPTO frame", packetTypeInitial, debugFrameCrypto{ @@ -171,6 +186,9 @@ func TestConnCloseReceiveInInitial(t *testing.T) { } func TestConnCloseReceiveInHandshake(t *testing.T) { + synctest.Test(t, testConnCloseReceiveInHandshake) +} +func testConnCloseReceiveInHandshake(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) tc.wantFrame("client sends Initial CRYPTO frame", @@ -204,6 +222,9 @@ func TestConnCloseReceiveInHandshake(t *testing.T) { } func TestConnCloseClosedByEndpoint(t *testing.T) { + synctest.Test(t, testConnCloseClosedByEndpoint) +} +func testConnCloseClosedByEndpoint(t *testing.T) { ctx := canceledContext() tc := newTestConn(t, clientSide) tc.handshake() @@ -231,6 +252,9 @@ func testConnCloseUnblocks(t *testing.T, f func(context.Context, *testConn) erro } func TestConnCloseUnblocksAcceptStream(t *testing.T) { + synctest.Test(t, testConnCloseUnblocksAcceptStream) +} +func testConnCloseUnblocksAcceptStream(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { _, err := tc.conn.AcceptStream(ctx) return err @@ -238,6 +262,9 @@ func TestConnCloseUnblocksAcceptStream(t *testing.T) { } func TestConnCloseUnblocksNewStream(t *testing.T) { + synctest.Test(t, testConnCloseUnblocksNewStream) +} +func testConnCloseUnblocksNewStream(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { _, err := tc.conn.NewStream(ctx) return err @@ -245,6 +272,9 @@ func TestConnCloseUnblocksNewStream(t *testing.T) { } func TestConnCloseUnblocksStreamRead(t *testing.T) { + synctest.Test(t, testConnCloseUnblocksStreamRead) +} +func testConnCloseUnblocksStreamRead(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) s.SetReadContext(ctx) @@ -255,6 +285,9 @@ func TestConnCloseUnblocksStreamRead(t *testing.T) { } func TestConnCloseUnblocksStreamWrite(t *testing.T) { + synctest.Test(t, testConnCloseUnblocksStreamWrite) +} +func testConnCloseUnblocksStreamWrite(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) s.SetWriteContext(ctx) @@ -267,6 +300,9 @@ func TestConnCloseUnblocksStreamWrite(t *testing.T) { } func TestConnCloseUnblocksStreamClose(t *testing.T) { + synctest.Test(t, testConnCloseUnblocksStreamClose) +} +func testConnCloseUnblocksStreamClose(t *testing.T) { testConnCloseUnblocks(t, func(ctx context.Context, tc *testConn) error { s := newLocalStream(t, tc, bidiStream) s.SetWriteContext(ctx) diff --git a/quic/conn_flow_test.go b/quic/conn_flow_test.go index 52ecf92254..d8d3ae76e6 100644 --- a/quic/conn_flow_test.go +++ b/quic/conn_flow_test.go @@ -2,14 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( "context" "testing" + "testing/synctest" ) func TestConnInflowReturnOnRead(t *testing.T) { + synctest.Test(t, testConnInflowReturnOnRead) +} +func testConnInflowReturnOnRead(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { c.MaxConnReadBufferSize = 64 }) @@ -41,6 +47,9 @@ func TestConnInflowReturnOnRead(t *testing.T) { } func TestConnInflowReturnOnRacingReads(t *testing.T) { + synctest.Test(t, testConnInflowReturnOnRacingReads) +} +func testConnInflowReturnOnRacingReads(t *testing.T) { // Perform two reads at the same time, // one for half of MaxConnReadBufferSize // and one for one byte. @@ -91,6 +100,9 @@ func TestConnInflowReturnOnRacingReads(t *testing.T) { } func TestConnInflowReturnOnClose(t *testing.T) { + synctest.Test(t, testConnInflowReturnOnClose) +} +func testConnInflowReturnOnClose(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { c.MaxConnReadBufferSize = 64 }) @@ -107,6 +119,9 @@ func TestConnInflowReturnOnClose(t *testing.T) { } func TestConnInflowReturnOnReset(t *testing.T) { + synctest.Test(t, testConnInflowReturnOnReset) +} +func testConnInflowReturnOnReset(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { c.MaxConnReadBufferSize = 64 }) @@ -127,6 +142,9 @@ func TestConnInflowReturnOnReset(t *testing.T) { } func TestConnInflowStreamViolation(t *testing.T) { + synctest.Test(t, testConnInflowStreamViolation) +} +func testConnInflowStreamViolation(t *testing.T) { tc := newTestConn(t, serverSide, func(c *Config) { c.MaxConnReadBufferSize = 100 }) @@ -169,6 +187,9 @@ func TestConnInflowStreamViolation(t *testing.T) { } func TestConnInflowResetViolation(t *testing.T) { + synctest.Test(t, testConnInflowResetViolation) +} +func testConnInflowResetViolation(t *testing.T) { tc := newTestConn(t, serverSide, func(c *Config) { c.MaxConnReadBufferSize = 100 }) @@ -197,6 +218,9 @@ func TestConnInflowResetViolation(t *testing.T) { } func TestConnInflowMultipleStreams(t *testing.T) { + synctest.Test(t, testConnInflowMultipleStreams) +} +func testConnInflowMultipleStreams(t *testing.T) { tc := newTestConn(t, serverSide, func(c *Config) { c.MaxConnReadBufferSize = 128 }) @@ -247,6 +271,9 @@ func TestConnInflowMultipleStreams(t *testing.T) { } func TestConnOutflowBlocked(t *testing.T) { + synctest.Test(t, testConnOutflowBlocked) +} +func testConnOutflowBlocked(t *testing.T) { tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, permissiveTransportParameters, func(p *transportParameters) { @@ -291,6 +318,9 @@ func TestConnOutflowBlocked(t *testing.T) { } func TestConnOutflowMaxDataDecreases(t *testing.T) { + synctest.Test(t, testConnOutflowMaxDataDecreases) +} +func testConnOutflowMaxDataDecreases(t *testing.T) { tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, permissiveTransportParameters, func(p *transportParameters) { @@ -318,6 +348,9 @@ func TestConnOutflowMaxDataDecreases(t *testing.T) { } func TestConnOutflowMaxDataRoundRobin(t *testing.T) { + synctest.Test(t, testConnOutflowMaxDataRoundRobin) +} +func testConnOutflowMaxDataRoundRobin(t *testing.T) { ctx := canceledContext() tc := newTestConn(t, clientSide, permissiveTransportParameters, func(p *transportParameters) { @@ -370,6 +403,9 @@ func TestConnOutflowMaxDataRoundRobin(t *testing.T) { } func TestConnOutflowMetaAndData(t *testing.T) { + synctest.Test(t, testConnOutflowMetaAndData) +} +func testConnOutflowMetaAndData(t *testing.T) { tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream, permissiveTransportParameters, func(p *transportParameters) { @@ -398,6 +434,9 @@ func TestConnOutflowMetaAndData(t *testing.T) { } func TestConnOutflowResentData(t *testing.T) { + synctest.Test(t, testConnOutflowResentData) +} +func testConnOutflowResentData(t *testing.T) { tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream, permissiveTransportParameters, func(p *transportParameters) { diff --git a/quic/conn_id_test.go b/quic/conn_id_test.go index c9da0eb090..4b4da675d3 100644 --- a/quic/conn_id_test.go +++ b/quic/conn_id_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -11,9 +13,13 @@ import ( "net/netip" "strings" "testing" + "testing/synctest" ) func TestConnIDClientHandshake(t *testing.T) { + synctest.Test(t, testConnIDClientHandshake) +} +func testConnIDClientHandshake(t *testing.T) { tc := newTestConn(t, clientSide) // On initialization, the client chooses local and remote IDs. // @@ -57,6 +63,9 @@ func TestConnIDClientHandshake(t *testing.T) { } func TestConnIDServerHandshake(t *testing.T) { + synctest.Test(t, testConnIDServerHandshake) +} +func testConnIDServerHandshake(t *testing.T) { tc := newTestConn(t, serverSide) // On initialization, the server is provided with the client-chosen // transient connection ID, and allocates an ID of its own. @@ -178,6 +187,9 @@ func TestNewRandomConnID(t *testing.T) { } func TestConnIDPeerRequestsManyIDs(t *testing.T) { + synctest.Test(t, testConnIDPeerRequestsManyIDs) +} +func testConnIDPeerRequestsManyIDs(t *testing.T) { // "An endpoint SHOULD ensure that its peer has a sufficient number // of available and unused connection IDs." // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-4 @@ -220,6 +232,9 @@ func TestConnIDPeerRequestsManyIDs(t *testing.T) { } func TestConnIDPeerProvidesTooManyIDs(t *testing.T) { + synctest.Test(t, testConnIDPeerProvidesTooManyIDs) +} +func testConnIDPeerProvidesTooManyIDs(t *testing.T) { // "An endpoint MUST NOT provide more connection IDs than the peer's limit." // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-4 tc := newTestConn(t, serverSide) @@ -238,6 +253,9 @@ func TestConnIDPeerProvidesTooManyIDs(t *testing.T) { } func TestConnIDPeerTemporarilyExceedsActiveConnIDLimit(t *testing.T) { + synctest.Test(t, testConnIDPeerTemporarilyExceedsActiveConnIDLimit) +} +func testConnIDPeerTemporarilyExceedsActiveConnIDLimit(t *testing.T) { // "An endpoint MAY send connection IDs that temporarily exceed a peer's limit // if the NEW_CONNECTION_ID frame also requires the retirement of any excess [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-4 @@ -272,7 +290,7 @@ func TestConnIDPeerRetiresConnID(t *testing.T) { clientSide, serverSide, } { - t.Run(side.String(), func(t *testing.T) { + synctestSubtest(t, side.String(), func(t *testing.T) { tc := newTestConn(t, side) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -293,6 +311,9 @@ func TestConnIDPeerRetiresConnID(t *testing.T) { } func TestConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) { + synctest.Test(t, testConnIDPeerWithZeroLengthConnIDSendsNewConnectionID) +} +func testConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) { // "An endpoint that selects a zero-length connection ID during the handshake // cannot issue a new connection ID." // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.1-8 @@ -315,6 +336,9 @@ func TestConnIDPeerWithZeroLengthConnIDSendsNewConnectionID(t *testing.T) { } func TestConnIDPeerRequestsRetirement(t *testing.T) { + synctest.Test(t, testConnIDPeerRequestsRetirement) +} +func testConnIDPeerRequestsRetirement(t *testing.T) { // "Upon receipt of an increased Retire Prior To field, the peer MUST // stop using the corresponding connection IDs and retire them with // RETIRE_CONNECTION_ID frames [...]" @@ -339,6 +363,9 @@ func TestConnIDPeerRequestsRetirement(t *testing.T) { } func TestConnIDPeerDoesNotAcknowledgeRetirement(t *testing.T) { + synctest.Test(t, testConnIDPeerDoesNotAcknowledgeRetirement) +} +func testConnIDPeerDoesNotAcknowledgeRetirement(t *testing.T) { // "An endpoint SHOULD limit the number of connection IDs it has retired locally // for which RETIRE_CONNECTION_ID frames have not yet been acknowledged." // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6 @@ -364,6 +391,9 @@ func TestConnIDPeerDoesNotAcknowledgeRetirement(t *testing.T) { } func TestConnIDRepeatedNewConnectionIDFrame(t *testing.T) { + synctest.Test(t, testConnIDRepeatedNewConnectionIDFrame) +} +func testConnIDRepeatedNewConnectionIDFrame(t *testing.T) { // "Receipt of the same [NEW_CONNECTION_ID] frame multiple times // MUST NOT be treated as a connection error. // https://www.rfc-editor.org/rfc/rfc9000#section-19.15-7 @@ -387,6 +417,9 @@ func TestConnIDRepeatedNewConnectionIDFrame(t *testing.T) { } func TestConnIDForSequenceNumberChanges(t *testing.T) { + synctest.Test(t, testConnIDForSequenceNumberChanges) +} +func testConnIDForSequenceNumberChanges(t *testing.T) { // "[...] if a sequence number is used for different connection IDs, // the endpoint MAY treat that receipt as a connection error // of type PROTOCOL_VIOLATION." @@ -415,6 +448,9 @@ func TestConnIDForSequenceNumberChanges(t *testing.T) { } func TestConnIDRetirePriorToAfterNewConnID(t *testing.T) { + synctest.Test(t, testConnIDRetirePriorToAfterNewConnID) +} +func testConnIDRetirePriorToAfterNewConnID(t *testing.T) { // "Receiving a value in the Retire Prior To field that is greater than // that in the Sequence Number field MUST be treated as a connection error // of type FRAME_ENCODING_ERROR. @@ -436,6 +472,9 @@ func TestConnIDRetirePriorToAfterNewConnID(t *testing.T) { } func TestConnIDAlreadyRetired(t *testing.T) { + synctest.Test(t, testConnIDAlreadyRetired) +} +func testConnIDAlreadyRetired(t *testing.T) { // "An endpoint that receives a NEW_CONNECTION_ID frame with a // sequence number smaller than the Retire Prior To field of a // previously received NEW_CONNECTION_ID frame MUST send a @@ -472,6 +511,9 @@ func TestConnIDAlreadyRetired(t *testing.T) { } func TestConnIDRepeatedRetireConnectionIDFrame(t *testing.T) { + synctest.Test(t, testConnIDRepeatedRetireConnectionIDFrame) +} +func testConnIDRepeatedRetireConnectionIDFrame(t *testing.T) { tc := newTestConn(t, clientSide) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -493,6 +535,9 @@ func TestConnIDRepeatedRetireConnectionIDFrame(t *testing.T) { } func TestConnIDRetiredUnsent(t *testing.T) { + synctest.Test(t, testConnIDRetiredUnsent) +} +func testConnIDRetiredUnsent(t *testing.T) { // "Receipt of a RETIRE_CONNECTION_ID frame containing a sequence number // greater than any previously sent to the peer MUST be treated as a // connection error of type PROTOCOL_VIOLATION." @@ -512,6 +557,9 @@ func TestConnIDRetiredUnsent(t *testing.T) { } func TestConnIDUsePreferredAddressConnID(t *testing.T) { + synctest.Test(t, testConnIDUsePreferredAddressConnID) +} +func testConnIDUsePreferredAddressConnID(t *testing.T) { // Peer gives us a connection ID in the preferred address transport parameter. // We don't use the preferred address at this time, but we should use the // connection ID. (It isn't tied to any specific address.) @@ -543,6 +591,9 @@ func TestConnIDUsePreferredAddressConnID(t *testing.T) { } func TestConnIDPeerProvidesPreferredAddrAndTooManyConnIDs(t *testing.T) { + synctest.Test(t, testConnIDPeerProvidesPreferredAddrAndTooManyConnIDs) +} +func testConnIDPeerProvidesPreferredAddrAndTooManyConnIDs(t *testing.T) { // Peer gives us more conn ids than our advertised limit, // including a conn id in the preferred address transport parameter. cid := testPeerConnID(10) @@ -568,6 +619,9 @@ func TestConnIDPeerProvidesPreferredAddrAndTooManyConnIDs(t *testing.T) { } func TestConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { + synctest.Test(t, testConnIDPeerWithZeroLengthIDProvidesPreferredAddr) +} +func testConnIDPeerWithZeroLengthIDProvidesPreferredAddr(t *testing.T) { // Peer gives us more conn ids than our advertised limit, // including a conn id in the preferred address transport parameter. tc := newTestConn(t, serverSide, func(p *transportParameters) { @@ -596,7 +650,7 @@ func TestConnIDInitialSrcConnIDMismatch(t *testing.T) { // "Endpoints MUST validate that received [initial_source_connection_id] // parameters match received connection ID values." // https://www.rfc-editor.org/rfc/rfc9000#section-7.3-3 - testSides(t, "", func(t *testing.T, side connSide) { + testSidesSynctest(t, "", func(t *testing.T, side connSide) { tc := newTestConn(t, side, func(p *transportParameters) { p.initialSrcConnID = []byte("invalid") }) @@ -621,7 +675,7 @@ func TestConnIDInitialSrcConnIDMismatch(t *testing.T) { } func TestConnIDsCleanedUpAfterClose(t *testing.T) { - testSides(t, "", func(t *testing.T, side connSide) { + testSidesSynctest(t, "", func(t *testing.T, side connSide) { tc := newTestConn(t, side, func(p *transportParameters) { if side == clientSide { token := testPeerStatelessResetToken(0) @@ -664,6 +718,9 @@ func TestConnIDsCleanedUpAfterClose(t *testing.T) { } func TestConnIDRetiredConnIDResent(t *testing.T) { + synctest.Test(t, testConnIDRetiredConnIDResent) +} +func testConnIDRetiredConnIDResent(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() tc.ignoreFrame(frameTypeAck) diff --git a/quic/conn_loss.go b/quic/conn_loss.go index 06761e3f83..bc6d106601 100644 --- a/quic/conn_loss.go +++ b/quic/conn_loss.go @@ -32,7 +32,7 @@ func (c *Conn) handleAckOrLoss(space numberSpace, sent *sentPacket, fate packetF switch f := sent.next(); f { default: panic(fmt.Sprintf("BUG: unhandled acked/lost frame type %x", f)) - case frameTypeAck: + case frameTypeAck, frameTypeAckECN: // Unlike most information, loss of an ACK frame does not trigger // retransmission. ACKs are sent in response to ack-eliciting packets, // and always contain the latest information available. diff --git a/quic/conn_loss_test.go b/quic/conn_loss_test.go index f13ea13d48..49c794ffab 100644 --- a/quic/conn_loss_test.go +++ b/quic/conn_loss_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -9,6 +11,8 @@ import ( "crypto/tls" "fmt" "testing" + "testing/synctest" + "time" ) // Frames may be retransmitted either when the packet containing the frame is lost, or on PTO. @@ -22,6 +26,16 @@ func lostFrameTest(t *testing.T, f func(t *testing.T, pto bool)) { }) } +func lostFrameTestSynctest(t *testing.T, f func(t *testing.T, pto bool)) { + t.Helper() + lostFrameTest(t, func(t *testing.T, pto bool) { + t.Helper() + synctest.Test(t, func(t *testing.T) { + f(t, pto) + }) + }) +} + // triggerLossOrPTO causes the conn to declare the last sent packet lost, // or advances to the PTO timer. func (tc *testConn) triggerLossOrPTO(ptype packetType, pto bool) { @@ -33,7 +47,11 @@ func (tc *testConn) triggerLossOrPTO(ptype packetType, pto bool) { if *testVV { tc.t.Logf("advancing to PTO timer") } - tc.advanceTo(tc.conn.loss.timer) + var when time.Time + tc.conn.runOnLoop(tc.t.Context(), func(now time.Time, conn *Conn) { + when = conn.loss.timer + }) + time.Sleep(time.Until(when)) return } if *testVV { @@ -77,7 +95,7 @@ func TestLostResetStreamFrame(t *testing.T) { // "Cancellation of stream transmission, as carried in a RESET_STREAM frame, // is sent until acknowledged or until all stream data is acknowledged by the peer [...]" // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.4 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) tc.ignoreFrame(frameTypeAck) @@ -106,7 +124,7 @@ func TestLostStopSendingFrame(t *testing.T) { // Technically, we can stop sending a STOP_SENDING frame if the peer sends // us all the data for the stream or resets it. We don't bother tracking this, // however, so we'll keep sending the frame until it is acked. This is harmless. - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, permissiveTransportParameters) tc.ignoreFrame(frameTypeAck) @@ -127,7 +145,7 @@ func TestLostStopSendingFrame(t *testing.T) { func TestLostCryptoFrame(t *testing.T) { // "Data sent in CRYPTO frames is retransmitted [...] until all data has been acknowledged." // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.1 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) @@ -171,7 +189,7 @@ func TestLostCryptoFrame(t *testing.T) { func TestLostStreamFrameEmpty(t *testing.T) { // A STREAM frame opening a stream, but containing no stream data, should // be retransmitted if lost. - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { ctx := canceledContext() tc := newTestConn(t, clientSide, permissiveTransportParameters) tc.handshake() @@ -203,7 +221,7 @@ func TestLostStreamWithData(t *testing.T) { // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.2 // // TODO: Lost stream frame after RESET_STREAM - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { data := []byte{0, 1, 2, 3, 4, 5, 6, 7} tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, func(p *transportParameters) { p.initialMaxStreamsUni = 1 @@ -247,6 +265,9 @@ func TestLostStreamWithData(t *testing.T) { } func TestLostStreamPartialLoss(t *testing.T) { + synctest.Test(t, testLostStreamPartialLoss) +} +func testLostStreamPartialLoss(t *testing.T) { // Conn sends four STREAM packets. // ACKs are received for the packets containing bytes 0 and 2. // The remaining packets are declared lost. @@ -295,7 +316,7 @@ func TestLostMaxDataFrame(t *testing.T) { // "An updated value is sent in a MAX_DATA frame if the packet // containing the most recently sent MAX_DATA frame is declared lost [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.7 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { const maxWindowSize = 32 buf := make([]byte, maxWindowSize) tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { @@ -340,7 +361,7 @@ func TestLostMaxStreamDataFrame(t *testing.T) { // "[...] an updated value is sent when the packet containing // the most recent MAX_STREAM_DATA frame for a stream is lost" // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.8 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { const maxWindowSize = 32 buf := make([]byte, maxWindowSize) tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { @@ -387,7 +408,7 @@ func TestLostMaxStreamDataFrameAfterStreamFinReceived(t *testing.T) { // "An endpoint SHOULD stop sending MAX_STREAM_DATA frames when // the receiving part of the stream enters a "Size Known" or "Reset Recvd" state." // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.8 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { const maxWindowSize = 10 buf := make([]byte, maxWindowSize) tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, func(c *Config) { @@ -425,7 +446,7 @@ func TestLostMaxStreamsFrameMostRecent(t *testing.T) { // most recent MAX_STREAMS for a stream type frame is declared lost [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.9 testStreamTypes(t, "", func(t *testing.T, styp streamType) { - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { ctx := canceledContext() tc := newTestConn(t, serverSide, func(c *Config) { c.MaxUniRemoteStreams = 1 @@ -469,6 +490,9 @@ func TestLostMaxStreamsFrameMostRecent(t *testing.T) { } func TestLostMaxStreamsFrameNotMostRecent(t *testing.T) { + synctest.Test(t, testLostMaxStreamsFrameNotMostRecent) +} +func testLostMaxStreamsFrameNotMostRecent(t *testing.T) { // Send two MAX_STREAMS frames, lose the first one. // // No PTO mode for this test: The ack that causes the first frame @@ -514,7 +538,7 @@ func TestLostStreamDataBlockedFrame(t *testing.T) { // "A new [STREAM_DATA_BLOCKED] frame is sent if a packet containing // the most recent frame for a scope is lost [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.10 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, func(p *transportParameters) { p.initialMaxStreamsUni = 1 p.initialMaxData = 1 << 20 @@ -565,7 +589,7 @@ func TestLostStreamDataBlockedFrameAfterStreamUnblocked(t *testing.T) { // "A new [STREAM_DATA_BLOCKED] frame is sent [...] only while // the endpoint is blocked on the corresponding limit." // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.10 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, func(p *transportParameters) { p.initialMaxStreamsUni = 1 p.initialMaxData = 1 << 20 @@ -607,7 +631,7 @@ func TestLostStreamDataBlockedFrameAfterStreamUnblocked(t *testing.T) { func TestLostNewConnectionIDFrame(t *testing.T) { // "New connection IDs are [...] retransmitted if the packet containing them is lost." // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.13 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc := newTestConn(t, serverSide) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -637,7 +661,7 @@ func TestLostRetireConnectionIDFrame(t *testing.T) { // "[...] retired connection IDs are [...] retransmitted // if the packet containing them is lost." // https://www.rfc-editor.org/rfc/rfc9000#section-13.3-3.13 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc := newTestConn(t, clientSide) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -664,7 +688,7 @@ func TestLostRetireConnectionIDFrame(t *testing.T) { func TestLostPathResponseFrame(t *testing.T) { // "Responses to path validation using PATH_RESPONSE frames are sent just once." // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.12 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc := newTestConn(t, clientSide) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -687,7 +711,7 @@ func TestLostPathResponseFrame(t *testing.T) { func TestLostHandshakeDoneFrame(t *testing.T) { // "The HANDSHAKE_DONE frame MUST be retransmitted until it is acknowledged." // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.3-3.16 - lostFrameTest(t, func(t *testing.T, pto bool) { + lostFrameTestSynctest(t, func(t *testing.T, pto bool) { tc := newTestConn(t, serverSide) tc.ignoreFrame(frameTypeAck) diff --git a/quic/conn_recv.go b/quic/conn_recv.go index a24fc36916..2bf127a479 100644 --- a/quic/conn_recv.go +++ b/quic/conn_recv.go @@ -124,7 +124,7 @@ func (c *Conn) handleLongHeader(now time.Time, dgram *datagram, ptype packetType } c.connIDState.handlePacket(c, p.ptype, p.srcConnID) ackEliciting := c.handleFrames(now, dgram, ptype, space, p.payload) - c.acks[space].receive(now, space, p.num, ackEliciting) + c.acks[space].receive(now, space, p.num, ackEliciting, dgram.ecn) if p.ptype == packetTypeHandshake && c.side == serverSide { c.loss.validateClientAddress() @@ -147,7 +147,7 @@ func (c *Conn) handle1RTT(now time.Time, dgram *datagram, buf []byte) int { p, err := parse1RTTPacket(buf, &c.keysAppData, connIDLen, pnumMax) if err != nil { // A localTransportError terminates the connection. - // Other errors indicate an unparseable packet, but otherwise may be ignored. + // Other errors indicate an unparsable packet, but otherwise may be ignored. if _, ok := err.(localTransportError); ok { c.abort(now, err) } @@ -174,7 +174,7 @@ func (c *Conn) handle1RTT(now time.Time, dgram *datagram, buf []byte) int { c.log1RTTPacketReceived(p, buf) } ackEliciting := c.handleFrames(now, dgram, packetType1RTT, appDataSpace, p.payload) - c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting) + c.acks[appDataSpace].receive(now, appDataSpace, p.num, ackEliciting, dgram.ecn) return len(buf) } @@ -208,10 +208,14 @@ func (c *Conn) handleRetry(now time.Time, pkt []byte) { } c.retryToken = cloneBytes(p.token) c.connIDState.handleRetryPacket(p.srcConnID) + c.keysInitial = initialKeys(p.srcConnID, c.side) // We need to resend any data we've already sent in Initial packets. // We must not reuse already sent packet numbers. c.loss.discardPackets(initialSpace, c.log, c.handleAckOrLoss) // TODO: Discard 0-RTT packets as well, once we support 0-RTT. + if c.testHooks != nil { + c.testHooks.init(false) + } } var errVersionNegotiation = errors.New("server does not support QUIC version 1") @@ -420,12 +424,15 @@ func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, sp func (c *Conn) handleAckFrame(now time.Time, space numberSpace, payload []byte) int { c.loss.receiveAckStart() - largest, ackDelay, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) { + largest, ackDelay, ecn, n := consumeAckFrame(payload, func(rangeIndex int, start, end packetNumber) { if err := c.loss.receiveAckRange(now, space, rangeIndex, start, end, c.handleAckOrLoss); err != nil { c.abort(now, err) return } }) + // TODO: Make use of ECN feedback. + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.3.2 + _ = ecn // Prior to receiving the peer's transport parameters, we cannot // interpret the ACK Delay field because we don't know the ack_delay_exponent // to apply. diff --git a/quic/conn_recv_test.go b/quic/conn_recv_test.go index 1a0eb3a105..6ee728e0e3 100644 --- a/quic/conn_recv_test.go +++ b/quic/conn_recv_test.go @@ -2,14 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( "crypto/tls" "testing" + "testing/synctest" ) func TestConnReceiveAckForUnsentPacket(t *testing.T) { + synctest.Test(t, testConnReceiveAckForUnsentPacket) +} +func testConnReceiveAckForUnsentPacket(t *testing.T) { tc := newTestConn(t, serverSide, permissiveTransportParameters) tc.handshake() tc.writeFrames(packetType1RTT, @@ -27,6 +33,9 @@ func TestConnReceiveAckForUnsentPacket(t *testing.T) { // drop state for a number space, and also contains a valid ACK frame for that space, // we shouldn't complain about the ACK. func TestConnReceiveAckForDroppedSpace(t *testing.T) { + synctest.Test(t, testConnReceiveAckForDroppedSpace) +} +func testConnReceiveAckForDroppedSpace(t *testing.T) { tc := newTestConn(t, serverSide, permissiveTransportParameters) tc.ignoreFrame(frameTypeAck) tc.ignoreFrame(frameTypeNewConnectionID) diff --git a/quic/conn_send.go b/quic/conn_send.go index d6fb149d9f..3e8cf526b5 100644 --- a/quic/conn_send.go +++ b/quic/conn_send.go @@ -374,7 +374,7 @@ func (c *Conn) appendAckFrame(now time.Time, space numberSpace) bool { return false } d := unscaledAckDelayFromDuration(delay, ackDelayExponent) - return c.w.appendAckFrame(seen, d) + return c.w.appendAckFrame(seen, d, c.acks[space].ecn) } func (c *Conn) appendConnectionCloseFrame(now time.Time, space numberSpace, err error) { diff --git a/quic/conn_send_test.go b/quic/conn_send_test.go index c5cf93644c..88911bd167 100644 --- a/quic/conn_send_test.go +++ b/quic/conn_send_test.go @@ -2,14 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( "testing" + "testing/synctest" "time" ) func TestAckElicitingAck(t *testing.T) { + synctest.Test(t, testAckElicitingAck) +} +func testAckElicitingAck(t *testing.T) { // "A receiver that sends only non-ack-eliciting packets [...] might not receive // an acknowledgment for a long period of time. // [...] a receiver could send a [...] ack-eliciting frame occasionally [...] @@ -22,7 +28,7 @@ func TestAckElicitingAck(t *testing.T) { tc.handshake() const count = 100 for i := 0; i < count; i++ { - tc.advance(1 * time.Millisecond) + time.Sleep(1 * time.Millisecond) tc.writeFrames(packetType1RTT, debugFramePing{}, ) @@ -38,6 +44,9 @@ func TestAckElicitingAck(t *testing.T) { } func TestSendPacketNumberSize(t *testing.T) { + synctest.Test(t, testSendPacketNumberSize) +} +func testSendPacketNumberSize(t *testing.T) { tc := newTestConn(t, clientSide, permissiveTransportParameters) tc.handshake() diff --git a/quic/conn_streams.go b/quic/conn_streams.go index bfe80c6dcf..0e4bf50094 100644 --- a/quic/conn_streams.go +++ b/quic/conn_streams.go @@ -71,7 +71,7 @@ func (c *Conn) streamsCleanup() { // AcceptStream waits for and returns the next stream created by the peer. func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) { - return c.streams.queue.get(ctx, c.testHooks) + return c.streams.queue.get(ctx) } // NewStream creates a stream. @@ -283,19 +283,14 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) return false } - // MAX_STREAM_DATA - if !c.streams.remoteLimit[uniStream].appendFrame(w, uniStream, pnum, pto) { - return false - } - if !c.streams.remoteLimit[bidiStream].appendFrame(w, bidiStream, pnum, pto) { - return false - } - if pto { return c.appendStreamFramesPTO(w, pnum) } if !c.streams.needSend.Load() { - return true + // If queueMeta includes newly-finished streams, we may extend the peer's + // stream limits. When there are no streams to process, add MAX_STREAMS + // frames here. Otherwise, wait until after we've processed queueMeta. + return c.appendMaxStreams(w, pnum, pto) } c.streams.sendMu.Lock() defer c.streams.sendMu.Unlock() @@ -354,6 +349,12 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) // If so, put the stream back on a queue. c.queueStreamForSendLocked(s, state) } + + // MAX_STREAMS (possibly triggered by finalization of remote streams above). + if !c.appendMaxStreams(w, pnum, pto) { + return false + } + // queueData contains streams with flow-controlled frames. for c.streams.queueData.head != nil { avail := c.streams.outflow.avail() @@ -408,9 +409,12 @@ func (c *Conn) appendStreamFrames(w *packetWriter, pnum packetNumber, pto bool) // It returns true if no more frames need appending, // false if not everything fit in the current packet. func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { + const pto = true + if !c.appendMaxStreams(w, pnum, pto) { + return false + } c.streams.sendMu.Lock() defer c.streams.sendMu.Unlock() - const pto = true for _, ms := range c.streams.streams { s := ms.s if s == nil { @@ -434,6 +438,16 @@ func (c *Conn) appendStreamFramesPTO(w *packetWriter, pnum packetNumber) bool { return true } +func (c *Conn) appendMaxStreams(w *packetWriter, pnum packetNumber, pto bool) bool { + if !c.streams.remoteLimit[uniStream].appendFrame(w, uniStream, pnum, pto) { + return false + } + if !c.streams.remoteLimit[bidiStream].appendFrame(w, bidiStream, pnum, pto) { + return false + } + return true +} + // A streamRing is a circular linked list of streams. type streamRing struct { head *Stream diff --git a/quic/conn_streams_test.go b/quic/conn_streams_test.go index af3c1dec8f..b95aa47122 100644 --- a/quic/conn_streams_test.go +++ b/quic/conn_streams_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -11,9 +13,13 @@ import ( "math" "sync" "testing" + "testing/synctest" ) func TestStreamsCreate(t *testing.T) { + synctest.Test(t, testStreamsCreate) +} +func testStreamsCreate(t *testing.T) { ctx := canceledContext() tc := newTestConn(t, clientSide, permissiveTransportParameters) tc.handshake() @@ -53,6 +59,9 @@ func TestStreamsCreate(t *testing.T) { } func TestStreamsAccept(t *testing.T) { + synctest.Test(t, testStreamsAccept) +} +func testStreamsAccept(t *testing.T) { ctx := canceledContext() tc := newTestConn(t, serverSide) tc.handshake() @@ -95,6 +104,9 @@ func TestStreamsAccept(t *testing.T) { } func TestStreamsBlockingAccept(t *testing.T) { + synctest.Test(t, testStreamsBlockingAccept) +} +func testStreamsBlockingAccept(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() @@ -124,6 +136,9 @@ func TestStreamsBlockingAccept(t *testing.T) { } func TestStreamsLocalStreamNotCreated(t *testing.T) { + synctest.Test(t, testStreamsLocalStreamNotCreated) +} +func testStreamsLocalStreamNotCreated(t *testing.T) { // "An endpoint MUST terminate the connection with error STREAM_STATE_ERROR // if it receives a STREAM frame for a locally initiated stream that has // not yet been created [...]" @@ -142,6 +157,9 @@ func TestStreamsLocalStreamNotCreated(t *testing.T) { } func TestStreamsLocalStreamClosed(t *testing.T) { + synctest.Test(t, testStreamsLocalStreamClosed) +} +func testStreamsLocalStreamClosed(t *testing.T) { tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, permissiveTransportParameters) s.CloseWrite() tc.wantFrame("FIN for closed stream", @@ -168,6 +186,9 @@ func TestStreamsLocalStreamClosed(t *testing.T) { } func TestStreamsStreamSendOnly(t *testing.T) { + synctest.Test(t, testStreamsStreamSendOnly) +} +func testStreamsStreamSendOnly(t *testing.T) { // "An endpoint MUST terminate the connection with error STREAM_STATE_ERROR // if it receives a STREAM frame for a locally initiated stream that has // not yet been created [...]" @@ -198,6 +219,9 @@ func TestStreamsStreamSendOnly(t *testing.T) { } func TestStreamsWriteQueueFairness(t *testing.T) { + synctest.Test(t, testStreamsWriteQueueFairness) +} +func testStreamsWriteQueueFairness(t *testing.T) { ctx := canceledContext() const dataLen = 1 << 20 const numStreams = 3 @@ -233,7 +257,7 @@ func TestStreamsWriteQueueFairness(t *testing.T) { } // Wait for the stream to finish writing whatever frames it can before // congestion control blocks it. - tc.wait() + synctest.Wait() } sent := make([]int64, len(streams)) @@ -344,7 +368,7 @@ func TestStreamsShutdown(t *testing.T) { }, }} { name := fmt.Sprintf("%v/%v/%v", test.side, test.styp, test.name) - t.Run(name, func(t *testing.T) { + synctestSubtest(t, name, func(t *testing.T) { tc, s := newTestConnAndStream(t, serverSide, test.side, test.styp, permissiveTransportParameters) tc.ignoreFrame(frameTypeStreamBase) @@ -364,6 +388,9 @@ func TestStreamsShutdown(t *testing.T) { } func TestStreamsCreateAndCloseRemote(t *testing.T) { + synctest.Test(t, testStreamsCreateAndCloseRemote) +} +func testStreamsCreateAndCloseRemote(t *testing.T) { // This test exercises creating new streams in response to frames // from the peer, and cleaning up after streams are fully closed. // @@ -473,6 +500,9 @@ func TestStreamsCreateAndCloseRemote(t *testing.T) { } func TestStreamsCreateConcurrency(t *testing.T) { + synctest.Test(t, testStreamsCreateConcurrency) +} +func testStreamsCreateConcurrency(t *testing.T) { cli, srv := newLocalConnPair(t, &Config{}, &Config{}) srvdone := make(chan int) @@ -520,6 +550,9 @@ func TestStreamsCreateConcurrency(t *testing.T) { } func TestStreamsPTOWithImplicitStream(t *testing.T) { + synctest.Test(t, testStreamsPTOWithImplicitStream) +} +func testStreamsPTOWithImplicitStream(t *testing.T) { ctx := canceledContext() tc := newTestConn(t, serverSide, permissiveTransportParameters) tc.handshake() diff --git a/quic/conn_test.go b/quic/conn_test.go index 4b0511fce6..81eeffc5ae 100644 --- a/quic/conn_test.go +++ b/quic/conn_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -17,6 +19,7 @@ import ( "reflect" "strings" "testing" + "testing/synctest" "time" "golang.org/x/net/quic/qlog" @@ -27,7 +30,8 @@ var ( qlogdir = flag.String("qlog", "", "write qlog logs to directory") ) -func TestConnTestConn(t *testing.T) { +func TestConnTestConn(t *testing.T) { synctest.Test(t, testConnTestConn) } +func testConnTestConn(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() if got, want := tc.timeUntilEvent(), defaultMaxIdleTimeout; got != want { @@ -40,13 +44,13 @@ func TestConnTestConn(t *testing.T) { }) return }).result() - if !ranAt.Equal(tc.endpoint.now) { - t.Errorf("func ran on loop at %v, want %v", ranAt, tc.endpoint.now) + if !ranAt.Equal(time.Now()) { + t.Errorf("func ran on loop at %v, want %v", ranAt, time.Now()) } - tc.wait() + synctest.Wait() - nextTime := tc.endpoint.now.Add(defaultMaxIdleTimeout / 2) - tc.advanceTo(nextTime) + nextTime := time.Now().Add(defaultMaxIdleTimeout / 2) + time.Sleep(time.Until(nextTime)) ranAt, _ = runAsync(tc, func(ctx context.Context) (when time.Time, _ error) { tc.conn.runOnLoop(ctx, func(now time.Time, c *Conn) { when = now @@ -56,7 +60,7 @@ func TestConnTestConn(t *testing.T) { if !ranAt.Equal(nextTime) { t.Errorf("func ran on loop at %v, want %v", ranAt, nextTime) } - tc.wait() + synctest.Wait() tc.advanceToTimer() if got := tc.conn.lifetime.state; got != connStateDone { @@ -125,12 +129,9 @@ const maxTestKeyPhases = 3 // A testConn is a Conn whose external interactions (sending and receiving packets, // setting timers) can be manipulated in tests. type testConn struct { - t *testing.T - conn *Conn - endpoint *testEndpoint - timer time.Time - timerLastFired time.Time - idlec chan struct{} // only accessed on the conn's loop + t *testing.T + conn *Conn + endpoint *testEndpoint // Keys are distinct from the conn's keys, // because the test may know about keys before the conn does. @@ -150,7 +151,7 @@ type testConn struct { // CRYPTO data produced by the conn's QUICConn is placed in // cryptoDataOut. // - // The peerTLSConn is is a QUICConn representing the peer. + // The peerTLSConn is a QUICConn representing the peer. // CRYPTO data produced by the conn is written to peerTLSConn, // and data produced by peerTLSConn is placed in cryptoDataIn. cryptoDataOut map[tls.QUICEncryptionLevel][]byte @@ -183,8 +184,6 @@ type testConn struct { // Values to set in packets sent to the conn. sendKeyNumber int sendKeyPhaseBit bool - - asyncTestState } type test1RTTKeys struct { @@ -198,10 +197,6 @@ type keySecret struct { } // newTestConn creates a Conn for testing. -// -// The Conn's event loop is controlled by the test, -// allowing test code to access Conn state directly -// by first ensuring the loop goroutine is idle. func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { t.Helper() config := &Config{ @@ -242,7 +237,7 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { endpoint.configTransportParams = configTransportParams endpoint.configTestConn = configTestConn conn, err := endpoint.e.newConn( - endpoint.now, + time.Now(), config, side, cids, @@ -252,7 +247,7 @@ func newTestConn(t *testing.T, side connSide, opts ...any) *testConn { t.Fatal(err) } tc := endpoint.conns[conn] - tc.wait() + synctest.Wait() return tc } @@ -306,76 +301,33 @@ func newTestConnForConn(t *testing.T, endpoint *testEndpoint, conn *Conn) *testC return tc } -// advance causes time to pass. -func (tc *testConn) advance(d time.Duration) { - tc.t.Helper() - tc.endpoint.advance(d) -} - -// advanceTo sets the current time. -func (tc *testConn) advanceTo(now time.Time) { - tc.t.Helper() - tc.endpoint.advanceTo(now) -} - // advanceToTimer sets the current time to the time of the Conn's next timer event. func (tc *testConn) advanceToTimer() { - if tc.timer.IsZero() { + when := tc.nextEvent() + if when.IsZero() { tc.t.Fatalf("advancing to timer, but timer is not set") } - tc.advanceTo(tc.timer) -} - -func (tc *testConn) timerDelay() time.Duration { - if tc.timer.IsZero() { - return math.MaxInt64 // infinite - } - if tc.timer.Before(tc.endpoint.now) { - return 0 - } - return tc.timer.Sub(tc.endpoint.now) + time.Sleep(time.Until(when)) + synctest.Wait() } const infiniteDuration = time.Duration(math.MaxInt64) // timeUntilEvent returns the amount of time until the next connection event. func (tc *testConn) timeUntilEvent() time.Duration { - if tc.timer.IsZero() { + next := tc.nextEvent() + if next.IsZero() { return infiniteDuration } - if tc.timer.Before(tc.endpoint.now) { - return 0 - } - return tc.timer.Sub(tc.endpoint.now) + return max(0, time.Until(next)) } -// wait blocks until the conn becomes idle. -// The conn is idle when it is blocked waiting for a packet to arrive or a timer to expire. -// Tests shouldn't need to call wait directly. -// testConn methods that wake the Conn event loop will call wait for them. -func (tc *testConn) wait() { - tc.t.Helper() - idlec := make(chan struct{}) - fail := false - tc.conn.sendMsg(func(now time.Time, c *Conn) { - if tc.idlec != nil { - tc.t.Errorf("testConn.wait called concurrently") - fail = true - close(idlec) - } else { - // nextMessage will close idlec. - tc.idlec = idlec - } +func (tc *testConn) nextEvent() time.Time { + nextc := make(chan time.Time) + tc.conn.sendMsg(func(now, next time.Time, c *Conn) { + nextc <- next }) - select { - case <-idlec: - case <-tc.conn.donec: - // We may have async ops that can proceed now that the conn is done. - tc.wakeAsync() - } - if fail { - panic(fail) - } + return <-nextc } func (tc *testConn) cleanup() { @@ -498,7 +450,7 @@ func (tc *testConn) ignoreFrame(frameType byte) { // It returns nil if the Conn has no more datagrams to send at this time. func (tc *testConn) readDatagram() *testDatagram { tc.t.Helper() - tc.wait() + synctest.Wait() tc.sentPackets = nil tc.sentFrames = nil buf := tc.endpoint.read() @@ -1001,11 +953,11 @@ func spaceForPacketType(ptype packetType) numberSpace { // testConnHooks implements connTestHooks. type testConnHooks testConn -func (tc *testConnHooks) init() { +func (tc *testConnHooks) init(first bool) { tc.conn.keysAppData.updateAfter = maxPacketNumber // disable key updates tc.keysInitial.r = tc.conn.keysInitial.w tc.keysInitial.w = tc.conn.keysInitial.r - if tc.conn.side == serverSide { + if first && tc.conn.side == serverSide { tc.endpoint.acceptQueue = append(tc.endpoint.acceptQueue, (*testConn)(tc)) } } @@ -1095,7 +1047,7 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { case tls.QUICTransportParameters: p, err := unmarshalTransportParams(e.Data) if err != nil { - tc.t.Logf("sent unparseable transport parameters %x %v", e.Data, err) + tc.t.Logf("sent unparsable transport parameters %x %v", e.Data, err) } else { tc.sentTransportParameters = &p } @@ -1103,48 +1055,10 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { } } -// nextMessage is called by the Conn's event loop to request its next event. -func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) { - tc.timer = timer - for { - if !timer.IsZero() && !timer.After(tc.endpoint.now) { - if timer.Equal(tc.timerLastFired) { - // If the connection timer fires at time T, the Conn should take some - // action to advance the timer into the future. If the Conn reschedules - // the timer for the same time, it isn't making progress and we have a bug. - tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.endpoint.now, timer) - } else { - tc.timerLastFired = timer - return tc.endpoint.now, timerEvent{} - } - } - select { - case m := <-msgc: - return tc.endpoint.now, m - default: - } - if !tc.wakeAsync() { - break - } - } - // If the message queue is empty, then the conn is idle. - if tc.idlec != nil { - idlec := tc.idlec - tc.idlec = nil - close(idlec) - } - m = <-msgc - return tc.endpoint.now, m -} - func (tc *testConnHooks) newConnID(seq int64) ([]byte, error) { return testLocalConnID(seq), nil } -func (tc *testConnHooks) timeNow() time.Time { - return tc.endpoint.now -} - // testLocalConnID returns the connection ID with a given sequence number // used by a Conn under test. func testLocalConnID(seq int64) []byte { @@ -1171,7 +1085,7 @@ func testPeerStatelessResetToken(seq int64) statelessResetToken { // canceledContext returns a canceled Context. // -// Functions which take a context preference progress over cancelation. +// Functions which take a context preference progress over cancellation. // For example, a read with a canceled context will return data if any is available. // Tests use canceled contexts to perform non-blocking operations. func canceledContext() context.Context { diff --git a/quic/crypto_stream.go b/quic/crypto_stream.go index ce73cb54ff..a5b9818296 100644 --- a/quic/crypto_stream.go +++ b/quic/crypto_stream.go @@ -142,7 +142,7 @@ func (s *cryptoStream) sendData(off int64, b []byte) { func (s *cryptoStream) discardKeys() error { if s.in.end-s.in.start != 0 { // The peer sent some unprocessed CRYPTO data that we're about to discard. - // Close the connetion with a TLS unexpected_message alert. + // Close the connection with a TLS unexpected_message alert. // https://www.rfc-editor.org/rfc/rfc5246#section-7.2.2 const unexpectedMessage = 10 return localTransportError{ diff --git a/quic/doc.go b/quic/doc.go index 2fd10f0878..37b19eb13b 100644 --- a/quic/doc.go +++ b/quic/doc.go @@ -21,7 +21,7 @@ // // A [Stream] is a QUIC stream, an ordered, reliable byte stream. // -// # Cancelation +// # Cancellation // // All blocking operations may be canceled using a context.Context. // When performing an operation with a canceled context, the operation diff --git a/quic/endpoint.go b/quic/endpoint.go index 1bb901525e..3d68073cd6 100644 --- a/quic/endpoint.go +++ b/quic/endpoint.go @@ -36,7 +36,6 @@ type Endpoint struct { } type endpointTestHooks interface { - timeNow() time.Time newConn(c *Conn) } @@ -160,7 +159,7 @@ func (e *Endpoint) Close(ctx context.Context) error { // Accept waits for and returns the next connection. func (e *Endpoint) Accept(ctx context.Context) (*Conn, error) { - return e.acceptQueue.get(ctx, nil) + return e.acceptQueue.get(ctx) } // Dial creates and returns a connection to a network address. @@ -269,12 +268,7 @@ func (e *Endpoint) handleUnknownDestinationDatagram(m *datagram) { if len(m.b) < minimumValidPacketSize { return } - var now time.Time - if e.testHooks != nil { - now = e.testHooks.timeNow() - } else { - now = time.Now() - } + now := time.Now() // Check to see if this is a stateless reset. var token statelessResetToken copy(token[:], m.b[len(m.b)-len(token):]) diff --git a/quic/endpoint_test.go b/quic/endpoint_test.go index 98b8756d1c..6a62104e62 100644 --- a/quic/endpoint_test.go +++ b/quic/endpoint_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -12,8 +14,9 @@ import ( "log/slog" "net/netip" "runtime" + "sync" "testing" - "time" + "testing/synctest" "golang.org/x/net/quic/qlog" ) @@ -22,6 +25,10 @@ func TestConnect(t *testing.T) { newLocalConnPair(t, &Config{}, &Config{}) } +func TestConnectRetry(t *testing.T) { + newLocalConnPair(t, &Config{RequireAddressValidation: true}, &Config{}) +} + func TestConnectDefaultTLSConfig(t *testing.T) { serverConfig := newTestTLSConfigWithMoreDefaults(serverSide) clientConfig := newTestTLSConfigWithMoreDefaults(clientSide) @@ -122,22 +129,22 @@ func makeTestConfig(conf *Config, side connSide) *Config { type testEndpoint struct { t *testing.T e *Endpoint - now time.Time recvc chan *datagram idlec chan struct{} conns map[*Conn]*testConn acceptQueue []*testConn configTransportParams []func(*transportParameters) configTestConn []func(*testConn) - sentDatagrams [][]byte peerTLSConn *tls.QUICConn lastInitialDstConnID []byte // for parsing Retry packets + + sentDatagramsMu sync.Mutex + sentDatagrams [][]byte } func newTestEndpoint(t *testing.T, config *Config) *testEndpoint { te := &testEndpoint{ t: t, - now: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), recvc: make(chan *datagram), idlec: make(chan struct{}), conns: make(map[*Conn]*testConn), @@ -155,16 +162,6 @@ func (te *testEndpoint) cleanup() { te.e.Close(canceledContext()) } -func (te *testEndpoint) wait() { - select { - case te.idlec <- struct{}{}: - case <-te.e.closec: - } - for _, tc := range te.conns { - tc.wait() - } -} - // accept returns a server connection from the endpoint. // Unlike Endpoint.Accept, connections are available as soon as they are created. func (te *testEndpoint) accept() *testConn { @@ -178,7 +175,7 @@ func (te *testEndpoint) accept() *testConn { func (te *testEndpoint) write(d *datagram) { te.recvc <- d - te.wait() + synctest.Wait() } var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000") @@ -237,7 +234,9 @@ func (te *testEndpoint) connForSource(srcConnID []byte) *testConn { func (te *testEndpoint) read() []byte { te.t.Helper() - te.wait() + synctest.Wait() + te.sentDatagramsMu.Lock() + defer te.sentDatagramsMu.Unlock() if len(te.sentDatagrams) == 0 { return nil } @@ -275,34 +274,9 @@ func (te *testEndpoint) wantIdle(expectation string) { } } -// advance causes time to pass. -func (te *testEndpoint) advance(d time.Duration) { - te.t.Helper() - te.advanceTo(te.now.Add(d)) -} - -// advanceTo sets the current time. -func (te *testEndpoint) advanceTo(now time.Time) { - te.t.Helper() - if te.now.After(now) { - te.t.Fatalf("time moved backwards: %v -> %v", te.now, now) - } - te.now = now - for _, tc := range te.conns { - if !tc.timer.After(te.now) { - tc.conn.sendMsg(timerEvent{}) - tc.wait() - } - } -} - // testEndpointHooks implements endpointTestHooks. type testEndpointHooks testEndpoint -func (te *testEndpointHooks) timeNow() time.Time { - return te.now -} - func (te *testEndpointHooks) newConn(c *Conn) { tc := newTestConnForConn(te.t, (*testEndpoint)(te), c) te.conns[c] = tc @@ -334,6 +308,8 @@ func (te *testEndpointUDPConn) Read(f func(*datagram)) { } func (te *testEndpointUDPConn) Write(dgram datagram) error { + te.sentDatagramsMu.Lock() + defer te.sentDatagramsMu.Unlock() te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), dgram.b...)) return nil } diff --git a/quic/frame_debug.go b/quic/frame_debug.go index 7cf03faf5b..8d8fd54517 100644 --- a/quic/frame_debug.go +++ b/quic/frame_debug.go @@ -136,11 +136,12 @@ func (f debugFramePing) LogValue() slog.Value { type debugFrameAck struct { ackDelay unscaledAckDelay ranges []i64range[packetNumber] + ecn ecnCounts } func parseDebugFrameAck(b []byte) (f debugFrameAck, n int) { f.ranges = nil - _, f.ackDelay, n = consumeAckFrame(b, func(_ int, start, end packetNumber) { + _, f.ackDelay, f.ecn, n = consumeAckFrame(b, func(_ int, start, end packetNumber) { f.ranges = append(f.ranges, i64range[packetNumber]{ start: start, end: end, @@ -159,11 +160,15 @@ func (f debugFrameAck) String() string { for _, r := range f.ranges { s += fmt.Sprintf(" [%v,%v)", r.start, r.end) } + + if (f.ecn != ecnCounts{}) { + s += fmt.Sprintf(" ECN=[%d,%d,%d]", f.ecn.t0, f.ecn.t1, f.ecn.ce) + } return s } func (f debugFrameAck) write(w *packetWriter) bool { - return w.appendAckFrame(rangeset[packetNumber](f.ranges), f.ackDelay) + return w.appendAckFrame(rangeset[packetNumber](f.ranges), f.ackDelay, f.ecn) } func (f debugFrameAck) LogValue() slog.Value { diff --git a/quic/gate.go b/quic/gate.go index 1f570bb906..b8b8605e62 100644 --- a/quic/gate.go +++ b/quic/gate.go @@ -46,10 +46,7 @@ func (g *gate) lock() (set bool) { // waitAndLock waits until the condition is set before acquiring the gate. // If the context expires, waitAndLock returns an error and does not acquire the gate. -func (g *gate) waitAndLock(ctx context.Context, testHooks connTestHooks) error { - if testHooks != nil { - return testHooks.waitUntil(ctx, g.lockIfSet) - } +func (g *gate) waitAndLock(ctx context.Context) error { select { case <-g.set: return nil diff --git a/quic/gate_test.go b/quic/gate_test.go index 54f7a8a4ac..59c157d237 100644 --- a/quic/gate_test.go +++ b/quic/gate_test.go @@ -47,7 +47,7 @@ func TestGateWaitAndLockContext(t *testing.T) { time.Sleep(1 * time.Millisecond) cancel() }() - if err := g.waitAndLock(ctx, nil); err != context.Canceled { + if err := g.waitAndLock(ctx); err != context.Canceled { t.Errorf("g.waitAndLock() = %v, want context.Canceled", err) } // waitAndLock succeeds @@ -58,7 +58,7 @@ func TestGateWaitAndLockContext(t *testing.T) { set = true g.unlock(true) }() - if err := g.waitAndLock(context.Background(), nil); err != nil { + if err := g.waitAndLock(context.Background()); err != nil { t.Errorf("g.waitAndLock() = %v, want nil", err) } if !set { @@ -66,7 +66,7 @@ func TestGateWaitAndLockContext(t *testing.T) { } g.unlock(true) // waitAndLock succeeds when the gate is set and the context is canceled - if err := g.waitAndLock(ctx, nil); err != nil { + if err := g.waitAndLock(ctx); err != nil { t.Errorf("g.waitAndLock() = %v, want nil", err) } } @@ -89,5 +89,5 @@ func TestGateUnlockFunc(t *testing.T) { g.lock() defer g.unlockFunc(func() bool { return true }) }() - g.waitAndLock(context.Background(), nil) + g.waitAndLock(context.Background()) } diff --git a/quic/idle_test.go b/quic/idle_test.go index 29d3bd1418..d9ae16ab77 100644 --- a/quic/idle_test.go +++ b/quic/idle_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -9,10 +11,14 @@ import ( "crypto/tls" "fmt" "testing" + "testing/synctest" "time" ) func TestHandshakeTimeoutExpiresServer(t *testing.T) { + synctest.Test(t, testHandshakeTimeoutExpiresServer) +} +func testHandshakeTimeoutExpiresServer(t *testing.T) { const timeout = 5 * time.Second tc := newTestConn(t, serverSide, func(c *Config) { c.HandshakeTimeout = timeout @@ -32,18 +38,18 @@ func TestHandshakeTimeoutExpiresServer(t *testing.T) { packetTypeHandshake, debugFrameCrypto{}) tc.writeAckForAll() - if got, want := tc.timerDelay(), timeout; got != want { + if got, want := tc.timeUntilEvent(), timeout; got != want { t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want) } // Client sends a packet, but this does not extend the handshake timer. - tc.advance(1 * time.Second) + time.Sleep(1 * time.Second) tc.writeFrames(packetTypeHandshake, debugFrameCrypto{ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake][:1], // partial data }) tc.wantIdle("handshake is not complete") - tc.advance(timeout - 1*time.Second) + time.Sleep(timeout - 1*time.Second) tc.wantFrame("server closes connection after handshake timeout", packetTypeHandshake, debugFrameConnectionCloseTransport{ code: errConnectionRefused, @@ -51,6 +57,9 @@ func TestHandshakeTimeoutExpiresServer(t *testing.T) { } func TestHandshakeTimeoutExpiresClient(t *testing.T) { + synctest.Test(t, testHandshakeTimeoutExpiresClient) +} +func testHandshakeTimeoutExpiresClient(t *testing.T) { const timeout = 5 * time.Second tc := newTestConn(t, clientSide, func(c *Config) { c.HandshakeTimeout = timeout @@ -77,10 +86,10 @@ func TestHandshakeTimeoutExpiresClient(t *testing.T) { tc.writeAckForAll() tc.wantIdle("client is waiting for end of handshake") - if got, want := tc.timerDelay(), timeout; got != want { + if got, want := tc.timeUntilEvent(), timeout; got != want { t.Errorf("connection timer = %v, want %v (handshake timeout)", got, want) } - tc.advance(timeout) + time.Sleep(timeout) tc.wantFrame("client closes connection after handshake timeout", packetTypeHandshake, debugFrameConnectionCloseTransport{ code: errConnectionRefused, @@ -110,7 +119,7 @@ func TestIdleTimeoutExpires(t *testing.T) { wantTimeout: 10 * time.Second, }} { name := fmt.Sprintf("local=%v/peer=%v", test.localMaxIdleTimeout, test.peerMaxIdleTimeout) - t.Run(name, func(t *testing.T) { + synctestSubtest(t, name, func(t *testing.T) { tc := newTestConn(t, serverSide, func(p *transportParameters) { p.maxIdleTimeout = test.peerMaxIdleTimeout }, func(c *Config) { @@ -120,13 +129,13 @@ func TestIdleTimeoutExpires(t *testing.T) { if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want { t.Errorf("new conn timeout=%v, want %v (idle timeout)", got, want) } - tc.advance(test.wantTimeout - 1) + time.Sleep(test.wantTimeout - 1) tc.wantIdle("connection is idle and alive prior to timeout") ctx := canceledContext() if err := tc.conn.Wait(ctx); err != context.Canceled { t.Fatalf("conn.Wait() = %v, want Canceled", err) } - tc.advance(1) + time.Sleep(1) tc.wantIdle("connection exits after timeout") if err := tc.conn.Wait(ctx); err != errIdleTimeout { t.Fatalf("conn.Wait() = %v, want errIdleTimeout", err) @@ -154,7 +163,7 @@ func TestIdleTimeoutKeepAlive(t *testing.T) { wantTimeout: 30 * time.Second, }} { name := fmt.Sprintf("idle_timeout=%v/keepalive=%v", test.idleTimeout, test.keepAlive) - t.Run(name, func(t *testing.T) { + synctestSubtest(t, name, func(t *testing.T) { tc := newTestConn(t, serverSide, func(c *Config) { c.MaxIdleTimeout = test.idleTimeout c.KeepAlivePeriod = test.keepAlive @@ -163,9 +172,9 @@ func TestIdleTimeoutKeepAlive(t *testing.T) { if got, want := tc.timeUntilEvent(), test.wantTimeout; got != want { t.Errorf("new conn timeout=%v, want %v (keepalive timeout)", got, want) } - tc.advance(test.wantTimeout - 1) + time.Sleep(test.wantTimeout - 1) tc.wantIdle("connection is idle prior to timeout") - tc.advance(1) + time.Sleep(1) tc.wantFrameType("keep-alive ping is sent", packetType1RTT, debugFramePing{}) }) @@ -173,6 +182,9 @@ func TestIdleTimeoutKeepAlive(t *testing.T) { } func TestIdleLongTermKeepAliveSent(t *testing.T) { + synctest.Test(t, testIdleLongTermKeepAliveSent) +} +func testIdleLongTermKeepAliveSent(t *testing.T) { // This test examines a connection sitting idle and sending periodic keep-alive pings. const keepAlivePeriod = 30 * time.Second tc := newTestConn(t, clientSide, func(c *Config) { @@ -191,7 +203,7 @@ func TestIdleLongTermKeepAliveSent(t *testing.T) { if got, want := tc.timeUntilEvent(), keepAlivePeriod; got != want { t.Errorf("i=%v conn timeout=%v, want %v (keepalive timeout)", i, got, want) } - tc.advance(keepAlivePeriod) + time.Sleep(keepAlivePeriod) tc.wantFrameType("keep-alive ping is sent", packetType1RTT, debugFramePing{}) tc.writeAckForAll() @@ -199,6 +211,9 @@ func TestIdleLongTermKeepAliveSent(t *testing.T) { } func TestIdleLongTermKeepAliveReceived(t *testing.T) { + synctest.Test(t, testIdleLongTermKeepAliveReceived) +} +func testIdleLongTermKeepAliveReceived(t *testing.T) { // This test examines a connection sitting idle, but receiving periodic peer // traffic to keep the connection alive. const idleTimeout = 30 * time.Second @@ -207,7 +222,7 @@ func TestIdleLongTermKeepAliveReceived(t *testing.T) { }) tc.handshake() for i := 0; i < 10; i++ { - tc.advance(idleTimeout - 1*time.Second) + time.Sleep(idleTimeout - 1*time.Second) tc.writeFrames(packetType1RTT, debugFramePing{}) if got, want := tc.timeUntilEvent(), maxAckDelay-timerGranularity; got != want { t.Errorf("i=%v conn timeout=%v, want %v (max_ack_delay)", i, got, want) diff --git a/quic/key_update_test.go b/quic/key_update_test.go index 2daf7db97f..7a02e84907 100644 --- a/quic/key_update_test.go +++ b/quic/key_update_test.go @@ -2,13 +2,19 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( "testing" + "testing/synctest" ) func TestKeyUpdatePeerUpdates(t *testing.T) { + synctest.Test(t, testKeyUpdatePeerUpdates) +} +func testKeyUpdatePeerUpdates(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() tc.ignoreFrames = nil // ignore nothing @@ -56,6 +62,9 @@ func TestKeyUpdatePeerUpdates(t *testing.T) { } func TestKeyUpdateAcceptPreviousPhaseKeys(t *testing.T) { + synctest.Test(t, testKeyUpdateAcceptPreviousPhaseKeys) +} +func testKeyUpdateAcceptPreviousPhaseKeys(t *testing.T) { // "An endpoint SHOULD retain old keys for some time after // unprotecting a packet sent using the new keys." // https://www.rfc-editor.org/rfc/rfc9001#section-6.1-8 @@ -112,6 +121,9 @@ func TestKeyUpdateAcceptPreviousPhaseKeys(t *testing.T) { } func TestKeyUpdateRejectPacketFromPriorPhase(t *testing.T) { + synctest.Test(t, testKeyUpdateRejectPacketFromPriorPhase) +} +func testKeyUpdateRejectPacketFromPriorPhase(t *testing.T) { // "Packets with higher packet numbers MUST be protected with either // the same or newer packet protection keys than packets with lower packet numbers." // https://www.rfc-editor.org/rfc/rfc9001#section-6.4-2 @@ -161,6 +173,9 @@ func TestKeyUpdateRejectPacketFromPriorPhase(t *testing.T) { } func TestKeyUpdateLocallyInitiated(t *testing.T) { + synctest.Test(t, testKeyUpdateLocallyInitiated) +} +func testKeyUpdateLocallyInitiated(t *testing.T) { const updateAfter = 4 // initiate key update after 1-RTT packet 4 tc := newTestConn(t, serverSide) tc.conn.keysAppData.updateAfter = updateAfter diff --git a/quic/loss.go b/quic/loss.go index ffbf69ddb7..95feaba2d4 100644 --- a/quic/loss.go +++ b/quic/loss.go @@ -178,7 +178,7 @@ func (c *lossState) nextNumber(space numberSpace) packetNumber { return c.spaces[space].nextNum } -// skipPacketNumber skips a packet number as a defense against optimistic ACK attacks. +// skipNumber skips a packet number as a defense against optimistic ACK attacks. func (c *lossState) skipNumber(now time.Time, space numberSpace) { sent := newSentPacket() sent.num = c.spaces[space].nextNum diff --git a/quic/loss_test.go b/quic/loss_test.go index 545f2c414e..6d07d137cc 100644 --- a/quic/loss_test.go +++ b/quic/loss_test.go @@ -675,7 +675,7 @@ func TestLossPTONotSetWhenLossTimerSet(t *testing.T) { t.Logf("# PTO = smoothed_rtt + max(4*rttvar, 1ms)") test.wantTimeout(999 * time.Millisecond) - t.Logf("# ack of packet 1 starts loss timer for 0, PTO overidden") + t.Logf("# ack of packet 1 starts loss timer for 0, PTO overridden") test.advance(333 * time.Millisecond) test.ack(initialSpace, 0*time.Millisecond, i64range[packetNumber]{1, 2}) test.wantAck(initialSpace, 1) diff --git a/quic/packet_codec_test.go b/quic/packet_codec_test.go index be335d7fdf..d49f0ea69c 100644 --- a/quic/packet_codec_test.go +++ b/quic/packet_codec_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -263,6 +265,65 @@ func TestFrameEncodeDecode(t *testing.T) { 0x0f, // Gap (i) 0x0e, // ACK Range Length (i) }, + }, { + s: "ACK Delay=10 [0,16) [17,32) ECN=[1,2,3]", + j: `"error: debugFrameAck should not appear as a slog Value"`, + f: debugFrameAck{ + ackDelay: 10, + ranges: []i64range[packetNumber]{ + {0x00, 0x10}, + {0x11, 0x20}, + }, + ecn: ecnCounts{1, 2, 3}, + }, + b: []byte{ + 0x03, // TYPE (i) = 0x3 + 0x1f, // Largest Acknowledged (i) + 10, // ACK Delay (i) + 0x01, // ACK Range Count (i) + 0x0e, // First ACK Range (i) + 0x00, // Gap (i) + 0x0f, // ACK Range Length (i) + 0x01, // ECT0 Count (i) + 0x02, // ECT1 Count (i) + 0x03, // ECN-CE Count (i) + }, + truncated: []byte{ + 0x03, // TYPE (i) = 0x3 + 0x1f, // Largest Acknowledged (i) + 10, // ACK Delay (i) + 0x00, // ACK Range Count (i) + 0x0e, // First ACK Range (i) + 0x01, // ECT0 Count (i) + 0x02, // ECT1 Count (i) + 0x03, // ECN-CE Count (i) + }, + }, { + s: "ACK Delay=10 [17,32) ECN=[1,2,3]", + j: `"error: debugFrameAck should not appear as a slog Value"`, + f: debugFrameAck{ + ackDelay: 10, + ranges: []i64range[packetNumber]{ + {0x11, 0x20}, + }, + ecn: ecnCounts{1, 2, 3}, + }, + b: []byte{ + 0x03, // TYPE (i) = 0x3 + 0x1f, // Largest Acknowledged (i) + 10, // ACK Delay (i) + 0x00, // ACK Range Count (i) + 0x0e, // First ACK Range (i) + 0x01, // ECT0 Count (i) + 0x02, // ECT1 Count (i) + 0x03, // ECN-CE Count (i) + }, + // Downgrading to a type 0x2 ACK frame is not allowed: "Even if an + // endpoint does not set an ECT field in packets it sends, the endpoint + // MUST provide feedback about ECN markings it receives, if these are + // accessible." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.4.1-2 + truncated: nil, }, { s: "RESET_STREAM ID=1 Code=2 FinalSize=3", j: `{"frame_type":"reset_stream","stream_id":1,"final_size":3}`, @@ -675,6 +736,7 @@ func TestFrameDecode(t *testing.T) { ranges: []i64range[packetNumber]{ {0, 1}, }, + ecn: ecnCounts{1, 2, 3}, }, b: []byte{ 0x03, // TYPE (i) = 0x02..0x03 diff --git a/quic/packet_parser.go b/quic/packet_parser.go index eadf14fd18..265c4aeb3a 100644 --- a/quic/packet_parser.go +++ b/quic/packet_parser.go @@ -157,25 +157,25 @@ func parse1RTTPacket(pkt []byte, k *updatingKeyPair, dstConnIDLen int, pnumMax p // which includes both general parse failures and specific violations of frame // constraints. -func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumber)) (largest packetNumber, ackDelay unscaledAckDelay, n int) { +func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumber)) (largest packetNumber, ackDelay unscaledAckDelay, ecn ecnCounts, n int) { b := frame[1:] // type largestAck, n := quicwire.ConsumeVarint(b) if n < 0 { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } b = b[n:] v, n := quicwire.ConsumeVarintInt64(b) if n < 0 { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } b = b[n:] ackDelay = unscaledAckDelay(v) ackRangeCount, n := quicwire.ConsumeVarint(b) if n < 0 { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } b = b[n:] @@ -183,12 +183,12 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe for i := uint64(0); ; i++ { rangeLen, n := quicwire.ConsumeVarint(b) if n < 0 { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } b = b[n:] rangeMin := rangeMax - packetNumber(rangeLen) if rangeMin < 0 || rangeMin > rangeMax { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } f(int(i), rangeMin, rangeMax+1) @@ -198,7 +198,7 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe gap, n := quicwire.ConsumeVarint(b) if n < 0 { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } b = b[n:] @@ -206,32 +206,30 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe } if frame[0] != frameTypeAckECN { - return packetNumber(largestAck), ackDelay, len(frame) - len(b) + return packetNumber(largestAck), ackDelay, ecnCounts{}, len(frame) - len(b) } ect0Count, n := quicwire.ConsumeVarint(b) if n < 0 { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } b = b[n:] ect1Count, n := quicwire.ConsumeVarint(b) if n < 0 { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } b = b[n:] ecnCECount, n := quicwire.ConsumeVarint(b) if n < 0 { - return 0, 0, -1 + return 0, 0, ecnCounts{}, -1 } b = b[n:] - // TODO: Make use of ECN feedback. - // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.3.2 - _ = ect0Count - _ = ect1Count - _ = ecnCECount + ecn.t0 = int(ect0Count) + ecn.t1 = int(ect1Count) + ecn.ce = int(ecnCECount) - return packetNumber(largestAck), ackDelay, len(frame) - len(b) + return packetNumber(largestAck), ackDelay, ecn, len(frame) - len(b) } func consumeResetStreamFrame(b []byte) (id streamID, code uint64, finalSize int64, n int) { diff --git a/quic/packet_writer.go b/quic/packet_writer.go index 3560ebbe4d..f446521d2b 100644 --- a/quic/packet_writer.go +++ b/quic/packet_writer.go @@ -262,7 +262,7 @@ func (w *packetWriter) appendPingFrame() (added bool) { // to the peer potentially failing to receive an acknowledgement // for an older packet during a period of high packet loss or // reordering. This may result in unnecessary retransmissions. -func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscaledAckDelay) (added bool) { +func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscaledAckDelay, ecn ecnCounts) (added bool) { if len(seen) == 0 { return false } @@ -270,10 +270,20 @@ func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscale largest = uint64(seen.max()) firstRange = uint64(seen[len(seen)-1].size() - 1) ) - if w.avail() < 1+quicwire.SizeVarint(largest)+quicwire.SizeVarint(uint64(delay))+1+quicwire.SizeVarint(firstRange) { + var ecnLen int + ackType := byte(frameTypeAck) + if (ecn != ecnCounts{}) { + // "Even if an endpoint does not set an ECT field in packets it sends, + // the endpoint MUST provide feedback about ECN markings it receives, if + // these are accessible." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-13.4.1-2 + ecnLen = quicwire.SizeVarint(uint64(ecn.ce)) + quicwire.SizeVarint(uint64(ecn.t0)) + quicwire.SizeVarint(uint64(ecn.t1)) + ackType = frameTypeAckECN + } + if w.avail() < 1+quicwire.SizeVarint(largest)+quicwire.SizeVarint(uint64(delay))+1+quicwire.SizeVarint(firstRange)+ecnLen { return false } - w.b = append(w.b, frameTypeAck) + w.b = append(w.b, ackType) w.b = quicwire.AppendVarint(w.b, largest) w.b = quicwire.AppendVarint(w.b, uint64(delay)) // The range count is technically a varint, but we'll reserve a single byte for it @@ -285,7 +295,7 @@ func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscale for i := len(seen) - 2; i >= 0; i-- { gap := uint64(seen[i+1].start - seen[i].end - 1) size := uint64(seen[i].size() - 1) - if w.avail() < quicwire.SizeVarint(gap)+quicwire.SizeVarint(size) || rangeCount > 62 { + if w.avail() < quicwire.SizeVarint(gap)+quicwire.SizeVarint(size)+ecnLen || rangeCount > 62 { break } w.b = quicwire.AppendVarint(w.b, gap) @@ -293,7 +303,12 @@ func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscale rangeCount++ } w.b[rangeCountOff] = rangeCount - w.sent.appendNonAckElicitingFrame(frameTypeAck) + if ackType == frameTypeAckECN { + w.b = quicwire.AppendVarint(w.b, uint64(ecn.t0)) + w.b = quicwire.AppendVarint(w.b, uint64(ecn.t1)) + w.b = quicwire.AppendVarint(w.b, uint64(ecn.ce)) + } + w.sent.appendNonAckElicitingFrame(ackType) w.sent.appendInt(uint64(seen.max())) return true } diff --git a/quic/path_test.go b/quic/path_test.go index 60ff51e35d..16dd9fcede 100644 --- a/quic/path_test.go +++ b/quic/path_test.go @@ -2,10 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( "testing" + "testing/synctest" ) func TestPathChallengeReceived(t *testing.T) { @@ -22,30 +25,35 @@ func TestPathChallengeReceived(t *testing.T) { padTo: 1200, wantPadding: 1200, }} { - // "The recipient of [a PATH_CHALLENGE] frame MUST generate - // a PATH_RESPONSE frame [...] containing the same Data value." - // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.17-7 - tc := newTestConn(t, clientSide) - tc.handshake() - tc.ignoreFrame(frameTypeAck) - data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} - tc.writeFrames(packetType1RTT, debugFramePathChallenge{ - data: data, - }, debugFramePadding{ - to: test.padTo, - }) - tc.wantFrame("response to PATH_CHALLENGE", - packetType1RTT, debugFramePathResponse{ + synctestSubtest(t, test.name, func(t *testing.T) { + // "The recipient of [a PATH_CHALLENGE] frame MUST generate + // a PATH_RESPONSE frame [...] containing the same Data value." + // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.17-7 + tc := newTestConn(t, clientSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + data := pathChallengeData{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef} + tc.writeFrames(packetType1RTT, debugFramePathChallenge{ data: data, + }, debugFramePadding{ + to: test.padTo, }) - if got, want := tc.lastDatagram.paddedSize, test.wantPadding; got != want { - t.Errorf("PATH_RESPONSE expanded to %v bytes, want %v", got, want) - } - tc.wantIdle("connection is idle") + tc.wantFrame("response to PATH_CHALLENGE", + packetType1RTT, debugFramePathResponse{ + data: data, + }) + if got, want := tc.lastDatagram.paddedSize, test.wantPadding; got != want { + t.Errorf("PATH_RESPONSE expanded to %v bytes, want %v", got, want) + } + tc.wantIdle("connection is idle") + }) } } func TestPathResponseMismatchReceived(t *testing.T) { + synctest.Test(t, testPathResponseMismatchReceived) +} +func testPathResponseMismatchReceived(t *testing.T) { // "If the content of a PATH_RESPONSE frame does not match the content of // a PATH_CHALLENGE frame previously sent by the endpoint, // the endpoint MAY generate a connection error of type PROTOCOL_VIOLATION." diff --git a/quic/ping_test.go b/quic/ping_test.go index a8e6b61ada..4589a6c7be 100644 --- a/quic/ping_test.go +++ b/quic/ping_test.go @@ -2,11 +2,19 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic -import "testing" +import ( + "testing" + "testing/synctest" +) func TestPing(t *testing.T) { + synctest.Test(t, testPing) +} +func testPing(t *testing.T) { tc := newTestConn(t, clientSide) tc.handshake() @@ -22,6 +30,9 @@ func TestPing(t *testing.T) { } func TestAck(t *testing.T) { + synctest.Test(t, testAck) +} +func testAck(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() diff --git a/quic/qlog_test.go b/quic/qlog_test.go index 08c2a77a81..47e4671160 100644 --- a/quic/qlog_test.go +++ b/quic/qlog_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -12,14 +14,16 @@ import ( "io" "log/slog" "reflect" + "sync" "testing" + "testing/synctest" "time" "golang.org/x/net/quic/qlog" ) func TestQLogHandshake(t *testing.T) { - testSides(t, "", func(t *testing.T, side connSide) { + testSidesSynctest(t, "", func(t *testing.T, side connSide) { qr := &qlogRecord{} tc := newTestConn(t, side, qr.config) tc.handshake() @@ -55,6 +59,9 @@ func TestQLogHandshake(t *testing.T) { } func TestQLogPacketFrames(t *testing.T) { + synctest.Test(t, testQLogPacketFrames) +} +func testQLogPacketFrames(t *testing.T) { qr := &qlogRecord{} tc := newTestConn(t, clientSide, qr.config) tc.handshake() @@ -111,7 +118,7 @@ func TestQLogConnectionClosedTrigger(t *testing.T) { tc.ignoreFrame(frameTypeCrypto) tc.ignoreFrame(frameTypeAck) tc.ignoreFrame(frameTypePing) - tc.advance(5 * time.Second) + time.Sleep(5 * time.Second) }, }, { trigger: "idle_timeout", @@ -122,7 +129,7 @@ func TestQLogConnectionClosedTrigger(t *testing.T) { }, f: func(tc *testConn) { tc.handshake() - tc.advance(5 * time.Second) + time.Sleep(5 * time.Second) }, }, { trigger: "error", @@ -134,7 +141,7 @@ func TestQLogConnectionClosedTrigger(t *testing.T) { tc.conn.Abort(nil) }, }} { - t.Run(test.trigger, func(t *testing.T) { + synctestSubtest(t, test.trigger, func(t *testing.T) { qr := &qlogRecord{} tc := newTestConn(t, clientSide, append(test.connOpts, qr.config)...) test.f(tc) @@ -147,7 +154,7 @@ func TestQLogConnectionClosedTrigger(t *testing.T) { t.Fatalf("unexpected frame: %v", fr) } tc.wantIdle("connection should be idle while closing") - tc.advance(5 * time.Second) // long enough for the drain timer to expire + time.Sleep(5 * time.Second) // long enough for the drain timer to expire qr.wantEvents(t, jsonEvent{ "name": "connectivity:connection_closed", "data": map[string]any{ @@ -159,6 +166,9 @@ func TestQLogConnectionClosedTrigger(t *testing.T) { } func TestQLogRecovery(t *testing.T) { + synctest.Test(t, testQLogRecovery) +} +func testQLogRecovery(t *testing.T) { qr := &qlogRecord{} tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, permissiveTransportParameters, qr.config) @@ -198,6 +208,9 @@ func TestQLogRecovery(t *testing.T) { } func TestQLogLoss(t *testing.T) { + synctest.Test(t, testQLogLoss) +} +func testQLogLoss(t *testing.T) { qr := &qlogRecord{} tc, s := newTestConnAndLocalStream(t, clientSide, uniStream, permissiveTransportParameters, qr.config) @@ -230,6 +243,9 @@ func TestQLogLoss(t *testing.T) { } func TestQLogPacketDropped(t *testing.T) { + synctest.Test(t, testQLogPacketDropped) +} +func testQLogPacketDropped(t *testing.T) { qr := &qlogRecord{} tc := newTestConn(t, clientSide, permissiveTransportParameters, qr.config) tc.handshake() @@ -324,10 +340,13 @@ func jsonPartialEqual(got, want any) (equal bool) { // A qlogRecord records events. type qlogRecord struct { + mu sync.Mutex ev []jsonEvent } func (q *qlogRecord) Write(b []byte) (int, error) { + q.mu.Lock() + defer q.mu.Unlock() // This relies on the property that the Handler always makes one Write call per event. if len(b) < 1 || b[0] != 0x1e { panic(fmt.Errorf("trace Write should start with record separator, got %q", string(b))) @@ -355,6 +374,8 @@ func (q *qlogRecord) config(c *Config) { // wantEvents checks that every event in want occurs in the order specified. func (q *qlogRecord) wantEvents(t *testing.T, want ...jsonEvent) { t.Helper() + q.mu.Lock() + defer q.mu.Unlock() got := q.ev if !jsonPartialEqual(got, want) { t.Fatalf("got events:\n%v\n\nwant events:\n%v", got, want) diff --git a/quic/queue.go b/quic/queue.go index 8b90ae7773..f2712f4012 100644 --- a/quic/queue.go +++ b/quic/queue.go @@ -42,9 +42,9 @@ func (q *queue[T]) put(v T) bool { // get removes the first item from the queue, blocking until ctx is done, an item is available, // or the queue is closed. -func (q *queue[T]) get(ctx context.Context, testHooks connTestHooks) (T, error) { +func (q *queue[T]) get(ctx context.Context) (T, error) { var zero T - if err := q.gate.waitAndLock(ctx, testHooks); err != nil { + if err := q.gate.waitAndLock(ctx); err != nil { return zero, err } defer q.unlock() diff --git a/quic/queue_test.go b/quic/queue_test.go index eee34e5ba7..a3907f31fb 100644 --- a/quic/queue_test.go +++ b/quic/queue_test.go @@ -16,8 +16,8 @@ func TestQueue(t *testing.T) { cancel() q := newQueue[int]() - if got, err := q.get(nonblocking, nil); err != context.Canceled { - t.Fatalf("q.get() = %v, %v, want nil, contex.Canceled", got, err) + if got, err := q.get(nonblocking); err != context.Canceled { + t.Fatalf("q.get() = %v, %v, want nil, context.Canceled", got, err) } if !q.put(1) { @@ -26,21 +26,21 @@ func TestQueue(t *testing.T) { if !q.put(2) { t.Fatalf("q.put(2) = false, want true") } - if got, err := q.get(nonblocking, nil); got != 1 || err != nil { + if got, err := q.get(nonblocking); got != 1 || err != nil { t.Fatalf("q.get() = %v, %v, want 1, nil", got, err) } - if got, err := q.get(nonblocking, nil); got != 2 || err != nil { + if got, err := q.get(nonblocking); got != 2 || err != nil { t.Fatalf("q.get() = %v, %v, want 2, nil", got, err) } - if got, err := q.get(nonblocking, nil); err != context.Canceled { - t.Fatalf("q.get() = %v, %v, want nil, contex.Canceled", got, err) + if got, err := q.get(nonblocking); err != context.Canceled { + t.Fatalf("q.get() = %v, %v, want nil, context.Canceled", got, err) } go func() { time.Sleep(1 * time.Millisecond) q.put(3) }() - if got, err := q.get(context.Background(), nil); got != 3 || err != nil { + if got, err := q.get(context.Background()); got != 3 || err != nil { t.Fatalf("q.get() = %v, %v, want 3, nil", got, err) } @@ -48,7 +48,7 @@ func TestQueue(t *testing.T) { t.Fatalf("q.put(2) = false, want true") } q.close(io.EOF) - if got, err := q.get(context.Background(), nil); got != 0 || err != io.EOF { + if got, err := q.get(context.Background()); got != 0 || err != io.EOF { t.Fatalf("q.get() = %v, %v, want 0, io.EOF", got, err) } if q.put(5) { diff --git a/quic/quic_test.go b/quic/quic_test.go index 071003e963..cdcc0d780f 100644 --- a/quic/quic_test.go +++ b/quic/quic_test.go @@ -2,10 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( "testing" + "testing/synctest" ) func testSides(t *testing.T, name string, f func(*testing.T, connSide)) { @@ -16,6 +19,16 @@ func testSides(t *testing.T, name string, f func(*testing.T, connSide)) { t.Run(name+"client", func(t *testing.T) { f(t, clientSide) }) } +func testSidesSynctest(t *testing.T, name string, f func(*testing.T, connSide)) { + t.Helper() + testSides(t, name, func(t *testing.T, side connSide) { + t.Helper() + synctest.Test(t, func(t *testing.T) { + f(t, side) + }) + }) +} + func testStreamTypes(t *testing.T, name string, f func(*testing.T, streamType)) { if name != "" { name += "/" @@ -24,6 +37,16 @@ func testStreamTypes(t *testing.T, name string, f func(*testing.T, streamType)) t.Run(name+"uni", func(t *testing.T) { f(t, uniStream) }) } +func testStreamTypesSynctest(t *testing.T, name string, f func(*testing.T, streamType)) { + t.Helper() + testStreamTypes(t, name, func(t *testing.T, stype streamType) { + t.Helper() + synctest.Test(t, func(t *testing.T) { + f(t, stype) + }) + }) +} + func testSidesAndStreamTypes(t *testing.T, name string, f func(*testing.T, connSide, streamType)) { if name != "" { name += "/" @@ -33,3 +56,20 @@ func testSidesAndStreamTypes(t *testing.T, name string, f func(*testing.T, connS t.Run(name+"server/uni", func(t *testing.T) { f(t, serverSide, uniStream) }) t.Run(name+"client/uni", func(t *testing.T) { f(t, clientSide, uniStream) }) } + +func testSidesAndStreamTypesSynctest(t *testing.T, name string, f func(*testing.T, connSide, streamType)) { + t.Helper() + testSidesAndStreamTypes(t, name, func(t *testing.T, side connSide, stype streamType) { + t.Helper() + synctest.Test(t, func(t *testing.T) { + f(t, side, stype) + }) + }) +} + +func synctestSubtest(t *testing.T, name string, f func(t *testing.T)) { + t.Run(name, func(t *testing.T) { + t.Helper() + synctest.Test(t, f) + }) +} diff --git a/quic/retry_test.go b/quic/retry_test.go index d6f025472e..7a4481c094 100644 --- a/quic/retry_test.go +++ b/quic/retry_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -10,6 +12,7 @@ import ( "crypto/tls" "net/netip" "testing" + "testing/synctest" "time" ) @@ -77,9 +80,12 @@ func newRetryServerTest(t *testing.T) *retryServerTest { } func TestRetryServerSucceeds(t *testing.T) { + synctest.Test(t, testRetryServerSucceeds) +} +func testRetryServerSucceeds(t *testing.T) { rt := newRetryServerTest(t) te := rt.te - te.advance(retryTokenValidityPeriod) + time.Sleep(retryTokenValidityPeriod) te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, @@ -117,6 +123,9 @@ func TestRetryServerSucceeds(t *testing.T) { } func TestRetryServerTokenInvalid(t *testing.T) { + synctest.Test(t, testRetryServerTokenInvalid) +} +func testRetryServerTokenInvalid(t *testing.T) { // "If a server receives a client Initial that contains an invalid Retry token [...] // the server SHOULD immediately close [...] the connection with an // INVALID_TOKEN error." @@ -147,11 +156,14 @@ func TestRetryServerTokenInvalid(t *testing.T) { } func TestRetryServerTokenTooOld(t *testing.T) { + synctest.Test(t, testRetryServerTokenTooOld) +} +func testRetryServerTokenTooOld(t *testing.T) { // "[...] a token SHOULD have an expiration time [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.3-3 rt := newRetryServerTest(t) te := rt.te - te.advance(retryTokenValidityPeriod + time.Second) + time.Sleep(retryTokenValidityPeriod + time.Second) te.writeDatagram(&testDatagram{ packets: []*testPacket{{ ptype: packetTypeInitial, @@ -176,6 +188,9 @@ func TestRetryServerTokenTooOld(t *testing.T) { } func TestRetryServerTokenWrongIP(t *testing.T) { + synctest.Test(t, testRetryServerTokenWrongIP) +} +func testRetryServerTokenWrongIP(t *testing.T) { // "Tokens sent in Retry packets SHOULD include information that allows the server // to verify that the source IP address and port in client packets remain constant." // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.4-3 @@ -206,6 +221,9 @@ func TestRetryServerTokenWrongIP(t *testing.T) { } func TestRetryServerIgnoresRetry(t *testing.T) { + synctest.Test(t, testRetryServerIgnoresRetry) +} +func testRetryServerIgnoresRetry(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() tc.write(&testDatagram{ @@ -225,6 +243,9 @@ func TestRetryServerIgnoresRetry(t *testing.T) { } func TestRetryClientSuccess(t *testing.T) { + synctest.Test(t, testRetryClientSuccess) +} +func testRetryClientSuccess(t *testing.T) { // "This token MUST be repeated by the client in all Initial packets it sends // for that connection after it receives the Retry packet." // https://www.rfc-editor.org/rfc/rfc9000#section-8.1.2-1 @@ -323,7 +344,7 @@ func TestRetryClientInvalidServerTransportParameters(t *testing.T) { p.retrySrcConnID = []byte("invalid") }, }} { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t *testing.T) { tc := newTestConn(t, clientSide, func(p *transportParameters) { p.initialSrcConnID = initialSrcConnID @@ -367,6 +388,9 @@ func TestRetryClientInvalidServerTransportParameters(t *testing.T) { } func TestRetryClientIgnoresRetryAfterReceivingPacket(t *testing.T) { + synctest.Test(t, testRetryClientIgnoresRetryAfterReceivingPacket) +} +func testRetryClientIgnoresRetryAfterReceivingPacket(t *testing.T) { // "After the client has received and processed an Initial or Retry packet // from the server, it MUST discard any subsequent Retry packets that it receives." // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-1 @@ -401,6 +425,9 @@ func TestRetryClientIgnoresRetryAfterReceivingPacket(t *testing.T) { } func TestRetryClientIgnoresRetryAfterReceivingRetry(t *testing.T) { + synctest.Test(t, testRetryClientIgnoresRetryAfterReceivingRetry) +} +func testRetryClientIgnoresRetryAfterReceivingRetry(t *testing.T) { // "After the client has received and processed an Initial or Retry packet // from the server, it MUST discard any subsequent Retry packets that it receives." // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-1 @@ -424,6 +451,9 @@ func TestRetryClientIgnoresRetryAfterReceivingRetry(t *testing.T) { } func TestRetryClientIgnoresRetryWithInvalidIntegrityTag(t *testing.T) { + synctest.Test(t, testRetryClientIgnoresRetryWithInvalidIntegrityTag) +} +func testRetryClientIgnoresRetryWithInvalidIntegrityTag(t *testing.T) { tc := newTestConn(t, clientSide) tc.wantFrameType("client Initial CRYPTO data", packetTypeInitial, debugFrameCrypto{}) @@ -441,6 +471,9 @@ func TestRetryClientIgnoresRetryWithInvalidIntegrityTag(t *testing.T) { } func TestRetryClientIgnoresRetryWithZeroLengthToken(t *testing.T) { + synctest.Test(t, testRetryClientIgnoresRetryWithZeroLengthToken) +} +func testRetryClientIgnoresRetryWithZeroLengthToken(t *testing.T) { // "A client MUST discard a Retry packet with a zero-length Retry Token field." // https://www.rfc-editor.org/rfc/rfc9000#section-17.2.5.2-2 tc := newTestConn(t, clientSide) diff --git a/quic/skip.go b/quic/skip.go index f5ba764f8a..f0d0234ee6 100644 --- a/quic/skip.go +++ b/quic/skip.go @@ -32,7 +32,7 @@ func (ss *skipState) init(c *Conn) { ss.updateNumberSkip(c) } -// shouldSkipAfter returns whether we should skip the given packet number. +// shouldSkip returns whether we should skip the given packet number. func (ss *skipState) shouldSkip(num packetNumber) bool { return ss.skip == num } diff --git a/quic/skip_test.go b/quic/skip_test.go index 1fcb735ff1..2c33378b02 100644 --- a/quic/skip_test.go +++ b/quic/skip_test.go @@ -2,11 +2,19 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic -import "testing" +import ( + "testing" + "testing/synctest" +) func TestSkipPackets(t *testing.T) { + synctest.Test(t, testSkipPackets) +} +func testSkipPackets(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) connWritesPacket := func() { s.WriteByte(0) @@ -39,6 +47,9 @@ expectSkip: } func TestSkipAckForSkippedPacket(t *testing.T) { + synctest.Test(t, testSkipAckForSkippedPacket) +} +func testSkipAckForSkippedPacket(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) // Cause the connection to send packets until it skips a packet number. diff --git a/quic/stateless_reset_test.go b/quic/stateless_reset_test.go index 33d467a95b..947375085a 100644 --- a/quic/stateless_reset_test.go +++ b/quic/stateless_reset_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -12,10 +14,14 @@ import ( "errors" "net/netip" "testing" + "testing/synctest" "time" ) func TestStatelessResetClientSendsStatelessResetTokenTransportParameter(t *testing.T) { + synctest.Test(t, testStatelessResetClientSendsStatelessResetTokenTransportParameter) +} +func testStatelessResetClientSendsStatelessResetTokenTransportParameter(t *testing.T) { // "[The stateless_reset_token] transport parameter MUST NOT be sent by a client [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-18.2-4.6.1 resetToken := testPeerStatelessResetToken(0) @@ -61,6 +67,9 @@ func newDatagramForReset(cid []byte, size int, addr netip.AddrPort) *datagram { } func TestStatelessResetSentSizes(t *testing.T) { + synctest.Test(t, testStatelessResetSentSizes) +} +func testStatelessResetSentSizes(t *testing.T) { config := &Config{ TLSConfig: newTestTLSConfig(serverSide), StatelessResetKey: testStatelessResetKey, @@ -126,6 +135,9 @@ func TestStatelessResetSentSizes(t *testing.T) { } func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { + synctest.Test(t, testStatelessResetSuccessfulNewConnectionID) +} +func testStatelessResetSuccessfulNewConnectionID(t *testing.T) { // "[...] Stateless Reset Token field values from [...] NEW_CONNECTION_ID frames [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-1 qr := &qlogRecord{} @@ -155,7 +167,7 @@ func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { t.Errorf("conn.Wait() = %v, want errStatelessReset", err) } tc.wantIdle("closed connection is idle in draining") - tc.advance(1 * time.Second) // long enough to exit the draining state + time.Sleep(1 * time.Second) // long enough to exit the draining state tc.wantIdle("closed connection is idle after draining") qr.wantEvents(t, jsonEvent{ @@ -167,6 +179,9 @@ func TestStatelessResetSuccessfulNewConnectionID(t *testing.T) { } func TestStatelessResetSuccessfulTransportParameter(t *testing.T) { + synctest.Test(t, testStatelessResetSuccessfulTransportParameter) +} +func testStatelessResetSuccessfulTransportParameter(t *testing.T) { // "[...] Stateless Reset Token field values from [...] // the server's transport parameters [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-1 @@ -229,7 +244,7 @@ func TestStatelessResetSuccessfulPrefix(t *testing.T) { }, testLocalConnID(0)...), size: 100, }} { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t *testing.T) { resetToken := testPeerStatelessResetToken(0) tc := newTestConn(t, clientSide, func(p *transportParameters) { p.statelessResetToken = resetToken[:] @@ -252,6 +267,9 @@ func TestStatelessResetSuccessfulPrefix(t *testing.T) { } func TestStatelessResetRetiredConnID(t *testing.T) { + synctest.Test(t, testStatelessResetRetiredConnID) +} +func testStatelessResetRetiredConnID(t *testing.T) { // "An endpoint MUST NOT check for any stateless reset tokens [...] // for connection IDs that have been retired." // https://www.rfc-editor.org/rfc/rfc9000#section-10.3.1-3 diff --git a/quic/stream.go b/quic/stream.go index b20cfe7fe0..4c632079a8 100644 --- a/quic/stream.go +++ b/quic/stream.go @@ -236,7 +236,7 @@ func (s *Stream) Read(b []byte) (n int, err error) { s.inbufoff += n return n, nil } - if err := s.ingate.waitAndLock(s.inctx, s.conn.testHooks); err != nil { + if err := s.ingate.waitAndLock(s.inctx); err != nil { return 0, err } if s.inbufoff > 0 { @@ -350,7 +350,7 @@ func (s *Stream) Write(b []byte) (n int, err error) { if len(b) > 0 && !canWrite { // Our send buffer is full. Wait for the peer to ack some data. s.outUnlock() - if err := s.outgate.waitAndLock(s.outctx, s.conn.testHooks); err != nil { + if err := s.outgate.waitAndLock(s.outctx); err != nil { return n, err } // Successfully returning from waitAndLockGate means we are no longer diff --git a/quic/stream_limits.go b/quic/stream_limits.go index ed31c365d3..f1abcae99c 100644 --- a/quic/stream_limits.go +++ b/quic/stream_limits.go @@ -29,7 +29,7 @@ func (lim *localStreamLimits) init() { // open creates a new local stream, blocking until MAX_STREAMS quota is available. func (lim *localStreamLimits) open(ctx context.Context, c *Conn) (num int64, err error) { // TODO: Send a STREAMS_BLOCKED when blocked. - if err := lim.gate.waitAndLock(ctx, c.testHooks); err != nil { + if err := lim.gate.waitAndLock(ctx); err != nil { return 0, err } if lim.opened < 0 { diff --git a/quic/stream_limits_test.go b/quic/stream_limits_test.go index ad634113b8..d62b29bbfa 100644 --- a/quic/stream_limits_test.go +++ b/quic/stream_limits_test.go @@ -2,19 +2,22 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( "context" "crypto/tls" "testing" + "testing/synctest" ) func TestStreamLimitNewStreamBlocked(t *testing.T) { // "An endpoint that receives a frame with a stream ID exceeding the limit // it has sent MUST treat this as a connection error of type STREAM_LIMIT_ERROR [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-4.6-3 - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { ctx := canceledContext() tc := newTestConn(t, clientSide, permissiveTransportParameters, @@ -46,7 +49,7 @@ func TestStreamLimitNewStreamBlocked(t *testing.T) { func TestStreamLimitMaxStreamsDecreases(t *testing.T) { // "MAX_STREAMS frames that do not increase the stream limit MUST be ignored." // https://www.rfc-editor.org/rfc/rfc9000#section-4.6-4 - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { ctx := canceledContext() tc := newTestConn(t, clientSide, permissiveTransportParameters, @@ -77,7 +80,7 @@ func TestStreamLimitMaxStreamsDecreases(t *testing.T) { } func TestStreamLimitViolated(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc := newTestConn(t, serverSide, func(c *Config) { if styp == bidiStream { @@ -104,7 +107,7 @@ func TestStreamLimitViolated(t *testing.T) { } func TestStreamLimitImplicitStreams(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc := newTestConn(t, serverSide, func(c *Config) { c.MaxBidiRemoteStreams = 1 << 60 @@ -152,7 +155,7 @@ func TestStreamLimitMaxStreamsTransportParameterTooLarge(t *testing.T) { // a value greater than 2^60 [...] the connection MUST be closed // immediately with a connection error of type TRANSPORT_PARAMETER_ERROR [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-4.6-2 - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc := newTestConn(t, serverSide, func(p *transportParameters) { if styp == bidiStream { @@ -177,7 +180,7 @@ func TestStreamLimitMaxStreamsFrameTooLarge(t *testing.T) { // greater than 2^60 [...] the connection MUST be closed immediately // with a connection error [...] of type FRAME_ENCODING_ERROR [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-4.6-2 - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc := newTestConn(t, serverSide) tc.handshake() tc.writeFrames(packetTypeInitial, @@ -197,7 +200,7 @@ func TestStreamLimitMaxStreamsFrameTooLarge(t *testing.T) { } func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc := newTestConn(t, serverSide, func(c *Config) { if styp == uniStream { c.MaxUniRemoteStreams = 4 @@ -236,6 +239,9 @@ func TestStreamLimitSendUpdatesMaxStreams(t *testing.T) { } func TestStreamLimitStopSendingDoesNotUpdateMaxStreams(t *testing.T) { + synctest.Test(t, testStreamLimitStopSendingDoesNotUpdateMaxStreams) +} +func testStreamLimitStopSendingDoesNotUpdateMaxStreams(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, bidiStream, func(c *Config) { c.MaxBidiRemoteStreams = 1 }) diff --git a/quic/stream_test.go b/quic/stream_test.go index 4119cc1e74..67d17f6546 100644 --- a/quic/stream_test.go +++ b/quic/stream_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -13,12 +15,13 @@ import ( "io" "strings" "testing" + "testing/synctest" "golang.org/x/net/internal/quic/quicwire" ) func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const writeBufferSize = 4 tc := newTestConn(t, clientSide, permissiveTransportParameters, func(c *Config) { @@ -79,7 +82,7 @@ func TestStreamWriteBlockedByOutputBuffer(t *testing.T) { } func TestStreamWriteBlockedByStreamFlowControl(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { ctx := canceledContext() want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} tc := newTestConn(t, clientSide, func(p *transportParameters) { @@ -149,7 +152,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { // "A sender MUST ignore any MAX_STREAM_DATA [...] frames that // do not increase flow control limits." // https://www.rfc-editor.org/rfc/rfc9000#section-4.1-9 - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { ctx := canceledContext() want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} tc := newTestConn(t, clientSide, func(p *transportParameters) { @@ -218,7 +221,7 @@ func TestStreamIgnoresMaxStreamDataReduction(t *testing.T) { } func TestStreamWriteBlockedByWriteBufferLimit(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} const maxWriteBuffer = 4 tc := newTestConn(t, clientSide, func(p *transportParameters) { @@ -392,7 +395,7 @@ func TestStreamReceive(t *testing.T) { wantEOF: true, }}, }} { - testStreamTypes(t, test.name, func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, test.name, func(t *testing.T, styp streamType) { tc := newTestConn(t, serverSide) tc.handshake() sid := newStreamID(clientSide, styp, 0) @@ -439,7 +442,7 @@ func TestStreamReceive(t *testing.T) { } func TestStreamReceiveExtendsStreamWindow(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { const maxWindowSize = 20 ctx := canceledContext() tc := newTestConn(t, serverSide, func(c *Config) { @@ -484,7 +487,7 @@ func TestStreamReceiveViolatesStreamDataLimit(t *testing.T) { // "A receiver MUST close the connection with an error of type FLOW_CONTROL_ERROR if // the sender violates the advertised [...] stream data limits [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-4.1-8 - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { const maxStreamData = 10 for _, test := range []struct { off int64 @@ -521,7 +524,7 @@ func TestStreamReceiveViolatesStreamDataLimit(t *testing.T) { } func TestStreamReceiveDuplicateDataDoesNotViolateLimits(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { const maxData = 10 tc := newTestConn(t, serverSide, func(c *Config) { // TODO: Add connection-level maximum data here as well. @@ -544,7 +547,7 @@ func TestStreamReceiveEmptyEOF(t *testing.T) { // A stream receives some data, we read a byte of that data // (causing the rest to be pulled into the s.inbuf buffer), // and then we receive a FIN with no additional data. - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndRemoteStream(t, serverSide, styp, permissiveTransportParameters) want := []byte{1, 2, 3} tc.writeFrames(packetType1RTT, debugFrameStream{ @@ -568,7 +571,7 @@ func TestStreamReceiveEmptyEOF(t *testing.T) { func TestStreamReadByteFromOneByteStream(t *testing.T) { // ReadByte on the only byte of a stream should not return an error. - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndRemoteStream(t, serverSide, styp, permissiveTransportParameters) want := byte(1) tc.writeFrames(packetType1RTT, debugFrameStream{ @@ -608,7 +611,7 @@ func finalSizeTest(t *testing.T, wantErr transportError, f func(tc *testConn, si }) }, }} { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t *testing.T) { tc := newTestConn(t, serverSide, opts...) tc.handshake() sid := newStreamID(clientSide, styp, 0) @@ -662,7 +665,7 @@ func TestStreamDataBeyondFinalSize(t *testing.T) { // "A receiver SHOULD treat receipt of data at or beyond // the final size as an error of type FINAL_SIZE_ERROR [...]" // https://www.rfc-editor.org/rfc/rfc9000#section-4.5-5 - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc := newTestConn(t, serverSide) tc.handshake() sid := newStreamID(clientSide, styp, 0) @@ -688,7 +691,7 @@ func TestStreamDataBeyondFinalSize(t *testing.T) { } func TestStreamReceiveUnblocksReader(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc := newTestConn(t, serverSide) tc.handshake() want := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} @@ -746,7 +749,7 @@ func TestStreamReceiveUnblocksReader(t *testing.T) { // It then sends the returned frame (STREAM, STREAM_DATA_BLOCKED, etc.) // to the conn and expects a STREAM_STATE_ERROR. func testStreamSendFrameInvalidState(t *testing.T, f func(sid streamID) debugFrame) { - testSides(t, "stream_not_created", func(t *testing.T, side connSide) { + testSidesSynctest(t, "stream_not_created", func(t *testing.T, side connSide) { tc := newTestConn(t, side, permissiveTransportParameters) tc.handshake() tc.writeFrames(packetType1RTT, f(newStreamID(side, bidiStream, 0))) @@ -755,7 +758,7 @@ func testStreamSendFrameInvalidState(t *testing.T, f func(sid streamID) debugFra code: errStreamState, }) }) - testSides(t, "uni_stream", func(t *testing.T, side connSide) { + testSidesSynctest(t, "uni_stream", func(t *testing.T, side connSide) { ctx := canceledContext() tc := newTestConn(t, side, permissiveTransportParameters) tc.handshake() @@ -823,7 +826,7 @@ func TestStreamDataBlockedInvalidState(t *testing.T) { // It then sends the returned frame (MAX_STREAM_DATA, STOP_SENDING, etc.) // to the conn and expects a STREAM_STATE_ERROR. func testStreamReceiveFrameInvalidState(t *testing.T, f func(sid streamID) debugFrame) { - testSides(t, "stream_not_created", func(t *testing.T, side connSide) { + testSidesSynctest(t, "stream_not_created", func(t *testing.T, side connSide) { tc := newTestConn(t, side) tc.handshake() tc.writeFrames(packetType1RTT, f(newStreamID(side, bidiStream, 0))) @@ -832,7 +835,7 @@ func testStreamReceiveFrameInvalidState(t *testing.T, f func(sid streamID) debug code: errStreamState, }) }) - testSides(t, "uni_stream", func(t *testing.T, side connSide) { + testSidesSynctest(t, "uni_stream", func(t *testing.T, side connSide) { tc := newTestConn(t, side) tc.handshake() tc.writeFrames(packetType1RTT, f(newStreamID(side.peer(), uniStream, 0))) @@ -873,6 +876,9 @@ func TestStreamMaxStreamDataInvalidState(t *testing.T) { } func TestStreamOffsetTooLarge(t *testing.T) { + synctest.Test(t, testStreamOffsetTooLarge) +} +func testStreamOffsetTooLarge(t *testing.T) { // "Receipt of a frame that exceeds [2^62-1] MUST be treated as a // connection error of type FRAME_ENCODING_ERROR or FLOW_CONTROL_ERROR." // https://www.rfc-editor.org/rfc/rfc9000.html#section-19.8-9 @@ -894,6 +900,9 @@ func TestStreamOffsetTooLarge(t *testing.T) { } func TestStreamReadFromWriteOnlyStream(t *testing.T) { + synctest.Test(t, testStreamReadFromWriteOnlyStream) +} +func testStreamReadFromWriteOnlyStream(t *testing.T) { _, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) buf := make([]byte, 10) wantErr := "read from write-only stream" @@ -903,6 +912,9 @@ func TestStreamReadFromWriteOnlyStream(t *testing.T) { } func TestStreamWriteToReadOnlyStream(t *testing.T) { + synctest.Test(t, testStreamWriteToReadOnlyStream) +} +func testStreamWriteToReadOnlyStream(t *testing.T) { _, s := newTestConnAndRemoteStream(t, serverSide, uniStream) buf := make([]byte, 10) wantErr := "write to read-only stream" @@ -912,6 +924,9 @@ func TestStreamWriteToReadOnlyStream(t *testing.T) { } func TestStreamReadFromClosedStream(t *testing.T) { + synctest.Test(t, testStreamReadFromClosedStream) +} +func testStreamReadFromClosedStream(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, bidiStream, permissiveTransportParameters) s.CloseRead() tc.wantFrame("CloseRead sends a STOP_SENDING frame", @@ -934,6 +949,9 @@ func TestStreamReadFromClosedStream(t *testing.T) { } func TestStreamCloseReadWithAllDataReceived(t *testing.T) { + synctest.Test(t, testStreamCloseReadWithAllDataReceived) +} +func testStreamCloseReadWithAllDataReceived(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, bidiStream, permissiveTransportParameters) tc.writeFrames(packetType1RTT, debugFrameStream{ id: s.id, @@ -950,6 +968,9 @@ func TestStreamCloseReadWithAllDataReceived(t *testing.T) { } func TestStreamWriteToClosedStream(t *testing.T) { + synctest.Test(t, testStreamWriteToClosedStream) +} +func testStreamWriteToClosedStream(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, bidiStream, permissiveTransportParameters) s.CloseWrite() tc.wantFrame("stream is opened after being closed", @@ -966,6 +987,9 @@ func TestStreamWriteToClosedStream(t *testing.T) { } func TestStreamResetBlockedStream(t *testing.T) { + synctest.Test(t, testStreamResetBlockedStream) +} +func testStreamResetBlockedStream(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, bidiStream, permissiveTransportParameters, func(c *Config) { c.MaxStreamWriteBufferSize = 4 @@ -1002,6 +1026,9 @@ func TestStreamResetBlockedStream(t *testing.T) { } func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { + synctest.Test(t, testStreamWriteMoreThanOnePacketOfData) +} +func testStreamWriteMoreThanOnePacketOfData(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, func(p *transportParameters) { p.initialMaxStreamsUni = 1 p.initialMaxData = 1 << 20 @@ -1038,6 +1065,9 @@ func TestStreamWriteMoreThanOnePacketOfData(t *testing.T) { } func TestStreamCloseWaitsForAcks(t *testing.T) { + synctest.Test(t, testStreamCloseWaitsForAcks) +} +func testStreamCloseWaitsForAcks(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) s.Write(data) @@ -1071,6 +1101,9 @@ func TestStreamCloseWaitsForAcks(t *testing.T) { } func TestStreamCloseReadOnly(t *testing.T) { + synctest.Test(t, testStreamCloseReadOnly) +} +func testStreamCloseReadOnly(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream, permissiveTransportParameters) if err := s.Close(); err != nil { t.Errorf("s.Close() = %v, want nil", err) @@ -1103,10 +1136,10 @@ func TestStreamCloseUnblocked(t *testing.T) { name: "stream reset", unblock: func(tc *testConn, s *Stream) { s.Reset(0) - tc.wait() // wait for test conn to process the Reset + synctest.Wait() // wait for test conn to process the Reset }, }} { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) data := make([]byte, 100) s.Write(data) @@ -1148,6 +1181,9 @@ func TestStreamCloseUnblocked(t *testing.T) { } func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) { + synctest.Test(t, testStreamCloseWriteWhenBlockedByStreamFlowControl) +} +func testStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters, func(p *transportParameters) { //p.initialMaxData = 0 @@ -1185,7 +1221,7 @@ func TestStreamCloseWriteWhenBlockedByStreamFlowControl(t *testing.T) { } func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndRemoteStream(t, serverSide, styp) data := []byte{0, 1, 2, 3, 4, 5, 6, 7} tc.writeFrames(packetType1RTT, debugFrameStream{ @@ -1210,7 +1246,7 @@ func TestStreamPeerResetsWithUnreadAndUnsentData(t *testing.T) { } func TestStreamPeerResetWakesBlockedRead(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndRemoteStream(t, serverSide, styp) reader := runAsync(tc, func(ctx context.Context) (int, error) { s.SetReadContext(ctx) @@ -1231,7 +1267,7 @@ func TestStreamPeerResetWakesBlockedRead(t *testing.T) { } func TestStreamPeerResetFollowedByData(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndRemoteStream(t, serverSide, styp) tc.writeFrames(packetType1RTT, debugFrameResetStream{ id: s.id, @@ -1256,6 +1292,9 @@ func TestStreamPeerResetFollowedByData(t *testing.T) { } func TestStreamResetInvalidCode(t *testing.T) { + synctest.Test(t, testStreamResetInvalidCode) +} +func testStreamResetInvalidCode(t *testing.T) { tc, s := newTestConnAndLocalStream(t, serverSide, uniStream, permissiveTransportParameters) s.Reset(1 << 62) tc.wantFrame("reset with invalid code sends a RESET_STREAM anyway", @@ -1268,6 +1307,9 @@ func TestStreamResetInvalidCode(t *testing.T) { } func TestStreamResetReceiveOnly(t *testing.T) { + synctest.Test(t, testStreamResetReceiveOnly) +} +func testStreamResetReceiveOnly(t *testing.T) { tc, s := newTestConnAndRemoteStream(t, serverSide, uniStream) s.Reset(0) tc.wantIdle("resetting a receive-only stream has no effect") @@ -1277,7 +1319,7 @@ func TestStreamPeerStopSendingForActiveStream(t *testing.T) { // "An endpoint that receives a STOP_SENDING frame MUST send a RESET_STREAM frame if // the stream is in the "Ready" or "Send" state." // https://www.rfc-editor.org/rfc/rfc9000#section-3.5-4 - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndLocalStream(t, serverSide, styp, permissiveTransportParameters) for i := 0; i < 4; i++ { s.Write([]byte{byte(i)}) @@ -1309,6 +1351,9 @@ func TestStreamPeerStopSendingForActiveStream(t *testing.T) { } func TestStreamReceiveDataBlocked(t *testing.T) { + synctest.Test(t, testStreamReceiveDataBlocked) +} +func testStreamReceiveDataBlocked(t *testing.T) { tc := newTestConn(t, serverSide, permissiveTransportParameters) tc.handshake() tc.ignoreFrame(frameTypeAck) @@ -1326,7 +1371,7 @@ func TestStreamReceiveDataBlocked(t *testing.T) { } func TestStreamFlushExplicit(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { tc, s := newTestConnAndLocalStream(t, clientSide, styp, permissiveTransportParameters) want := []byte{0, 1, 2, 3} n, err := s.Write(want) @@ -1344,6 +1389,9 @@ func TestStreamFlushExplicit(t *testing.T) { } func TestStreamFlushClosedStream(t *testing.T) { + synctest.Test(t, testStreamFlushClosedStream) +} +func testStreamFlushClosedStream(t *testing.T) { _, s := newTestConnAndLocalStream(t, clientSide, bidiStream, permissiveTransportParameters) s.Close() @@ -1353,6 +1401,9 @@ func TestStreamFlushClosedStream(t *testing.T) { } func TestStreamFlushResetStream(t *testing.T) { + synctest.Test(t, testStreamFlushResetStream) +} +func testStreamFlushResetStream(t *testing.T) { _, s := newTestConnAndLocalStream(t, clientSide, bidiStream, permissiveTransportParameters) s.Reset(0) @@ -1362,6 +1413,9 @@ func TestStreamFlushResetStream(t *testing.T) { } func TestStreamFlushStreamAfterPeerStopSending(t *testing.T) { + synctest.Test(t, testStreamFlushStreamAfterPeerStopSending) +} +func testStreamFlushStreamAfterPeerStopSending(t *testing.T) { tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream, permissiveTransportParameters) s.Flush() // create the stream @@ -1381,6 +1435,9 @@ func TestStreamFlushStreamAfterPeerStopSending(t *testing.T) { } func TestStreamErrorsAfterConnectionClosed(t *testing.T) { + synctest.Test(t, testStreamErrorsAfterConnectionClosed) +} +func testStreamErrorsAfterConnectionClosed(t *testing.T) { tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream, permissiveTransportParameters) wantErr := &ApplicationError{Code: 42} @@ -1399,7 +1456,7 @@ func TestStreamErrorsAfterConnectionClosed(t *testing.T) { } func TestStreamFlushImplicitExact(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { const writeBufferSize = 4 tc, s := newTestConnAndLocalStream(t, clientSide, styp, permissiveTransportParameters, @@ -1429,7 +1486,7 @@ func TestStreamFlushImplicitExact(t *testing.T) { } func TestStreamFlushImplicitLargerThanBuffer(t *testing.T) { - testStreamTypes(t, "", func(t *testing.T, styp streamType) { + testStreamTypesSynctest(t, "", func(t *testing.T, styp streamType) { const writeBufferSize = 4 tc, s := newTestConnAndLocalStream(t, clientSide, styp, permissiveTransportParameters, diff --git a/quic/tls.go b/quic/tls.go index 171d5a3138..9f6e0bc29a 100644 --- a/quic/tls.go +++ b/quic/tls.go @@ -33,7 +33,7 @@ func (c *Conn) startTLS(now time.Time, initialConnID []byte, peerHostname string c.tls = tls.QUICServer(qconfig) } c.tls.SetTransportParameters(marshalTransportParameters(params)) - // TODO: We don't need or want a context for cancelation here, + // TODO: We don't need or want a context for cancellation here, // but users can use a context to plumb values through to hooks defined // in the tls.Config. Pass through a context. if err := c.tls.Start(context.TODO()); err != nil { diff --git a/quic/tls_test.go b/quic/tls_test.go index 21f782eade..0818c68859 100644 --- a/quic/tls_test.go +++ b/quic/tls_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -9,12 +11,12 @@ import ( "crypto/x509" "errors" "testing" + "testing/synctest" "time" ) // handshake executes the handshake. func (tc *testConn) handshake() { - tc.t.Helper() if *testVV { *testVV = false defer func() { @@ -32,16 +34,16 @@ func (tc *testConn) handshake() { i := 0 for { if i == len(dgrams)-1 { + want := time.Now().Add(maxAckDelay - timerGranularity) if tc.conn.side == clientSide { - want := tc.endpoint.now.Add(maxAckDelay - timerGranularity) - if !tc.timer.Equal(want) { - t.Fatalf("want timer = %v (max_ack_delay), got %v", want, tc.timer) + if got := tc.nextEvent(); !got.Equal(want) { + t.Fatalf("want timer = %v (max_ack_delay), got %v", want, got) } if got := tc.readDatagram(); got != nil { t.Fatalf("client unexpectedly sent: %v", got) } } - tc.advance(maxAckDelay) + time.Sleep(time.Until(want)) } // Check that we're sending exactly the data we expect. @@ -209,7 +211,7 @@ func handshakeDatagrams(tc *testConn) (dgrams []*testDatagram) { frames: []debugFrame{ debugFrameAck{ ackDelay: unscaledAckDelayFromDuration( - maxAckDelay, ackDelayExponent), + maxAckDelay-timerGranularity, ackDelayExponent), ranges: []i64range[packetNumber]{{0, 2}}, }, }, @@ -308,20 +310,29 @@ func (tc *testConn) uncheckedHandshake() { } func TestConnClientHandshake(t *testing.T) { + synctest.Test(t, testConnClientHandshake) +} +func testConnClientHandshake(t *testing.T) { tc := newTestConn(t, clientSide) tc.handshake() - tc.advance(1 * time.Second) + time.Sleep(1 * time.Second) tc.wantIdle("no packets should be sent by an idle conn after the handshake") } func TestConnServerHandshake(t *testing.T) { + synctest.Test(t, testConnServerHandshake) +} +func testConnServerHandshake(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() - tc.advance(1 * time.Second) + time.Sleep(1 * time.Second) tc.wantIdle("no packets should be sent by an idle conn after the handshake") } func TestConnKeysDiscardedClient(t *testing.T) { + synctest.Test(t, testConnKeysDiscardedClient) +} +func testConnKeysDiscardedClient(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) @@ -370,6 +381,9 @@ func TestConnKeysDiscardedClient(t *testing.T) { } func TestConnKeysDiscardedServer(t *testing.T) { + synctest.Test(t, testConnKeysDiscardedServer) +} +func testConnKeysDiscardedServer(t *testing.T) { tc := newTestConn(t, serverSide) tc.ignoreFrame(frameTypeAck) @@ -425,6 +439,9 @@ func TestConnKeysDiscardedServer(t *testing.T) { } func TestConnInvalidCryptoData(t *testing.T) { + synctest.Test(t, testConnInvalidCryptoData) +} +func testConnInvalidCryptoData(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) @@ -455,6 +472,9 @@ func TestConnInvalidCryptoData(t *testing.T) { } func TestConnInvalidPeerCertificate(t *testing.T) { + synctest.Test(t, testConnInvalidPeerCertificate) +} +func testConnInvalidPeerCertificate(t *testing.T) { tc := newTestConn(t, clientSide, func(c *tls.Config) { c.VerifyPeerCertificate = func([][]byte, [][]*x509.Certificate) error { return errors.New("I will not buy this certificate. It is scratched.") @@ -481,6 +501,9 @@ func TestConnInvalidPeerCertificate(t *testing.T) { } func TestConnHandshakeDoneSentToServer(t *testing.T) { + synctest.Test(t, testConnHandshakeDoneSentToServer) +} +func testConnHandshakeDoneSentToServer(t *testing.T) { tc := newTestConn(t, serverSide) tc.handshake() @@ -493,6 +516,9 @@ func TestConnHandshakeDoneSentToServer(t *testing.T) { } func TestConnCryptoDataOutOfOrder(t *testing.T) { + synctest.Test(t, testConnCryptoDataOutOfOrder) +} +func testConnCryptoDataOutOfOrder(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) @@ -531,6 +557,9 @@ func TestConnCryptoDataOutOfOrder(t *testing.T) { } func TestConnCryptoBufferSizeExceeded(t *testing.T) { + synctest.Test(t, testConnCryptoBufferSizeExceeded) +} +func testConnCryptoBufferSizeExceeded(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) @@ -550,6 +579,9 @@ func TestConnCryptoBufferSizeExceeded(t *testing.T) { } func TestConnAEADLimitReached(t *testing.T) { + synctest.Test(t, testConnAEADLimitReached) +} +func testConnAEADLimitReached(t *testing.T) { // "[...] endpoints MUST count the number of received packets that // fail authentication during the lifetime of a connection. // If the total number of received packets that fail authentication [...] @@ -590,7 +622,7 @@ func TestConnAEADLimitReached(t *testing.T) { tc.conn.sendMsg(&datagram{ b: invalid, }) - tc.wait() + synctest.Wait() } // Set the conn's auth failure count to just before the AEAD integrity limit. @@ -610,11 +642,14 @@ func TestConnAEADLimitReached(t *testing.T) { }) tc.writeFrames(packetType1RTT, debugFramePing{}) - tc.advance(1 * time.Second) + time.Sleep(1 * time.Second) tc.wantIdle("auth failures at limit: conn does not process additional packets") } func TestConnKeysDiscardedWithExcessCryptoData(t *testing.T) { + synctest.Test(t, testConnKeysDiscardedWithExcessCryptoData) +} +func testConnKeysDiscardedWithExcessCryptoData(t *testing.T) { tc := newTestConn(t, serverSide, permissiveTransportParameters) tc.ignoreFrame(frameTypeAck) tc.ignoreFrame(frameTypeNewConnectionID) diff --git a/quic/tlsconfig_test.go b/quic/tlsconfig_test.go index b1305ec00f..8a07b0b13c 100644 --- a/quic/tlsconfig_test.go +++ b/quic/tlsconfig_test.go @@ -21,7 +21,7 @@ func newTestTLSConfig(side connSide) *tls.Config { MinVersion: tls.VersionTLS13, // Default key exchange mechanisms as of Go 1.23 minus X25519Kyber768Draft00, // which bloats the client hello enough to spill into a second datagram. - // Tests were written with the assuption each flight in the handshake + // Tests were written with the assumption each flight in the handshake // fits in one datagram, and it's simpler to keep that property. CurvePreferences: []tls.CurveID{ tls.X25519, tls.CurveP256, tls.CurveP384, tls.CurveP521, diff --git a/quic/version_test.go b/quic/version_test.go index 60d83078d3..ac054a83cf 100644 --- a/quic/version_test.go +++ b/quic/version_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 + package quic import ( @@ -9,9 +11,13 @@ import ( "context" "crypto/tls" "testing" + "testing/synctest" ) func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { + synctest.Test(t, testVersionNegotiationServerReceivesUnknownVersion) +} +func testVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { config := &Config{ TLSConfig: newTestTLSConfig(serverSide), } @@ -55,6 +61,9 @@ func TestVersionNegotiationServerReceivesUnknownVersion(t *testing.T) { } func TestVersionNegotiationClientAborts(t *testing.T) { + synctest.Test(t, testVersionNegotiationClientAborts) +} +func testVersionNegotiationClientAborts(t *testing.T) { tc := newTestConn(t, clientSide) p := tc.readPacket() // client Initial packet tc.endpoint.write(&datagram{ @@ -67,6 +76,9 @@ func TestVersionNegotiationClientAborts(t *testing.T) { } func TestVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) { + synctest.Test(t, testVersionNegotiationClientIgnoresAfterProcessingPacket) +} +func testVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) p := tc.readPacket() // client Initial packet @@ -89,6 +101,9 @@ func TestVersionNegotiationClientIgnoresAfterProcessingPacket(t *testing.T) { } func TestVersionNegotiationClientIgnoresMismatchingSourceConnID(t *testing.T) { + synctest.Test(t, testVersionNegotiationClientIgnoresMismatchingSourceConnID) +} +func testVersionNegotiationClientIgnoresMismatchingSourceConnID(t *testing.T) { tc := newTestConn(t, clientSide) tc.ignoreFrame(frameTypeAck) p := tc.readPacket() // client Initial packet diff --git a/trace/events.go b/trace/events.go index 3aaffdd1f7..c2b3c00980 100644 --- a/trace/events.go +++ b/trace/events.go @@ -58,8 +58,8 @@ func RenderEvents(w http.ResponseWriter, req *http.Request, sensitive bool) { Buckets: buckets, } - data.Families = make([]string, 0, len(families)) famMu.RLock() + data.Families = make([]string, 0, len(families)) for name := range families { data.Families = append(data.Families, name) } diff --git a/webdav/if_test.go b/webdav/if_test.go index aad61a4010..fd5d18c979 100644 --- a/webdav/if_test.go +++ b/webdav/if_test.go @@ -134,7 +134,7 @@ func TestParseIfHeader(t *testing.T) { }, }, { "section 7.5.1", - ` + ` ()`, ifHeader{ lists: []ifList{{ @@ -180,7 +180,7 @@ func TestParseIfHeader(t *testing.T) { }, }, { "section 9.9.6", - `() + `() ()`, ifHeader{ lists: []ifList{{ @@ -205,7 +205,7 @@ func TestParseIfHeader(t *testing.T) { }, }, { "section 10.4.6", - `( + `( ["I am an ETag"]) (["I am another ETag"])`, ifHeader{ @@ -223,7 +223,7 @@ func TestParseIfHeader(t *testing.T) { }, }, { "section 10.4.7", - `(Not + `(Not )`, ifHeader{ lists: []ifList{{ @@ -237,7 +237,7 @@ func TestParseIfHeader(t *testing.T) { }, }, { "section 10.4.8", - `() + `() (Not )`, ifHeader{ lists: []ifList{{ @@ -253,8 +253,8 @@ func TestParseIfHeader(t *testing.T) { }, }, { "section 10.4.9", - ` - ( + ` + ( [W/"A weak ETag"]) (["strong ETag"])`, ifHeader{ lists: []ifList{{ @@ -273,7 +273,7 @@ func TestParseIfHeader(t *testing.T) { }, }, { "section 10.4.10", - ` + ` ()`, ifHeader{ lists: []ifList{{ diff --git a/webdav/internal/xml/marshal.go b/webdav/internal/xml/marshal.go index 4dd0f417fd..a0ec9cba8d 100644 --- a/webdav/internal/xml/marshal.go +++ b/webdav/internal/xml/marshal.go @@ -546,9 +546,9 @@ func (p *printer) setAttrPrefix(prefix, url string) { } var ( - marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() - marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem() - textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + marshalerType = reflect.TypeFor[Marshaler]() + marshalerAttrType = reflect.TypeFor[MarshalerAttr]() + textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]() ) // marshalValue writes one or more XML elements representing val. diff --git a/webdav/internal/xml/marshal_test.go b/webdav/internal/xml/marshal_test.go index 226cfd013f..99b5af8ee8 100644 --- a/webdav/internal/xml/marshal_test.go +++ b/webdav/internal/xml/marshal_test.go @@ -1846,7 +1846,7 @@ func TestDecodeEncode(t *testing.T) { in.WriteString(` - + `) dec := NewDecoder(&in) enc := NewEncoder(&out) diff --git a/webdav/internal/xml/read.go b/webdav/internal/xml/read.go index bfaef6f17f..2ba3bb4a9b 100644 --- a/webdav/internal/xml/read.go +++ b/webdav/internal/xml/read.go @@ -262,9 +262,9 @@ func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { } var ( - unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() - unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem() - textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + unmarshalerType = reflect.TypeFor[Unmarshaler]() + unmarshalerAttrType = reflect.TypeFor[UnmarshalerAttr]() + textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() ) // Unmarshal a single XML element into val. diff --git a/webdav/internal/xml/read_test.go b/webdav/internal/xml/read_test.go index 02f1e10c33..e587d11fdf 100644 --- a/webdav/internal/xml/read_test.go +++ b/webdav/internal/xml/read_test.go @@ -325,10 +325,10 @@ type BadPathEmbeddedB struct { var badPathTests = []struct { v, e interface{} }{ - {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items"}}, - {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, - {&BadPathTestC{}, &TagPathError{reflect.TypeOf(BadPathTestC{}), "First", "", "Second", "First"}}, - {&BadPathTestD{}, &TagPathError{reflect.TypeOf(BadPathTestD{}), "First", "", "Second", "First"}}, + {&BadPathTestA{}, &TagPathError{reflect.TypeFor[BadPathTestA](), "First", "items>item1", "Second", "items"}}, + {&BadPathTestB{}, &TagPathError{reflect.TypeFor[BadPathTestB](), "First", "items>item1", "Second", "items>item1>value"}}, + {&BadPathTestC{}, &TagPathError{reflect.TypeFor[BadPathTestC](), "First", "", "Second", "First"}}, + {&BadPathTestD{}, &TagPathError{reflect.TypeFor[BadPathTestD](), "First", "", "Second", "First"}}, } func TestUnmarshalBadPaths(t *testing.T) { diff --git a/webdav/internal/xml/typeinfo.go b/webdav/internal/xml/typeinfo.go index fdde288bc3..b0b0f55a10 100644 --- a/webdav/internal/xml/typeinfo.go +++ b/webdav/internal/xml/typeinfo.go @@ -44,7 +44,7 @@ const ( var tinfoMap = make(map[reflect.Type]*typeInfo) var tinfoLock sync.RWMutex -var nameType = reflect.TypeOf(Name{}) +var nameType = reflect.TypeFor[Name]() // getTypeInfo returns the typeInfo structure with details necessary // for marshalling and unmarshalling typ. @@ -258,13 +258,6 @@ func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) { return nil } -func min(a, b int) int { - if a <= b { - return a - } - return b -} - // addFieldInfo adds finfo to tinfo.fields if there are no // conflicts, or if conflicts arise from previous fields that were // obtained from deeper embedded structures than finfo. In the latter diff --git a/webdav/prop_test.go b/webdav/prop_test.go index f4247e69b5..d085dac223 100644 --- a/webdav/prop_test.go +++ b/webdav/prop_test.go @@ -556,7 +556,7 @@ func TestMemPS(t *testing.T) { sort.Sort(byStatus(propstats)) sort.Sort(byStatus(op.wantPropstats)) if !reflect.DeepEqual(propstats, op.wantPropstats) { - t.Errorf("%s: propstat\ngot %q\nwant %q", desc, propstats, op.wantPropstats) + t.Errorf("%s: propstat\ngot %#v\nwant %#v", desc, propstats, op.wantPropstats) } } } diff --git a/webdav/webdav_test.go b/webdav/webdav_test.go index deb60fb885..54380e3743 100644 --- a/webdav/webdav_test.go +++ b/webdav/webdav_test.go @@ -53,7 +53,19 @@ func TestPrefix(t *testing.T) { return nil, err } defer res.Body.Close() - if res.StatusCode != wantStatusCode { + isRedirect := func(code int) bool { + switch code { + case http.StatusMovedPermanently, + http.StatusTemporaryRedirect, + http.StatusPermanentRedirect: + return true + default: + return false + } + } + if isRedirect(res.StatusCode) && isRedirect(wantStatusCode) { + // Allow any redirect. + } else if res.StatusCode != wantStatusCode { return nil, fmt.Errorf("got status code %d, want %d", res.StatusCode, wantStatusCode) } return res.Header, nil diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go index f0715d3f6f..5db22ad553 100644 --- a/websocket/hybi_test.go +++ b/websocket/hybi_test.go @@ -190,7 +190,7 @@ Sec-WebSocket-Version: 13 t.Errorf("handshake failed: %v", err) } if code != http.StatusSwitchingProtocols { - t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) + t.Errorf("status expected %d but got %d", http.StatusSwitchingProtocols, code) } expectedProtocols := []string{"chat", "superchat"} if fmt.Sprintf("%v", config.Protocol) != fmt.Sprintf("%v", expectedProtocols) { @@ -239,10 +239,10 @@ Sec-WebSocket-Version: 13 t.Errorf("handshake failed: %v", err) } if code != http.StatusSwitchingProtocols { - t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) + t.Errorf("status expected %d but got %d", http.StatusSwitchingProtocols, code) } if len(config.Protocol) != 0 { - t.Errorf("len(config.Protocol) expected 0, but got %q", len(config.Protocol)) + t.Errorf("len(config.Protocol) expected 0, but got %d", len(config.Protocol)) } b := bytes.NewBuffer([]byte{}) bw := bufio.NewWriter(b) @@ -285,7 +285,7 @@ Sec-WebSocket-Version: 9 t.Errorf("handshake expected err %q but got %q", ErrBadWebSocketVersion, err) } if code != http.StatusBadRequest { - t.Errorf("status expected %q but got %q", http.StatusBadRequest, code) + t.Errorf("status expected %d but got %d", http.StatusBadRequest, code) } } @@ -583,7 +583,7 @@ Sec-WebSocket-Version: 13 t.Errorf("handshake failed: %v", err) } if code != http.StatusSwitchingProtocols { - t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code) + t.Errorf("status expected %d but got %d", http.StatusSwitchingProtocols, code) } b := bytes.NewBuffer([]byte{}) bw := bufio.NewWriter(b) diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 2054ce85a6..1ba3827a77 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -323,7 +323,7 @@ func TestHTTP(t *testing.T) { return } if resp.StatusCode != http.StatusBadRequest { - t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode) + t.Errorf("Get: expected %d got %d", http.StatusBadRequest, resp.StatusCode) } }