Skip to content

Commit 46f809d

Browse files
committed
vhost: set DisableKeepAlives = false and fix websocket not work
1 parent c842558 commit 46f809d

File tree

3 files changed

+50
-20
lines changed

3 files changed

+50
-20
lines changed

pkg/util/util/http.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,4 @@ func hasPort(host string) bool {
7474
return true
7575
}
7676
return host[0] == '[' && strings.Contains(host, "]:")
77-
}
77+
}

pkg/util/vhost/http.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package vhost
1717
import (
1818
"bytes"
1919
"context"
20+
"encoding/base64"
2021
"errors"
2122
"fmt"
2223
"log"
@@ -59,20 +60,25 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
5960
req.URL.Scheme = "http"
6061
url := req.Context().Value(RouteInfoURL).(string)
6162
oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
62-
host := rp.GetRealHost(oldHost, url)
63-
if host != "" {
64-
req.Host = host
63+
rc := rp.GetRouteConfig(oldHost, url)
64+
if rc != nil {
65+
if rc.RewriteHost != "" {
66+
req.Host = rc.RewriteHost
67+
}
68+
// Set {domain}.{location} as URL host here to let http transport reuse connections.
69+
req.URL.Host = rc.Domain + "." + base64.StdEncoding.EncodeToString([]byte(rc.Location))
70+
71+
for k, v := range rc.Headers {
72+
req.Header.Set(k, v)
73+
}
74+
} else {
75+
req.URL.Host = req.Host
6576
}
66-
req.URL.Host = req.Host
6777

68-
headers := rp.GetHeaders(oldHost, url)
69-
for k, v := range headers {
70-
req.Header.Set(k, v)
71-
}
7278
},
7379
Transport: &http.Transport{
7480
ResponseHeaderTimeout: rp.responseHeaderTimeout,
75-
DisableKeepAlives: true,
81+
IdleConnTimeout: 60 * time.Second,
7682
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
7783
url := ctx.Value(RouteInfoURL).(string)
7884
host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
@@ -107,6 +113,14 @@ func (rp *HTTPReverseProxy) UnRegister(domain string, location string) {
107113
rp.vhostRouter.Del(domain, location)
108114
}
109115

116+
func (rp *HTTPReverseProxy) GetRouteConfig(domain string, location string) *RouteConfig {
117+
vr, ok := rp.getVhost(domain, location)
118+
if ok {
119+
return vr.payload.(*RouteConfig)
120+
}
121+
return nil
122+
}
123+
110124
func (rp *HTTPReverseProxy) GetRealHost(domain string, location string) (host string) {
111125
vr, ok := rp.getVhost(domain, location)
112126
if ok {

tests/ci/health/health_test.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ func TestHealthCheck(t *testing.T) {
139139
}
140140

141141
httpSvc3 := mock.NewHTTPServer(15005, func(w http.ResponseWriter, r *http.Request) {
142+
time.Sleep(time.Second)
142143
w.Write([]byte("http3"))
143144
})
144145
err = httpSvc3.Start()
@@ -147,6 +148,7 @@ func TestHealthCheck(t *testing.T) {
147148
}
148149

149150
httpSvc4 := mock.NewHTTPServer(15006, func(w http.ResponseWriter, r *http.Request) {
151+
time.Sleep(time.Second)
150152
w.Write([]byte("http4"))
151153
})
152154
err = httpSvc4.Start()
@@ -277,16 +279,30 @@ func TestHealthCheck(t *testing.T) {
277279

278280
// ****** load balancing type http ******
279281
result = make([]string, 0)
280-
281-
code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
282-
assert.NoError(err)
283-
assert.Equal(200, code)
284-
result = append(result, body)
285-
286-
code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
287-
assert.NoError(err)
288-
assert.Equal(200, code)
289-
result = append(result, body)
282+
var wait sync.WaitGroup
283+
var mu sync.Mutex
284+
wait.Add(2)
285+
286+
go func() {
287+
defer wait.Done()
288+
code, body, _, err := util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
289+
assert.NoError(err)
290+
assert.Equal(200, code)
291+
mu.Lock()
292+
result = append(result, body)
293+
mu.Unlock()
294+
}()
295+
296+
go func() {
297+
defer wait.Done()
298+
code, body, _, err = util.SendHTTPMsg("GET", "http://127.0.0.1:14000/xxx", "test.balancing.com", nil, "")
299+
assert.NoError(err)
300+
assert.Equal(200, code)
301+
mu.Lock()
302+
result = append(result, body)
303+
mu.Unlock()
304+
}()
305+
wait.Wait()
290306

291307
assert.Contains(result, "http3")
292308
assert.Contains(result, "http4")

0 commit comments

Comments
 (0)