Skip to content

Commit 2898549

Browse files
committed
Basic prototype
1 parent 7582abb commit 2898549

File tree

12 files changed

+555
-285
lines changed

12 files changed

+555
-285
lines changed

tcpforward/balancer.go renamed to balancer/balancer.go

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package tcpforward
1+
package balancer
22

33
import (
44
"encoding/json"
@@ -7,9 +7,11 @@ import (
77
"net"
88
"net/http"
99
"sync"
10-
"github.com/abiosoft/dockward/util"
1110
)
1211

12+
// TODO make dynamic
13+
const EndpointPort = 9923
14+
1315
type Message struct {
1416
Endpoint Endpoint
1517
Remove bool
@@ -22,6 +24,13 @@ type Balancer struct {
2224
sync.RWMutex
2325
}
2426

27+
func New(port int, endpoints Endpoints) *Balancer {
28+
return &Balancer{
29+
Port: port,
30+
Endpoints: endpoints,
31+
}
32+
}
33+
2534
func (b *Balancer) Start(stop chan struct{}) error {
2635
listener, err := net.Listen("tcp", ":"+fmt.Sprint(b.Port))
2736
if err != nil {
@@ -71,36 +80,29 @@ func (b *Balancer) Select(e Endpoints) Endpoint {
7180
return b.Policy.Select(e)
7281
}
7382

74-
func (b *Balancer) ListenForEndpoints() (int, error) {
75-
port, err := util.RandomPort()
76-
if err != nil {
77-
return port, err
78-
}
83+
func (b *Balancer) ListenForEndpoints(port int) {
84+
handler := http.HandlerFunc(
85+
func(w http.ResponseWriter, r *http.Request) {
86+
var message Message
87+
err := json.NewDecoder(r.Body).Decode(&message)
88+
if r.Method != "POST" || err != nil {
89+
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
90+
return
91+
}
7992

80-
go func() {
81-
err := http.ListenAndServe(":"+fmt.Sprint(port),
82-
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
83-
var message Message
84-
err := json.NewDecoder(r.Body).Decode(&message)
85-
if err != nil {
86-
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
87-
return
88-
}
89-
90-
b.Lock()
91-
if message.Remove {
92-
b.Endpoints.Delete(message.Endpoint.Id)
93-
} else {
94-
b.Endpoints.Add(message.Endpoint)
95-
}
96-
b.Unlock()
97-
98-
w.WriteHeader(200)
99-
}))
100-
101-
// should not get here
102-
log.Println(err)
103-
}()
93+
b.Lock()
94+
if message.Remove {
95+
b.Endpoints.Delete(message.Endpoint.Ip)
96+
} else {
97+
b.Endpoints.Add(message.Endpoint)
98+
}
99+
b.Unlock()
100+
101+
w.WriteHeader(200)
102+
})
103+
104+
err := http.ListenAndServe(":"+fmt.Sprint(port), handler)
104105

105-
return port, err
106+
// should not get here
107+
log.Println(err)
106108
}

balancer/endpoint.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package balancer
2+
3+
import (
4+
"fmt"
5+
"strconv"
6+
"strings"
7+
8+
"github.com/abiosoft/dockward/util"
9+
)
10+
11+
type Endpoint struct {
12+
Id string
13+
Ip string
14+
Port int
15+
}
16+
17+
func ParseEndpoint(addr string) Endpoint {
18+
// assume addr as host, port as 80
19+
ip, port, id := addr, 80, util.RandChars(10)
20+
21+
// if its valid port, assume as port, host as 127.0.0.1
22+
if p, err := strconv.Atoi(addr); err == nil {
23+
ip = "127.0.0.1"
24+
port = p
25+
}
26+
27+
// attempt parse
28+
str := strings.Split(addr, ":")
29+
30+
// valid host/port
31+
if len(str) > 1 {
32+
ip = str[0]
33+
port, _ = strconv.Atoi(str[1])
34+
}
35+
// valid id
36+
if len(str) > 2 {
37+
id = str[2]
38+
}
39+
40+
return Endpoint{
41+
Id: id,
42+
Ip: ip,
43+
Port: port,
44+
}
45+
}
46+
47+
func (ep Endpoint) Addr() string {
48+
return ep.Ip + ":" + fmt.Sprint(ep.Port)
49+
}
50+
51+
func (ep Endpoint) String() string {
52+
return ep.Addr() + ":" + ep.Id
53+
}
54+
55+
type Endpoints []Endpoint
56+
57+
func (e Endpoints) Len() int {
58+
return len(e)
59+
}
60+
61+
func (e *Endpoints) Add(ep Endpoint) {
62+
for i, endpoint := range *e {
63+
if endpoint.Id == ep.Id {
64+
// already exists, replace instead.
65+
(*e)[i] = ep
66+
return
67+
}
68+
}
69+
*e = append(*e, ep)
70+
}
71+
72+
func (e *Endpoints) Delete(id string) {
73+
pos := -1
74+
for i, ep := range *e {
75+
if ep.Id == id {
76+
pos = i
77+
break
78+
}
79+
}
80+
if pos == -1 {
81+
return
82+
}
83+
part := (*e)[:pos]
84+
if pos < len(*e)-1 {
85+
part = append(part, (*e)[pos+1:]...)
86+
}
87+
*e = part
88+
}
Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
package tcpforward
1+
package balancer
22

3-
import "math/rand"
3+
import (
4+
"math/rand"
5+
"time"
6+
)
47

58
// Policy is a selection policy.
69
type Policy interface {
@@ -12,9 +15,12 @@ type Random struct{}
1215

1316
// Select satisfies Policy.
1417
func (r Random) Select(d Endpoints) Endpoint {
15-
i := rand.Int() % (len(d) - 1)
16-
if i < 0 {
17-
i = 0
18+
if len(d) == 0 {
19+
return Endpoint{}
1820
}
19-
return d[i]
21+
return d[rand.Int()%len(d)]
22+
}
23+
24+
func init() {
25+
rand.Seed(time.Now().UnixNano())
2026
}

tcpforward/tcpforward.go renamed to balancer/tcpforward.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package tcpforward
1+
package balancer
22

33
import (
44
"fmt"
@@ -32,6 +32,9 @@ func handle(conn net.Conn, dest string) {
3232
client, err := net.Dial("tcp", dest)
3333
if err != nil {
3434
log.Println(err)
35+
if err := conn.Close(); err != nil {
36+
log.Println(err)
37+
}
3538
return
3639
}
3740
var w sync.WaitGroup

cleanup.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package main
2+
import (
3+
"os/signal"
4+
"fmt"
5+
"syscall"
6+
"os"
7+
)
8+
9+
// trapInterrupts traps OS interrupt signals.
10+
func trapInterrupts(exit chan struct{}) chan struct{} {
11+
done := make(chan struct{})
12+
c := make(chan os.Signal, 1)
13+
signal.Notify(c, os.Interrupt)
14+
signal.Notify(c, syscall.SIGTERM)
15+
go func() {
16+
select {
17+
case <-c:
18+
fmt.Print("OS Interrupt signal received. Performing cleanup...")
19+
cleanUp()
20+
fmt.Println("Done.")
21+
done <- struct{}{}
22+
case <-exit:
23+
cleanUp()
24+
done <- struct{}{}
25+
}
26+
27+
}()
28+
return done
29+
}
30+
31+
// cleanUpFuncs is list of functions to call before application exits.
32+
var cleanUpFuncs []func()
33+
34+
// addCleanUpFunc adds a function to cleanUpFuncs.
35+
func addCleanUpFunc(f func()) {
36+
cleanUpFuncs = append(cleanUpFuncs, f)
37+
}
38+
39+
// cleanUp calls all functions in cleanUpFuncs.
40+
func cleanUp() {
41+
for i := range cleanUpFuncs {
42+
cleanUpFuncs[i]()
43+
}
44+
}

cli.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"io/ioutil"
7+
"os"
8+
"strconv"
9+
)
10+
11+
type cliArgs struct {
12+
HostPort int
13+
ContainerName string
14+
ContainerId string
15+
ContainerLabel string
16+
Host bool
17+
Endpoints []string
18+
}
19+
20+
func usageErr(err error) {
21+
exit(fmt.Errorf("%v\n\n%v", err, Usage))
22+
}
23+
24+
func parseCli() cliArgs {
25+
if len(os.Args) == 1 {
26+
usageErr(fmt.Errorf("Command missing"))
27+
}
28+
29+
switch os.Args[1] {
30+
case "help":
31+
fmt.Println(Usage)
32+
exit(nil)
33+
case "version":
34+
fmt.Println("dockward version", Version)
35+
exit(nil)
36+
}
37+
hostPort, err := strconv.Atoi(os.Args[1])
38+
if err != nil {
39+
usageErr(err)
40+
}
41+
42+
args := cliArgs{HostPort: hostPort}
43+
44+
fs := flag.FlagSet{}
45+
fs.SetOutput(ioutil.Discard)
46+
47+
fs.BoolVar(&args.Host, "host", args.Host, "")
48+
fs.StringVar(&args.ContainerId, "id", args.ContainerId, "")
49+
fs.StringVar(&args.ContainerName, "name", args.ContainerName, "")
50+
fs.StringVar(&args.ContainerLabel, "label", args.ContainerLabel, "")
51+
52+
err = fs.Parse(os.Args[2:])
53+
if err != nil {
54+
exit(err)
55+
}
56+
57+
// if not host mode, require a container param.
58+
if !args.Host {
59+
if args.ContainerId == "" && args.ContainerLabel == "" && args.ContainerName == "" {
60+
exit(fmt.Errorf("One of container id, name or label is required."))
61+
}
62+
}
63+
64+
if fs.NArg() > 0 {
65+
args.Endpoints = fs.Args()
66+
}
67+
return args
68+
}
69+

0 commit comments

Comments
 (0)