Skip to content
This repository was archived by the owner on Apr 15, 2025. It is now read-only.

Commit 3cc10a2

Browse files
committed
Introduced ability to provide custom Dial functions.
1 parent 6283500 commit 3cc10a2

File tree

5 files changed

+69
-9
lines changed

5 files changed

+69
-9
lines changed

cluster.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,16 @@ type mongoCluster struct {
5151
direct bool
5252
cachedIndex map[string]bool
5353
sync chan bool
54+
dial dialer
5455
}
5556

56-
func newCluster(userSeeds []string, direct bool) *mongoCluster {
57-
cluster := &mongoCluster{userSeeds: userSeeds, references: 1, direct: direct}
57+
func newCluster(userSeeds []string, direct bool, dial dialer) *mongoCluster {
58+
cluster := &mongoCluster{
59+
userSeeds: userSeeds,
60+
references: 1,
61+
direct: direct,
62+
dial: dial,
63+
}
5864
cluster.serverSynced.L = cluster.RWMutex.RLocker()
5965
cluster.sync = make(chan bool, 1)
6066
go cluster.syncServersLoop()
@@ -329,7 +335,7 @@ func (cluster *mongoCluster) syncServersIteration(direct bool) {
329335
go func() {
330336
defer wg.Done()
331337

332-
server, err := newServer(addr, cluster.sync)
338+
server, err := newServer(addr, cluster.sync, cluster.dial)
333339
if err != nil {
334340
log("SYNC Failed to start sync of ", addr, ": ", err.Error())
335341
return

cluster_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
package mgo_test
2828

2929
import (
30+
"fmt"
3031
"io"
3132
. "launchpad.net/gocheck"
3233
"labix.org/v2/mgo"
3334
"labix.org/v2/mgo/bson"
35+
"net"
3436
"strings"
3537
"time"
3638
)
@@ -1043,3 +1045,38 @@ func (s *S) TestSetModeEventualIterBug(c *C) {
10431045
c.Assert(iter.Err(), Equals, nil)
10441046
c.Assert(i, Equals, N)
10451047
}
1048+
1049+
func (s *S) TestCustomDial(c *C) {
1050+
dials := make(chan bool, 16)
1051+
dial := func(addr net.Addr) (net.Conn, error) {
1052+
tcpaddr, ok := addr.(*net.TCPAddr)
1053+
if !ok {
1054+
return nil, fmt.Errorf("unexpected address type: %T", addr)
1055+
}
1056+
dials <- true
1057+
return net.DialTCP("tcp", nil, tcpaddr)
1058+
}
1059+
info := mgo.DialInfo{
1060+
Addrs: []string{"localhost:40012"},
1061+
Dial: dial,
1062+
}
1063+
1064+
// Use hostname here rather than IP, to make things trickier.
1065+
session, err := mgo.DialWithInfo(&info)
1066+
c.Assert(err, IsNil)
1067+
defer session.Close()
1068+
1069+
const N = 3
1070+
for i := 0; i < N; i++ {
1071+
select {
1072+
case <-dials:
1073+
case <-time.After(5 * time.Second):
1074+
c.Fatalf("expected %d dials, got %d", N, i)
1075+
}
1076+
}
1077+
select {
1078+
case <-dials:
1079+
c.Fatalf("got more dials than expected")
1080+
case <-time.After(100 * time.Millisecond):
1081+
}
1082+
}

server.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ type mongoServer struct {
4646
closed bool
4747
master bool
4848
sync chan bool
49+
dial dialer
4950
}
5051

51-
func newServer(addr string, sync chan bool) (server *mongoServer, err error) {
52+
type dialer func(addr net.Addr) (net.Conn, error)
53+
54+
func newServer(addr string, sync chan bool, dial dialer) (server *mongoServer, err error) {
5255
tcpaddr, err := net.ResolveTCPAddr("tcp", addr)
5356
if err != nil {
5457
log("Failed to resolve ", addr, ": ", err.Error())
@@ -64,6 +67,7 @@ func newServer(addr string, sync chan bool) (server *mongoServer, err error) {
6467
ResolvedAddr: resolvedAddr,
6568
tcpaddr: tcpaddr,
6669
sync: sync,
70+
dial: dial,
6771
}
6872
return
6973
}
@@ -115,10 +119,17 @@ func (server *mongoServer) Connect() (*mongoSocket, error) {
115119
addr := server.Addr
116120
tcpaddr := server.tcpaddr
117121
master := server.master
122+
dial := server.dial
118123
server.RUnlock()
119124

120125
log("Establishing new connection to ", addr, "...")
121-
conn, err := net.DialTCP("tcp", nil, tcpaddr)
126+
var conn net.Conn
127+
var err error
128+
if dial == nil {
129+
conn, err = net.DialTCP("tcp", nil, tcpaddr)
130+
} else {
131+
conn, err = dial(tcpaddr)
132+
}
122133
if err != nil {
123134
log("Connection to ", addr, " failed: ", err.Error())
124135
return nil, err

session.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"fmt"
3434
"labix.org/v2/mgo/bson"
3535
"math"
36+
"net"
3637
"reflect"
3738
"runtime"
3839
"sort"
@@ -244,11 +245,16 @@ type DialInfo struct {
244245
// or the "admin" database otherwise. See the Session.Login method too.
245246
Username string
246247
Password string
248+
249+
// Dial optionally specifies the dial function for creating connections.
250+
// At the moment addr will have type *net.TCPAddr, but other types may
251+
// be provided in the future, so check and fail if necessary.
252+
Dial func(addr net.Addr) (net.Conn, error)
247253
}
248254

249255
// DialWithInfo establishes a new session to the cluster identified by info.
250256
func DialWithInfo(info *DialInfo) (*Session, error) {
251-
cluster := newCluster(info.Addrs, info.Direct)
257+
cluster := newCluster(info.Addrs, info.Direct, info.Dial)
252258
session := newSession(Eventual, cluster, info.Timeout)
253259
session.defaultdb = info.Database
254260
if session.defaultdb == "" {

socket.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type replyFunc func(err error, reply *replyOp, docNum int, docData []byte)
3838
type mongoSocket struct {
3939
sync.Mutex
4040
server *mongoServer // nil when cached
41-
conn *net.TCPConn
41+
conn net.Conn
4242
addr string // For debugging only.
4343
nextRequestId uint32
4444
replyFuncs map[uint32]replyFunc
@@ -97,7 +97,7 @@ type requestInfo struct {
9797
replyFunc replyFunc
9898
}
9999

100-
func newSocket(server *mongoServer, conn *net.TCPConn) *mongoSocket {
100+
func newSocket(server *mongoServer, conn net.Conn) *mongoSocket {
101101
socket := &mongoSocket{conn: conn, addr: server.Addr}
102102
socket.gotNonce.L = &socket.Mutex
103103
socket.replyFuncs = make(map[uint32]replyFunc)
@@ -371,7 +371,7 @@ func (socket *mongoSocket) Query(ops ...interface{}) (err error) {
371371
return err
372372
}
373373

374-
func fill(r *net.TCPConn, b []byte) error {
374+
func fill(r net.Conn, b []byte) error {
375375
l := len(b)
376376
n, err := r.Read(b)
377377
for n != l && err == nil {

0 commit comments

Comments
 (0)