@@ -16,13 +16,18 @@ package caddytls
1616
1717import (
1818 "crypto/tls"
19+ "fmt"
20+ "net"
21+ "strings"
1922
2023 "github.com/caddyserver/caddy/v2"
2124 "github.com/caddyserver/certmagic"
25+ "go.uber.org/zap"
2226)
2327
2428func init () {
2529 caddy .RegisterModule (MatchServerName {})
30+ caddy .RegisterModule (MatchRemoteIP {})
2631}
2732
2833// MatchServerName matches based on SNI. Names in
@@ -48,5 +53,100 @@ func (m MatchServerName) Match(hello *tls.ClientHelloInfo) bool {
4853 return false
4954}
5055
51- // Interface guard
52- var _ ConnectionMatcher = (* MatchServerName )(nil )
56+ // MatchRemoteIP matches based on the remote IP of the
57+ // connection. Specific IPs or CIDR ranges can be specified.
58+ //
59+ // Note that IPs can sometimes be spoofed, so do not rely
60+ // on this as a replacement for actual authentication.
61+ type MatchRemoteIP struct {
62+ // The IPs or CIDR ranges to match.
63+ Ranges []string `json:"ranges,omitempty"`
64+
65+ // The IPs or CIDR ranges to *NOT* match.
66+ NotRanges []string `json:"not_ranges,omitempty"`
67+
68+ cidrs []* net.IPNet
69+ notCidrs []* net.IPNet
70+ logger * zap.Logger
71+ }
72+
73+ // CaddyModule returns the Caddy module information.
74+ func (MatchRemoteIP ) CaddyModule () caddy.ModuleInfo {
75+ return caddy.ModuleInfo {
76+ ID : "tls.handshake_match.remote_ip" ,
77+ New : func () caddy.Module { return new (MatchRemoteIP ) },
78+ }
79+ }
80+
81+ // Provision parses m's IP ranges, either from IP or CIDR expressions.
82+ func (m * MatchRemoteIP ) Provision (ctx caddy.Context ) error {
83+ m .logger = ctx .Logger (m )
84+ for _ , str := range m .Ranges {
85+ cidrs , err := m .parseIPRange (str )
86+ if err != nil {
87+ return err
88+ }
89+ m .cidrs = cidrs
90+ }
91+ for _ , str := range m .NotRanges {
92+ cidrs , err := m .parseIPRange (str )
93+ if err != nil {
94+ return err
95+ }
96+ m .notCidrs = cidrs
97+ }
98+ return nil
99+ }
100+
101+ // Match matches hello based on the connection's remote IP.
102+ func (m MatchRemoteIP ) Match (hello * tls.ClientHelloInfo ) bool {
103+ remoteAddr := hello .Conn .RemoteAddr ().String ()
104+ ipStr , _ , err := net .SplitHostPort (remoteAddr )
105+ if err != nil {
106+ ipStr = remoteAddr // weird; maybe no port?
107+ }
108+ ip := net .ParseIP (ipStr )
109+ if ip == nil {
110+ m .logger .Error ("invalid client IP addresss" , zap .String ("ip" , ipStr ))
111+ return false
112+ }
113+ return (len (m .cidrs ) == 0 || m .matches (ip , m .cidrs )) &&
114+ (len (m .notCidrs ) == 0 || ! m .matches (ip , m .notCidrs ))
115+ }
116+
117+ func (MatchRemoteIP ) parseIPRange (str string ) ([]* net.IPNet , error ) {
118+ var cidrs []* net.IPNet
119+ if strings .Contains (str , "/" ) {
120+ _ , ipNet , err := net .ParseCIDR (str )
121+ if err != nil {
122+ return nil , fmt .Errorf ("parsing CIDR expression: %v" , err )
123+ }
124+ cidrs = append (cidrs , ipNet )
125+ } else {
126+ ip := net .ParseIP (str )
127+ if ip == nil {
128+ return nil , fmt .Errorf ("invalid IP address: %s" , str )
129+ }
130+ mask := len (ip ) * 8
131+ cidrs = append (cidrs , & net.IPNet {
132+ IP : ip ,
133+ Mask : net .CIDRMask (mask , mask ),
134+ })
135+ }
136+ return cidrs , nil
137+ }
138+
139+ func (MatchRemoteIP ) matches (ip net.IP , ranges []* net.IPNet ) bool {
140+ for _ , ipRange := range ranges {
141+ if ipRange .Contains (ip ) {
142+ return true
143+ }
144+ }
145+ return false
146+ }
147+
148+ // Interface guards
149+ var (
150+ _ ConnectionMatcher = (* MatchServerName )(nil )
151+ _ ConnectionMatcher = (* MatchRemoteIP )(nil )
152+ )
0 commit comments