Skip to content

Commit b1edd90

Browse files
committed
Initial commit.
0 parents  commit b1edd90

File tree

5 files changed

+863
-0
lines changed

5 files changed

+863
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
tlsrouter
2+
tlsrouter.test

config.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package main
2+
3+
import (
4+
"bufio"
5+
"fmt"
6+
"io"
7+
"os"
8+
"regexp"
9+
"strings"
10+
"sync"
11+
)
12+
13+
type Route struct {
14+
match *regexp.Regexp
15+
backend string
16+
}
17+
18+
// Config stores the TLS routing configuration.
19+
type Config struct {
20+
mu sync.Mutex
21+
routes []Route
22+
}
23+
24+
func dnsRegex(s string) (*regexp.Regexp, error) {
25+
return regexp.Compile(s)
26+
}
27+
28+
func (c *Config) Match(hostname string) string {
29+
c.mu.Lock()
30+
defer c.mu.Unlock()
31+
for _, r := range c.routes {
32+
if r.match.MatchString(hostname) {
33+
return r.backend
34+
}
35+
}
36+
return ""
37+
}
38+
39+
func (c *Config) Read(r io.Reader) error {
40+
var routes []Route
41+
42+
s := bufio.NewScanner(r)
43+
for s.Scan() {
44+
fs := strings.Fields(s.Text())
45+
switch len(fs) {
46+
case 0:
47+
continue
48+
case 1:
49+
return fmt.Errorf("invalid %q on a line by itself", s.Text())
50+
case 2:
51+
re, err := dnsRegex(fs[0])
52+
if err != nil {
53+
return err
54+
}
55+
routes = append(routes, Route{re, fs[1]})
56+
default:
57+
// TODO: multiple backends?
58+
return fmt.Errorf("too many fields on line: %q", s.Text())
59+
}
60+
}
61+
if err := s.Err(); err != nil {
62+
return err
63+
}
64+
65+
c.mu.Lock()
66+
defer c.mu.Unlock()
67+
c.routes = routes
68+
return nil
69+
}
70+
71+
func (c *Config) ReadFile(path string) error {
72+
f, err := os.Open(path)
73+
if err != nil {
74+
return err
75+
}
76+
return c.Read(f)
77+
}

main.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package main
2+
3+
import (
4+
"bytes"
5+
"flag"
6+
"fmt"
7+
"io"
8+
"log"
9+
"net"
10+
"sync"
11+
"time"
12+
)
13+
14+
var cfgFile = flag.String("conf", "", "configuration file")
15+
var listen = flag.String("listen", ":443", "listening port")
16+
17+
var config Config
18+
19+
func main() {
20+
flag.Parse()
21+
22+
if err := config.ReadFile(*cfgFile); err != nil {
23+
log.Fatalf("Failed to read config %q: %s", *cfgFile, err)
24+
}
25+
26+
l, err := net.Listen("tcp", *listen)
27+
if err != nil {
28+
log.Fatalf("Failed to listen: %s", err)
29+
}
30+
31+
for {
32+
c, err := l.Accept()
33+
if err != nil {
34+
log.Fatalf("Error while accepting: %s", err)
35+
}
36+
37+
conn := &Conn{TCPConn: c.(*net.TCPConn)}
38+
go conn.proxy()
39+
}
40+
}
41+
42+
type Conn struct {
43+
*net.TCPConn
44+
45+
tlsMinor int
46+
hostname string
47+
backend string
48+
backendConn *net.TCPConn
49+
}
50+
51+
func (c *Conn) log(msg string, args ...interface{}) {
52+
msg = fmt.Sprintf(msg, args...)
53+
log.Printf("%s <> %s: %s", c.RemoteAddr(), c.LocalAddr(), msg)
54+
}
55+
56+
func (c *Conn) abort(alert byte, msg string, args ...interface{}) {
57+
c.log(msg, args...)
58+
alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, alert}
59+
if _, err := c.Write(alertMsg); err != nil {
60+
c.log("error while sending alert: %s", err)
61+
}
62+
}
63+
64+
func (c *Conn) internalError(msg string, args ...interface{}) { c.abort(80, msg, args...) }
65+
func (c *Conn) sniFailed(msg string, args ...interface{}) { c.abort(112, msg, args...) }
66+
67+
func (c *Conn) proxy() {
68+
defer c.Close()
69+
70+
var (
71+
err error
72+
handshakeBuf bytes.Buffer
73+
)
74+
c.hostname, c.tlsMinor, err = extractSNI(io.TeeReader(c, &handshakeBuf))
75+
if err != nil {
76+
c.internalError("Extracting SNI: %s", err)
77+
return
78+
}
79+
80+
c.backend = config.Match(c.hostname)
81+
if c.backend == "" {
82+
c.sniFailed("no backend found for %q", c.hostname)
83+
return
84+
}
85+
86+
c.log("routing %q to %q", c.hostname, c.backend)
87+
backend, err := net.DialTimeout("tcp", c.backend, 10*time.Second)
88+
if err != nil {
89+
c.internalError("failed to dial backend %q for %q: %s", c.backend, c.hostname, err)
90+
return
91+
}
92+
defer backend.Close()
93+
94+
c.backendConn = backend.(*net.TCPConn)
95+
96+
// Replay the piece of the handshake we had to read to do the
97+
// routing, then blindly proxy any other bytes.
98+
if _, err = io.Copy(c.backendConn, &handshakeBuf); err != nil {
99+
c.internalError("failed to replay handshake to %q: %s", c.backend, err)
100+
return
101+
}
102+
103+
var wg sync.WaitGroup
104+
wg.Add(2)
105+
go proxy(&wg, c.TCPConn, c.backendConn)
106+
go proxy(&wg, c.backendConn, c.TCPConn)
107+
wg.Wait()
108+
}
109+
110+
func proxy(wg *sync.WaitGroup, a, b net.Conn) {
111+
defer wg.Done()
112+
atcp, btcp := a.(*net.TCPConn), b.(*net.TCPConn)
113+
if _, err := io.Copy(atcp, btcp); err != nil {
114+
log.Printf("%s<>%s -> %s<>%s: %s", atcp.RemoteAddr(), atcp.LocalAddr(), btcp.LocalAddr(), btcp.RemoteAddr(), err)
115+
}
116+
btcp.CloseWrite()
117+
atcp.CloseRead()
118+
}

0 commit comments

Comments
 (0)