diff --git a/acme/acme.go b/acme/acme.go index 7a51284f91..b53ea28891 100644 --- a/acme/acme.go +++ b/acme/acme.go @@ -690,7 +690,7 @@ func (c *Client) addNonce(h http.Header) { } func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) { - r, err := http.NewRequest("HEAD", url, nil) + r, err := http.NewRequestWithContext(ctx, "HEAD", url, nil) if err != nil { return "", err } diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go index ccd5b7e3a1..cde9066f6d 100644 --- a/acme/autocert/autocert.go +++ b/acme/autocert/autocert.go @@ -134,7 +134,8 @@ type Manager struct { // RenewBefore optionally specifies how early certificates should // be renewed before they expire. // - // If zero, they're renewed 30 days before expiration. + // If zero, they're renewed at the lesser of 30 days or + // 1/3 of the certificate lifetime. RenewBefore time.Duration // Client is used to perform low-level operations, such as account registration @@ -464,7 +465,7 @@ func (m *Manager) cert(ctx context.Context, ck certKey) (*tls.Certificate, error leaf: cert.Leaf, } m.state[ck] = s - m.startRenew(ck, s.key, s.leaf.NotAfter) + m.startRenew(ck, s.key, s.leaf.NotBefore, s.leaf.NotAfter) return cert, nil } @@ -610,7 +611,7 @@ func (m *Manager) createCert(ctx context.Context, ck certKey) (*tls.Certificate, } state.cert = der state.leaf = leaf - m.startRenew(ck, state.key, state.leaf.NotAfter) + m.startRenew(ck, state.key, state.leaf.NotBefore, state.leaf.NotAfter) return state.tlscert() } @@ -908,7 +909,7 @@ func httpTokenCacheKey(tokenPath string) string { // // The key argument is a certificate private key. // The exp argument is the cert expiration time (NotAfter). -func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) { +func (m *Manager) startRenew(ck certKey, key crypto.Signer, notBefore, notAfter time.Time) { m.renewalMu.Lock() defer m.renewalMu.Unlock() if m.renewal[ck] != nil { @@ -920,7 +921,7 @@ func (m *Manager) startRenew(ck certKey, key crypto.Signer, exp time.Time) { } dr := &domainRenewal{m: m, ck: ck, key: key} m.renewal[ck] = dr - dr.start(exp) + dr.start(notBefore, notAfter) } // stopRenew stops all currently running cert renewal timers. @@ -1028,13 +1029,6 @@ func (m *Manager) hostPolicy() HostPolicy { return defaultHostPolicy } -func (m *Manager) renewBefore() time.Duration { - if m.RenewBefore > renewJitter { - return m.RenewBefore - } - return 720 * time.Hour // 30 days -} - func (m *Manager) now() time.Time { if m.nowFunc != nil { return m.nowFunc() diff --git a/acme/autocert/renewal.go b/acme/autocert/renewal.go index 0df7da78a6..93984f3866 100644 --- a/acme/autocert/renewal.go +++ b/acme/autocert/renewal.go @@ -11,9 +11,6 @@ import ( "time" ) -// renewJitter is the maximum deviation from Manager.RenewBefore. -const renewJitter = time.Hour - // domainRenewal tracks the state used by the periodic timers // renewing a single domain's cert. type domainRenewal struct { @@ -30,13 +27,13 @@ type domainRenewal struct { // defined by the certificate expiration time exp. // // If the timer is already started, calling start is a noop. -func (dr *domainRenewal) start(exp time.Time) { +func (dr *domainRenewal) start(notBefore, notAfter time.Time) { dr.timerMu.Lock() defer dr.timerMu.Unlock() if dr.timer != nil { return } - dr.timer = time.AfterFunc(dr.next(exp), dr.renew) + dr.timer = time.AfterFunc(dr.next(notBefore, notAfter), dr.renew) } // stop stops the cert renewal timer and waits for any in-flight calls to renew @@ -79,7 +76,7 @@ func (dr *domainRenewal) renew() { // TODO: rotate dr.key at some point? next, err := dr.do(ctx) if err != nil { - next = renewJitter / 2 + next = time.Hour / 2 next += time.Duration(pseudoRand.int63n(int64(next))) } testDidRenewLoop(next, err) @@ -107,8 +104,8 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { // a race is likely unavoidable in a distributed environment // but we try nonetheless if tlscert, err := dr.m.cacheGet(ctx, dr.ck); err == nil { - next := dr.next(tlscert.Leaf.NotAfter) - if next > dr.m.renewBefore()+renewJitter { + next := dr.next(tlscert.Leaf.NotBefore, tlscert.Leaf.NotAfter) + if next > 0 { signer, ok := tlscert.PrivateKey.(crypto.Signer) if ok { state := &certState{ @@ -139,18 +136,23 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { return 0, err } dr.updateState(state) - return dr.next(leaf.NotAfter), nil + return dr.next(leaf.NotBefore, leaf.NotAfter), nil } -func (dr *domainRenewal) next(expiry time.Time) time.Duration { - d := expiry.Sub(dr.m.now()) - dr.m.renewBefore() - // add a bit of randomness to renew deadline - n := pseudoRand.int63n(int64(renewJitter)) - d -= time.Duration(n) - if d < 0 { - return 0 +// next returns the wait time before the next renewal should start. +// If manager.RenewBefore is set, it uses that capped at 30 days, +// otherwise it uses a default of 1/3 of the cert lifetime. +// It builds in a jitter of 10% of the renew threshold, capped at 1 hour. +func (dr *domainRenewal) next(notBefore, notAfter time.Time) time.Duration { + threshold := min(notAfter.Sub(notBefore)/3, 30*24*time.Hour) + if dr.m.RenewBefore > 0 { + threshold = min(dr.m.RenewBefore, 30*24*time.Hour) } - return d + maxJitter := min(threshold/10, time.Hour) + jitter := pseudoRand.int63n(int64(maxJitter)) + renewAt := notAfter.Add(-(threshold - time.Duration(jitter))) + renewWait := renewAt.Sub(dr.m.now()) + return max(0, renewWait) } var testDidRenewLoop = func(next time.Duration, err error) {} diff --git a/acme/autocert/renewal_test.go b/acme/autocert/renewal_test.go index ffe4af2a5c..67e2da2e06 100644 --- a/acme/autocert/renewal_test.go +++ b/acme/autocert/renewal_test.go @@ -17,27 +17,60 @@ import ( func TestRenewalNext(t *testing.T) { now := time.Now() - man := &Manager{ - RenewBefore: 7 * 24 * time.Hour, - nowFunc: func() time.Time { return now }, - } - defer man.stopRenew() + nowFn := func() time.Time { return now } tt := []struct { - expiry time.Time - min, max time.Duration + name string + renewBefore time.Duration // arg to Manager + // leaf cert validity + notBefore time.Time + validFor time.Duration + // wait time + waitMin, waitMax time.Duration }{ - {now.Add(90 * 24 * time.Hour), 83*24*time.Hour - renewJitter, 83 * 24 * time.Hour}, - {now.Add(time.Hour), 0, 1}, - {now, 0, 1}, - {now.Add(-time.Hour), 0, 1}, + {"default renewal, 1h cert, valid", + 0, now, time.Hour, 40 * time.Minute, 50 * time.Minute}, + {"default renewal, 1h cert, should renew", + 0, now.Add(-50 * time.Minute), time.Hour, 0, 0}, + {"default renewal, 1h cert, expired", + 0, now.Add(-400 * 24 * time.Hour), time.Hour, 0, 0}, + {"default renewal, 6d cert, valid", + 0, now, 6 * 24 * time.Hour, 4 * 24 * time.Hour, (4*24 + 1) * time.Hour}, + {"default renewal, 6d cert, should renew", + 0, now.Add(-5 * 24 * time.Hour), 6 * 24 * time.Hour, 0, 0}, + {"default renewal, 6d cert, expired", + 0, now.Add(-400 * 24 * time.Hour), 6 * 24 * time.Hour, 0, 0}, + {"default renewal, 90d cert, valid", + 0, now, 90 * 24 * time.Hour, 60 * 24 * time.Hour, (60*24 + 1) * time.Hour}, + {"default renewal, 90d cert, should renew", + 0, now.Add(-70 * 24 * time.Hour), 90 * 24 * time.Hour, 0, 0}, + {"default renewal, 90d cert, expired", + 0, now.Add(-400 * 24 * time.Hour), 90 * 24 * time.Hour, 0, 0}, + {"default renewal, 398d cert, valid", + 0, now, 398 * 24 * time.Hour, (368 * 24) * time.Hour, (368*24 + 1) * time.Hour}, + {"default renewal, 398d cert, should renew", + 0, now.Add(-378 * 24 * time.Hour), 398 * 24 * time.Hour, 0, 0}, + {"default renewal, 398d cert, expired", + 0, now.Add(-400 * 24 * time.Hour), 398 * 24 * time.Hour, 0, 0}, + {"7d renewal, 90d cert, valid", + 7 * 24 * time.Hour, now, 90 * 24 * time.Hour, 83 * 24 * time.Hour, (83*24 + 1) * time.Hour}, + {"7d renewal, 90d cert, should not renew", + 7 * 24 * time.Hour, now.Add(-70 * 24 * time.Hour), 90 * 24 * time.Hour, 13 * 24 * time.Hour, (13*24 + 1) * time.Hour}, + {"7d renewal, 90d cert, should renew", + 7 * 24 * time.Hour, now.Add(-85 * 24 * time.Hour), 90 * 24 * time.Hour, 0, 0}, + {"7d renewal, 90d cert, expired", + 7 * 24 * time.Hour, now.Add(-400 * 24 * time.Hour), 90 * 24 * time.Hour, 0, 0}, } - dr := &domainRenewal{m: man} - for i, test := range tt { - next := dr.next(test.expiry) - if next < test.min || test.max < next { - t.Errorf("%d: next = %v; want between %v and %v", i, next, test.min, test.max) - } + for _, test := range tt { + t.Run(test.name, func(t *testing.T) { + dr := &domainRenewal{m: &Manager{RenewBefore: test.renewBefore, nowFunc: nowFn}} + defer dr.m.stopRenew() + + next := dr.next(test.notBefore, test.notBefore.Add(test.validFor)) + if next < test.waitMin || next > test.waitMax { + t.Errorf("expected wait time: %v <= %v <= %v", test.waitMin, next, test.waitMax) + } + }) } } @@ -239,7 +272,7 @@ func TestRenewFromCacheAlreadyRenewed(t *testing.T) { } // trigger renew - man.startRenew(exampleCertKey, s.key, s.leaf.NotAfter) + man.startRenew(exampleCertKey, s.key, s.leaf.NotBefore, s.leaf.NotAfter) <-renewed func() { man.renewalMu.Lock() diff --git a/acme/http.go b/acme/http.go index 8f29df56ee..7d1052acd4 100644 --- a/acme/http.go +++ b/acme/http.go @@ -128,7 +128,7 @@ func wantStatus(codes ...int) resOkay { func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) { retry := c.retryTimer() for { - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err } @@ -228,7 +228,7 @@ func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string, if err != nil { return nil, nil, err } - req, err := http.NewRequest("POST", url, bytes.NewReader(b)) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(b)) if err != nil { return nil, nil, err } diff --git a/go.mod b/go.mod index 7c5b2e95ae..ed7433125c 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module golang.org/x/crypto go 1.24.0 require ( - golang.org/x/net v0.46.0 // tagx:ignore + golang.org/x/net v0.47.0 // tagx:ignore golang.org/x/sys v0.38.0 golang.org/x/term v0.37.0 ) diff --git a/go.sum b/go.sum index 69212f3190..3a0b108e1d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= diff --git a/ssh/agent/server.go b/ssh/agent/server.go index 88ce4da6c4..4e8ff86b61 100644 --- a/ssh/agent/server.go +++ b/ssh/agent/server.go @@ -203,6 +203,9 @@ func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse for len(constraints) != 0 { switch constraints[0] { case agentConstrainLifetime: + if len(constraints) < 5 { + return 0, false, nil, io.ErrUnexpectedEOF + } lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5]) constraints = constraints[5:] case agentConstrainConfirm: diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go index 7700d18f1a..6309e2d9ab 100644 --- a/ssh/agent/server_test.go +++ b/ssh/agent/server_test.go @@ -8,6 +8,7 @@ import ( "crypto" "crypto/rand" "fmt" + "io" pseudorand "math/rand" "reflect" "strings" @@ -258,6 +259,12 @@ func TestParseConstraints(t *testing.T) { t.Errorf("got extension %v, want %v", extensions, expect) } + // Test Malformed Constraint + _, _, _, err = parseConstraints([]byte{1}) + if err != io.ErrUnexpectedEOF { + t.Errorf("got %v, want %v", err, io.ErrUnexpectedEOF) + } + // Test Unknown Constraint _, _, _, err = parseConstraints([]byte{128}) if err == nil || !strings.Contains(err.Error(), "unknown constraint") { diff --git a/ssh/keys.go b/ssh/keys.go index a035956fcc..47a07539d9 100644 --- a/ssh/keys.go +++ b/ssh/keys.go @@ -1490,6 +1490,7 @@ type openSSHEncryptedPrivateKey struct { NumKeys uint32 PubKey []byte PrivKeyBlock []byte + Rest []byte `ssh:"rest"` } type openSSHPrivateKey struct { diff --git a/ssh/keys_test.go b/ssh/keys_test.go index 661e3cb31c..a1165ec68b 100644 --- a/ssh/keys_test.go +++ b/ssh/keys_test.go @@ -271,6 +271,21 @@ func TestParseEncryptedPrivateKeysWithPassphrase(t *testing.T) { } } +func TestParseEncryptedPrivateKeysWithUnsupportedCiphers(t *testing.T) { + for _, tt := range testdata.UnsupportedCipherData { + t.Run(tt.Name, func(t *testing.T){ + _, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey)) + if err == nil { + t.Fatalf("expected 'unknown cipher' error for %q, got nil", tt.Name) + // If this cipher is now supported, remove it from testdata.UnsupportedCipherData + } + if !strings.Contains(err.Error(), "unknown cipher") { + t.Errorf("wanted 'unknown cipher' error, got %v", err.Error()) + } + }) + } +} + func TestParseEncryptedPrivateKeysWithIncorrectPassphrase(t *testing.T) { pem := testdata.PEMEncryptedKeys[0].PEMBytes for i := 0; i < 4096; i++ { diff --git a/ssh/ssh_gss.go b/ssh/ssh_gss.go index 24bd7c8e83..a6249a1227 100644 --- a/ssh/ssh_gss.go +++ b/ssh/ssh_gss.go @@ -106,6 +106,13 @@ func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) { if !ok { return nil, errors.New("parse uint32 failed") } + // Each ASN.1 encoded OID must have a minimum + // of 2 bytes; 64 maximum mechanisms is an + // arbitrary, but reasonable ceiling. + const maxMechs = 64 + if n > maxMechs || int(n)*2 > len(rest) { + return nil, errors.New("invalid mechanism count") + } s := &userAuthRequestGSSAPI{ N: n, OIDS: make([]asn1.ObjectIdentifier, n), @@ -122,7 +129,6 @@ func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) { if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil { return nil, err } - } return s, nil } diff --git a/ssh/ssh_gss_test.go b/ssh/ssh_gss_test.go index 39a111288a..9e3ea8c22c 100644 --- a/ssh/ssh_gss_test.go +++ b/ssh/ssh_gss_test.go @@ -17,6 +17,37 @@ func TestParseGSSAPIPayload(t *testing.T) { } } +func TestParseDubiousGSSAPIPayload(t *testing.T) { + for _, tc := range []struct { + name string + payload []byte + wanterr bool + }{ + { + "num mechanisms is unrealistic", + []byte{0xFF, 0x00, 0x00, 0xFF, + 0x00, 0x00, 0x00, 0x0b, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x12, 0x01, 0x02, 0x02}, + true, + }, + { + "num mechanisms greater than payload", + []byte{0x00, 0x00, 0x00, 0x40, // 64, |rest| too small + 0x00, 0x00, 0x00, 0x0b, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x12, 0x01, 0x02, 0x02}, + true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := parseGSSAPIPayload(tc.payload) + if tc.wanterr && err == nil { + t.Errorf("got nil, want error") + } + if !tc.wanterr && err != nil { + t.Errorf("got %v, want nil", err) + } + }) + } +} + func TestBuildMIC(t *testing.T) { sessionID := []byte{134, 180, 134, 194, 62, 145, 171, 82, 119, 149, 254, 196, 125, 173, 177, 145, 187, 85, 53, 183, 44, 150, 219, 129, 166, 195, 19, 33, 209, 246, 175, 121} diff --git a/ssh/streamlocal.go b/ssh/streamlocal.go index b171b330bc..152470fcb7 100644 --- a/ssh/streamlocal.go +++ b/ssh/streamlocal.go @@ -44,7 +44,7 @@ func (c *Client) ListenUnix(socketPath string) (net.Listener, error) { if !ok { return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer") } - ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"}) + ch := c.forwards.add("unix", socketPath) return &unixListener{socketPath, c, ch}, nil } @@ -96,7 +96,7 @@ func (l *unixListener) Accept() (net.Conn, error) { // Close closes the listener. func (l *unixListener) Close() error { // this also closes the listener. - l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"}) + l.conn.forwards.remove("unix", l.socketPath) m := streamLocalChannelForwardMsg{ l.socketPath, } diff --git a/ssh/tcpip.go b/ssh/tcpip.go index 93d844f035..78c41fe5a1 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -11,6 +11,7 @@ import ( "io" "math/rand" "net" + "net/netip" "strconv" "strings" "sync" @@ -22,14 +23,21 @@ import ( // the returned net.Listener. The listener must be serviced, or the // SSH connection may hang. // N must be "tcp", "tcp4", "tcp6", or "unix". +// +// If the address is a hostname, it is sent to the remote peer as-is, without +// being resolved locally, and the Listener Addr method will return a zero IP. func (c *Client) Listen(n, addr string) (net.Listener, error) { switch n { case "tcp", "tcp4", "tcp6": - laddr, err := net.ResolveTCPAddr(n, addr) + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.ParseInt(portStr, 10, 32) if err != nil { return nil, err } - return c.ListenTCP(laddr) + return c.listenTCPInternal(host, int(port)) case "unix": return c.ListenUnix(addr) default: @@ -102,15 +110,24 @@ func (c *Client) handleForwards() { // ListenTCP requests the remote peer open a listening socket // on laddr. Incoming connections will be available by calling // Accept on the returned net.Listener. +// +// ListenTCP accepts an IP address, to provide a hostname use [Client.Listen] +// with "tcp", "tcp4", or "tcp6" network instead. func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { c.handleForwardsOnce.Do(c.handleForwards) if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { return c.autoPortListenWorkaround(laddr) } + return c.listenTCPInternal(laddr.IP.String(), laddr.Port) +} + +func (c *Client) listenTCPInternal(host string, port int) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + m := channelForwardMsg{ - laddr.IP.String(), - uint32(laddr.Port), + host, + uint32(port), } // send message ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) @@ -123,20 +140,33 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { // If the original port was 0, then the remote side will // supply a real port number in the response. - if laddr.Port == 0 { + if port == 0 { var p struct { Port uint32 } if err := Unmarshal(resp, &p); err != nil { return nil, err } - laddr.Port = int(p.Port) + port = int(p.Port) } + // Construct a local address placeholder for the remote listener. If the + // original host is an IP address, preserve it so that Listener.Addr() + // reports the same IP. If the host is a hostname or cannot be parsed as an + // IP, fall back to IPv4zero. The port field is always set, even if the + // original port was 0, because in that case the remote server will assign + // one, allowing callers to determine which port was selected. + ip := net.IPv4zero + if parsed, err := netip.ParseAddr(host); err == nil { + ip = net.IP(parsed.AsSlice()) + } + laddr := &net.TCPAddr{ + IP: ip, + Port: port, + } + addr := net.JoinHostPort(host, strconv.FormatInt(int64(port), 10)) + ch := c.forwards.add("tcp", addr) - // Register this forward, using the port number we obtained. - ch := c.forwards.add(laddr) - - return &tcpListener{laddr, c, ch}, nil + return &tcpListener{laddr, addr, c, ch}, nil } // forwardList stores a mapping between remote @@ -149,8 +179,9 @@ type forwardList struct { // forwardEntry represents an established mapping of a laddr on a // remote ssh server to a channel connected to a tcpListener. type forwardEntry struct { - laddr net.Addr - c chan forward + addr string // host:port or socket path + network string // tcp or unix + c chan forward } // forward represents an incoming forwarded tcpip connection. The @@ -161,12 +192,13 @@ type forward struct { raddr net.Addr // the raddr of the incoming connection } -func (l *forwardList) add(addr net.Addr) chan forward { +func (l *forwardList) add(n, addr string) chan forward { l.Lock() defer l.Unlock() f := forwardEntry{ - laddr: addr, - c: make(chan forward, 1), + addr: addr, + network: n, + c: make(chan forward, 1), } l.entries = append(l.entries, f) return f.c @@ -185,19 +217,20 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { if port == 0 || port > 65535 { return nil, fmt.Errorf("ssh: port number out of range: %d", port) } - ip := net.ParseIP(string(addr)) - if ip == nil { + ip, err := netip.ParseAddr(addr) + if err != nil { return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) } - return &net.TCPAddr{IP: ip, Port: int(port)}, nil + return &net.TCPAddr{IP: net.IP(ip.AsSlice()), Port: int(port)}, nil } func (l *forwardList) handleChannels(in <-chan NewChannel) { for ch := range in { var ( - laddr net.Addr - raddr net.Addr - err error + addr string + network string + raddr net.Addr + err error ) switch channelType := ch.ChannelType(); channelType { case "forwarded-tcpip": @@ -207,40 +240,34 @@ func (l *forwardList) handleChannels(in <-chan NewChannel) { continue } - // RFC 4254 section 7.2 specifies that incoming - // addresses should list the address, in string - // format. It is implied that this should be an IP - // address, as it would be impossible to connect to it - // otherwise. - laddr, err = parseTCPAddr(payload.Addr, payload.Port) - if err != nil { - ch.Reject(ConnectionFailed, err.Error()) - continue - } + // RFC 4254 section 7.2 specifies that incoming addresses should + // list the address that was connected, in string format. It is the + // same address used in the tcpip-forward request. The originator + // address is an IP address instead. + addr = net.JoinHostPort(payload.Addr, strconv.FormatUint(uint64(payload.Port), 10)) + raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort) if err != nil { ch.Reject(ConnectionFailed, err.Error()) continue } - + network = "tcp" case "forwarded-streamlocal@openssh.com": var payload forwardedStreamLocalPayload if err = Unmarshal(ch.ExtraData(), &payload); err != nil { ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error()) continue } - laddr = &net.UnixAddr{ - Name: payload.SocketPath, - Net: "unix", - } + addr = payload.SocketPath raddr = &net.UnixAddr{ Name: "@", Net: "unix", } + network = "unix" default: panic(fmt.Errorf("ssh: unknown channel type %s", channelType)) } - if ok := l.forward(laddr, raddr, ch); !ok { + if ok := l.forward(network, addr, raddr, ch); !ok { // Section 7.2, implementations MUST reject spurious incoming // connections. ch.Reject(Prohibited, "no forward for address") @@ -252,11 +279,11 @@ func (l *forwardList) handleChannels(in <-chan NewChannel) { // remove removes the forward entry, and the channel feeding its // listener. -func (l *forwardList) remove(addr net.Addr) { +func (l *forwardList) remove(n, addr string) { l.Lock() defer l.Unlock() for i, f := range l.entries { - if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() { + if n == f.network && addr == f.addr { l.entries = append(l.entries[:i], l.entries[i+1:]...) close(f.c) return @@ -274,11 +301,11 @@ func (l *forwardList) closeAll() { l.entries = nil } -func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool { +func (l *forwardList) forward(n, addr string, raddr net.Addr, ch NewChannel) bool { l.Lock() defer l.Unlock() for _, f := range l.entries { - if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() { + if n == f.network && addr == f.addr { f.c <- forward{newCh: ch, raddr: raddr} return true } @@ -288,6 +315,7 @@ func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool { type tcpListener struct { laddr *net.TCPAddr + addr string conn *Client in <-chan forward @@ -314,13 +342,21 @@ func (l *tcpListener) Accept() (net.Conn, error) { // Close closes the listener. func (l *tcpListener) Close() error { + host, port, err := net.SplitHostPort(l.addr) + if err != nil { + return err + } + rport, err := strconv.ParseUint(port, 10, 32) + if err != nil { + return err + } m := channelForwardMsg{ - l.laddr.IP.String(), - uint32(l.laddr.Port), + host, + uint32(rport), } // this also closes the listener. - l.conn.forwards.remove(l.laddr) + l.conn.forwards.remove("tcp", l.addr) ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) if err == nil && !ok { err = errors.New("ssh: cancel-tcpip-forward failed") diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go index c10d1d02a6..549d46cef4 100644 --- a/ssh/test/forward_unix_test.go +++ b/ssh/test/forward_unix_test.go @@ -51,6 +51,8 @@ func testPortForward(t *testing.T, n, listenAddr string) { } }() + // The forwarded address match the listen address because we run the tests + // on the same host. forwardedAddr := sshListener.Addr().String() netConn, err := net.Dial(n, forwardedAddr) if err != nil { @@ -111,6 +113,8 @@ func testPortForward(t *testing.T, n, listenAddr string) { } func TestPortForwardTCP(t *testing.T) { + testPortForward(t, "tcp", ":0") + testPortForward(t, "tcp", "[::]:0") testPortForward(t, "tcp", "localhost:0") } diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go index 6e48841b67..adb4244eb3 100644 --- a/ssh/testdata/keys.go +++ b/ssh/testdata/keys.go @@ -310,6 +310,53 @@ gbDGyT3bXMQtagvCwoW+/oMTKXiZP5jCJpEO8= }, } +var UnsupportedCipherData = []struct { + Name string + EncryptionKey string + PEMBytes []byte +} { + 0: { + Name: "ed25519-encrypted-chacha20-poly1305", + EncryptionKey: "password", + PEMBytes: []byte(`-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAAHWNoYWNoYTIwLXBvbHkxMzA1QG9wZW5zc2guY29tAAAABm +JjcnlwdAAAABgAAAAQdPyPIjXDRAVHskY0yp9SWwAAAGQAAAABAAAAMwAAAAtzc2gtZWQy +NTUxOQAAACBi6qXITEUrmNce/c2lfozxALlKH3o/6sll8G7wzl1lvQAAAJDNlW1sEkvnK0 +8EecF1vHdPk85yClbh3KkHv09mbGAX/Gk6cJpYEGgJSkO7OEF4kG9DVGGd17+TZbTnM4LD +vYAJZExx2XLgJFEtHCVmJjYzwxx7yC7+s6u/XjrSlZS60RHunOPKyq+C+s48sejXvmX+t5 +0ZoVCI8aftT0ycis3gvLU9sCwJ2UnF6kAV226Z4g2aLkuJbgCDTEcYCRD64K1r +-----END OPENSSH PRIVATE KEY----- +`), + }, + 1: { + Name: "ed25519-encrypted-aes128-gcm", + EncryptionKey: "password", + PEMBytes: []byte(`-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAAFmFlczEyOC1nY21Ab3BlbnNzaC5jb20AAAAGYmNyeXB0AA +AAGAAAABBeMJIOqiyFwNCvDv6f8tQeAAAAZAAAAAEAAAAzAAAAC3NzaC1lZDI1NTE5AAAA +IGYpUcb3tGp9kF6pppcUdq3EPMr85BaSUdhiXGbhS5YNAAAAkNBtMEu0UlLgToThuQc+4m +/o0DfFIERu0sspQivn5RJHCtulVKfU9BMiEnF0+LOMOABMlYesgLOtoMxwm4ZCSWH54kZk +vaFyyvvxY+RLDuWNQZCryffIA4+iLCUQR1EdxMDiJweKnGJuD64a+9xTJt47A3Vq4SYzji +EuVmM0FqS8lbT2ynYSe3va0Qyw13jEO5qbtCuyG+C5GejL7kX4Z64= +-----END OPENSSH PRIVATE KEY----- +`), + }, + 2: { + Name: "ed25519-encrypted-aes256-gcm", + EncryptionKey: "password", + PEMBytes: []byte(`-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAAFmFlczI1Ni1nY21Ab3BlbnNzaC5jb20AAAAGYmNyeXB0AA +AAGAAAABBR1p3vH2Wr/HPL+q20L2rjAAAAZAAAAAEAAAAzAAAAC3NzaC1lZDI1NTE5AAAA +IM3tT1xrAuOHcrBdoLRo/ojWZsAw2lHfF5hJgFEOts5MAAAAkH/YGrDhDw8u+F8e4P+84B +tAzvp55Lf1Yl7y34BrVmqlWqw/7boqahOp6iYJHNpcuanzc5T6s7Z3wSSYodbY1uvFOfbj +rtP6rIHQIY5J2C40WOYJN8IkZlkwDXwZY0qoE9699ZYmWdwsXRZ7QDhjd2W8ziyZBsttiB +kv2ceuJMLT04TrKc2+RUkj4CQYnz7p8EkgZlUozx8wBSxKFGnkP7k= +-----END OPENSSH PRIVATE KEY----- +`), + }, +} + + // SKData contains a list of PubKeys backed by U2F/FIDO2 Security Keys and their test data. var SKData = []struct { Name string