Skip to content

Commit 70e9ffc

Browse files
authored
apiutil, middleware: strengthen the robustness of GetIPPortFromHTTPRequest function (tikv#6958)
close tikv#6957 - Improve `GetIPPortFromHTTPRequest` to ensure it could handle different host addresses. - Make middleware set the forwarded header correctly. Signed-off-by: JmPotato <ghzpotato@gmail.com>
1 parent 602c10d commit 70e9ffc

File tree

11 files changed

+204
-44
lines changed

11 files changed

+204
-44
lines changed

pkg/audit/audit_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func TestLocalLogBackendUsingFile(t *testing.T) {
103103
b, _ := os.ReadFile(fname)
104104
output := strings.SplitN(string(b), "]", 4)
105105
re.Equal(
106-
fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, "+
106+
fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, Port:, "+
107107
"StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n",
108108
time.Unix(info.StartTimeStamp, 0).String()),
109109
output[3],

pkg/utils/apiutil/apiutil.go

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ var (
4747
)
4848

4949
const (
50+
// PDRedirectorHeader is used to mark which PD redirected this request.
51+
PDRedirectorHeader = "PD-Redirector"
52+
// PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD.
53+
PDAllowFollowerHandleHeader = "PD-Allow-follower-handle"
54+
// XForwardedForHeader is used to mark the client IP.
55+
XForwardedForHeader = "X-Forwarded-For"
56+
// XForwardedPortHeader is used to mark the client port.
57+
XForwardedPortHeader = "X-Forwarded-Port"
58+
// XRealIPHeader is used to mark the real client IP.
59+
XRealIPHeader = "X-Real-Ip"
60+
5061
// ErrRedirectFailed is the error message for redirect failed.
5162
ErrRedirectFailed = "redirect failed"
5263
// ErrRedirectToNotLeader is the error message for redirect to not leader.
@@ -101,26 +112,30 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) {
101112
}
102113
}
103114

104-
// GetIPAddrFromHTTPRequest returns http client IP from context.
115+
// GetIPPortFromHTTPRequest returns http client host IP and port from context.
105116
// Because `X-Forwarded-For ` header has been written into RFC 7239(Forwarded HTTP Extension),
106117
// so `X-Forwarded-For` has the higher priority than `X-Real-IP`.
107118
// And both of them have the higher priority than `RemoteAddr`
108-
func GetIPAddrFromHTTPRequest(r *http.Request) string {
109-
ips := strings.Split(r.Header.Get("X-Forwarded-For"), ",")
110-
if len(strings.Trim(ips[0], " ")) > 0 {
111-
return ips[0]
112-
}
113-
114-
ip := r.Header.Get("X-Real-Ip")
115-
if ip != "" {
116-
return ip
119+
func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) {
120+
forwardedIPs := strings.Split(r.Header.Get(XForwardedForHeader), ",")
121+
if forwardedIP := strings.Trim(forwardedIPs[0], " "); len(forwardedIP) > 0 {
122+
ip = forwardedIP
123+
// Try to get the port from "X-Forwarded-Port" header.
124+
forwardedPorts := strings.Split(r.Header.Get(XForwardedPortHeader), ",")
125+
if forwardedPort := strings.Trim(forwardedPorts[0], " "); len(forwardedPort) > 0 {
126+
port = forwardedPort
127+
}
128+
} else if realIP := r.Header.Get(XRealIPHeader); len(realIP) > 0 {
129+
ip = realIP
130+
} else {
131+
ip = r.RemoteAddr
117132
}
118-
119-
ip, _, err := net.SplitHostPort(r.RemoteAddr)
133+
splitIP, splitPort, err := net.SplitHostPort(ip)
120134
if err != nil {
121-
return ""
135+
// Ensure we could get an IP address at least.
136+
return ip, port
122137
}
123-
return ip
138+
return splitIP, splitPort
124139
}
125140

126141
// GetComponentNameOnHTTP returns component name from Request Header

pkg/utils/apiutil/apiutil_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package apiutil
1717
import (
1818
"bytes"
1919
"io"
20+
"net/http"
2021
"net/http/httptest"
2122
"testing"
2223

@@ -68,3 +69,141 @@ func TestJsonRespondErrorBadInput(t *testing.T) {
6869
re.Equal(400, result.StatusCode)
6970
}
7071
}
72+
73+
func TestGetIPPortFromHTTPRequest(t *testing.T) {
74+
t.Parallel()
75+
re := require.New(t)
76+
77+
testCases := []struct {
78+
r *http.Request
79+
ip string
80+
port string
81+
err error
82+
}{
83+
// IPv4 "X-Forwarded-For" with port
84+
{
85+
r: &http.Request{
86+
Header: map[string][]string{
87+
XForwardedForHeader: {"127.0.0.1:5299"},
88+
},
89+
},
90+
ip: "127.0.0.1",
91+
port: "5299",
92+
},
93+
// IPv4 "X-Forwarded-For" without port
94+
{
95+
r: &http.Request{
96+
Header: map[string][]string{
97+
XForwardedForHeader: {"127.0.0.1"},
98+
XForwardedPortHeader: {"5299"},
99+
},
100+
},
101+
ip: "127.0.0.1",
102+
port: "5299",
103+
},
104+
// IPv4 "X-Real-IP" with port
105+
{
106+
r: &http.Request{
107+
Header: map[string][]string{
108+
XRealIPHeader: {"127.0.0.1:5299"},
109+
},
110+
},
111+
ip: "127.0.0.1",
112+
port: "5299",
113+
},
114+
// IPv4 "X-Real-IP" without port
115+
{
116+
r: &http.Request{
117+
Header: map[string][]string{
118+
XForwardedForHeader: {"127.0.0.1"},
119+
XForwardedPortHeader: {"5299"},
120+
},
121+
},
122+
ip: "127.0.0.1",
123+
port: "5299",
124+
},
125+
// IPv4 RemoteAddr with port
126+
{
127+
r: &http.Request{
128+
RemoteAddr: "127.0.0.1:5299",
129+
},
130+
ip: "127.0.0.1",
131+
port: "5299",
132+
},
133+
// IPv4 RemoteAddr without port
134+
{
135+
r: &http.Request{
136+
RemoteAddr: "127.0.0.1",
137+
},
138+
ip: "127.0.0.1",
139+
port: "",
140+
},
141+
// IPv6 "X-Forwarded-For" with port
142+
{
143+
r: &http.Request{
144+
Header: map[string][]string{
145+
XForwardedForHeader: {"[::1]:5299"},
146+
},
147+
},
148+
ip: "::1",
149+
port: "5299",
150+
},
151+
// IPv6 "X-Forwarded-For" without port
152+
{
153+
r: &http.Request{
154+
Header: map[string][]string{
155+
XForwardedForHeader: {"::1"},
156+
},
157+
},
158+
ip: "::1",
159+
port: "",
160+
},
161+
// IPv6 "X-Real-IP" with port
162+
{
163+
r: &http.Request{
164+
Header: map[string][]string{
165+
XRealIPHeader: {"[::1]:5299"},
166+
},
167+
},
168+
ip: "::1",
169+
port: "5299",
170+
},
171+
// IPv6 "X-Real-IP" without port
172+
{
173+
r: &http.Request{
174+
Header: map[string][]string{
175+
XForwardedForHeader: {"::1"},
176+
},
177+
},
178+
ip: "::1",
179+
port: "",
180+
},
181+
// IPv6 RemoteAddr with port
182+
{
183+
r: &http.Request{
184+
RemoteAddr: "[::1]:5299",
185+
},
186+
ip: "::1",
187+
port: "5299",
188+
},
189+
// IPv6 RemoteAddr without port
190+
{
191+
r: &http.Request{
192+
RemoteAddr: "::1",
193+
},
194+
ip: "::1",
195+
port: "",
196+
},
197+
// Abnormal case
198+
{
199+
r: &http.Request{},
200+
ip: "",
201+
port: "",
202+
},
203+
}
204+
for idx, testCase := range testCases {
205+
ip, port := GetIPPortFromHTTPRequest(testCase.r)
206+
re.Equal(testCase.ip, ip, "case %d", idx)
207+
re.Equal(testCase.port, port, "case %d", idx)
208+
}
209+
}

pkg/utils/apiutil/serverapi/middleware.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@ import (
2626
"go.uber.org/zap"
2727
)
2828

29-
// HTTP headers.
30-
const (
31-
PDRedirectorHeader = "PD-Redirector"
32-
PDAllowFollowerHandle = "PD-Allow-follower-handle"
33-
ForwardedForHeader = "X-Forwarded-For"
34-
)
35-
3629
type runtimeServiceValidator struct {
3730
s *server.Server
3831
group apiutil.APIServiceGroup
@@ -130,22 +123,31 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri
130123

131124
func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
132125
matchedFlag, targetAddr := h.matchMicroServiceRedirectRules(r)
133-
allowFollowerHandle := len(r.Header.Get(PDAllowFollowerHandle)) > 0
126+
allowFollowerHandle := len(r.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0
134127
isLeader := h.s.GetMember().IsLeader()
135128
if !h.s.IsClosed() && (allowFollowerHandle || isLeader) && !matchedFlag {
136129
next(w, r)
137130
return
138131
}
139132

140133
// Prevent more than one redirection.
141-
if name := r.Header.Get(PDRedirectorHeader); len(name) != 0 {
134+
if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 {
142135
log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect))
143136
http.Error(w, apiutil.ErrRedirectToNotLeader, http.StatusInternalServerError)
144137
return
145138
}
146139

147-
r.Header.Set(PDRedirectorHeader, h.s.Name())
148-
r.Header.Add(ForwardedForHeader, r.RemoteAddr)
140+
r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name())
141+
forwardedIP, forwardedPort := apiutil.GetIPPortFromHTTPRequest(r)
142+
if len(forwardedIP) > 0 {
143+
r.Header.Add(apiutil.XForwardedForHeader, forwardedIP)
144+
} else {
145+
// Fallback if GetIPPortFromHTTPRequest failed to get the IP.
146+
r.Header.Add(apiutil.XForwardedForHeader, r.RemoteAddr)
147+
}
148+
if len(forwardedPort) > 0 {
149+
r.Header.Add(apiutil.XForwardedPortHeader, forwardedPort)
150+
}
149151

150152
var clientUrls []string
151153
if matchedFlag {

pkg/utils/requestutil/request_info.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,27 @@ type RequestInfo struct {
3131
Method string
3232
Component string
3333
IP string
34+
Port string
3435
URLParam string
3536
BodyParam string
3637
StartTimeStamp int64
3738
}
3839

3940
func (info *RequestInfo) String() string {
40-
s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, StartTime:%s, URLParam:%s, BodyParam:%s}",
41-
info.ServiceLabel, info.Method, info.Component, info.IP, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam)
41+
s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}",
42+
info.ServiceLabel, info.Method, info.Component, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam)
4243
return s
4344
}
4445

4546
// GetRequestInfo returns request info needed from http.Request
4647
func GetRequestInfo(r *http.Request) RequestInfo {
48+
ip, port := apiutil.GetIPPortFromHTTPRequest(r)
4749
return RequestInfo{
4850
ServiceLabel: apiutil.GetRouteName(r),
4951
Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path),
5052
Component: apiutil.GetComponentNameOnHTTP(r),
51-
IP: apiutil.GetIPAddrFromHTTPRequest(r),
53+
IP: ip,
54+
Port: port,
5255
URLParam: getURLParam(r),
5356
BodyParam: getBodyParam(r),
5457
StartTimeStamp: time.Now().Unix(),

server/apiv2/middlewares/redirector.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import (
2222
"github.com/pingcap/log"
2323
"github.com/tikv/pd/pkg/errs"
2424
"github.com/tikv/pd/pkg/utils/apiutil"
25-
"github.com/tikv/pd/pkg/utils/apiutil/serverapi"
2625
"github.com/tikv/pd/server"
2726
"go.uber.org/zap"
2827
)
@@ -31,21 +30,21 @@ import (
3130
func Redirector() gin.HandlerFunc {
3231
return func(c *gin.Context) {
3332
svr := c.MustGet(ServerContextKey).(*server.Server)
34-
allowFollowerHandle := len(c.Request.Header.Get(serverapi.PDAllowFollowerHandle)) > 0
33+
allowFollowerHandle := len(c.Request.Header.Get(apiutil.PDAllowFollowerHandleHeader)) > 0
3534
isLeader := svr.GetMember().IsLeader()
3635
if !svr.IsClosed() && (allowFollowerHandle || isLeader) {
3736
c.Next()
3837
return
3938
}
4039

4140
// Prevent more than one redirection.
42-
if name := c.Request.Header.Get(serverapi.PDRedirectorHeader); len(name) != 0 {
41+
if name := c.Request.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 {
4342
log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", svr.Name()), errs.ZapError(errs.ErrRedirect))
4443
c.AbortWithStatusJSON(http.StatusInternalServerError, errs.ErrRedirect.FastGenByArgs().Error())
4544
return
4645
}
4746

48-
c.Request.Header.Set(serverapi.PDRedirectorHeader, svr.Name())
47+
c.Request.Header.Set(apiutil.PDRedirectorHeader, svr.Name())
4948

5049
leader := svr.GetMember().GetLeader()
5150
if leader == nil {

server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1744,7 +1744,7 @@ func (s *Server) ReplicateFileToMember(ctx context.Context, member *pdpb.Member,
17441744
}
17451745
url := clientUrls[0] + filepath.Join("/pd/api/v1/admin/persist-file", name)
17461746
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data))
1747-
req.Header.Set("PD-Allow-follower-handle", "true")
1747+
req.Header.Set(apiutil.PDAllowFollowerHandleHeader, "true")
17481748
res, err := s.httpClient.Do(req)
17491749
if err != nil {
17501750
log.Warn("failed to replicate file", zap.String("name", name), zap.String("member", member.GetName()), errs.ZapError(err))

server/server_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/stretchr/testify/require"
2525
"github.com/stretchr/testify/suite"
2626
"github.com/tikv/pd/pkg/mcs/utils"
27+
"github.com/tikv/pd/pkg/utils/apiutil"
2728
"github.com/tikv/pd/pkg/utils/assertutil"
2829
"github.com/tikv/pd/pkg/utils/etcdutil"
2930
"github.com/tikv/pd/pkg/utils/testutil"
@@ -218,7 +219,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderForwarded() {
218219

219220
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil)
220221
suite.NoError(err)
221-
req.Header.Add("X-Forwarded-For", "127.0.0.2")
222+
req.Header.Add(apiutil.XForwardedForHeader, "127.0.0.2")
222223
resp, err := http.DefaultClient.Do(req)
223224
suite.NoError(err)
224225
suite.Equal(http.StatusOK, resp.StatusCode)
@@ -248,7 +249,7 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderXReal() {
248249

249250
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil)
250251
suite.NoError(err)
251-
req.Header.Add("X-Real-Ip", "127.0.0.2")
252+
req.Header.Add(apiutil.XRealIPHeader, "127.0.0.2")
252253
resp, err := http.DefaultClient.Do(req)
253254
suite.NoError(err)
254255
suite.Equal(http.StatusOK, resp.StatusCode)
@@ -278,8 +279,8 @@ func (suite *leaderServerTestSuite) TestSourceIpForHeaderBoth() {
278279

279280
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/apis/mock/v1/hello", svr.GetAddr()), nil)
280281
suite.NoError(err)
281-
req.Header.Add("X-Forwarded-For", "127.0.0.2")
282-
req.Header.Add("X-Real-Ip", "127.0.0.3")
282+
req.Header.Add(apiutil.XForwardedForHeader, "127.0.0.2")
283+
req.Header.Add(apiutil.XRealIPHeader, "127.0.0.3")
283284
resp, err := http.DefaultClient.Do(req)
284285
suite.NoError(err)
285286
suite.Equal(http.StatusOK, resp.StatusCode)

server/testutil.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func CreateMockHandler(re *require.Assertions, ip string) HandlerBuilder {
143143
mux.HandleFunc("/pd/apis/mock/v1/hello", func(w http.ResponseWriter, r *http.Request) {
144144
fmt.Fprintln(w, "Hello World")
145145
// test getting ip
146-
clientIP := apiutil.GetIPAddrFromHTTPRequest(r)
146+
clientIP, _ := apiutil.GetIPPortFromHTTPRequest(r)
147147
re.Equal(ip, clientIP)
148148
})
149149
info := apiutil.APIServiceGroup{

0 commit comments

Comments
 (0)