diff --git a/go.mod b/go.mod index 98b6f214ef..d9a128eaec 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,9 @@ require ( github.com/libp2p/go-conn-security-multistream v0.2.0 github.com/libp2p/go-eventbus v0.1.0 github.com/libp2p/go-libp2p-autonat v0.2.2 - github.com/libp2p/go-libp2p-blankhost v0.1.4 + github.com/libp2p/go-libp2p-blankhost v0.1.5-0.20200504035409-3dd0148936e2 github.com/libp2p/go-libp2p-circuit v0.2.2 - github.com/libp2p/go-libp2p-core v0.5.2 + github.com/libp2p/go-libp2p-core v0.5.3 github.com/libp2p/go-libp2p-discovery v0.4.0 github.com/libp2p/go-libp2p-loggables v0.1.0 github.com/libp2p/go-libp2p-mplex v0.2.3 diff --git a/go.sum b/go.sum index 245636212c..e21f1904ef 100644 --- a/go.sum +++ b/go.sum @@ -178,6 +178,8 @@ github.com/libp2p/go-libp2p-autonat v0.2.2/go.mod h1:HsM62HkqZmHR2k1xgX34WuWDzk/ github.com/libp2p/go-libp2p-blankhost v0.1.1/go.mod h1:pf2fvdLJPsC1FsVrNP3DUUvMzUts2dsLLBEpo1vW1ro= github.com/libp2p/go-libp2p-blankhost v0.1.4 h1:I96SWjR4rK9irDHcHq3XHN6hawCRTPUADzkJacgZLvk= github.com/libp2p/go-libp2p-blankhost v0.1.4/go.mod h1:oJF0saYsAXQCSfDq254GMNmLNz6ZTHTOvtF4ZydUvwU= +github.com/libp2p/go-libp2p-blankhost v0.1.5-0.20200504035409-3dd0148936e2 h1:LX7e4NrshYhQQiy2kY33fEi4Pk7sAuXd+YKMgXpjvj8= +github.com/libp2p/go-libp2p-blankhost v0.1.5-0.20200504035409-3dd0148936e2/go.mod h1:jONCAJqEP+Z8T6EQviGL4JsQcLx1LgTGtVqFNY8EMfQ= github.com/libp2p/go-libp2p-circuit v0.1.4 h1:Phzbmrg3BkVzbqd4ZZ149JxCuUWu2wZcXf/Kr6hZJj8= github.com/libp2p/go-libp2p-circuit v0.1.4/go.mod h1:CY67BrEjKNDhdTk8UgBX1Y/H5c3xkAcs3gnksxY7osU= github.com/libp2p/go-libp2p-circuit v0.2.1 h1:BDiBcQxX/ZJJ/yDl3sqZt1bjj4PkZCEi7IEpwxXr13k= @@ -196,6 +198,8 @@ github.com/libp2p/go-libp2p-core v0.5.0/go.mod h1:49XGI+kc38oGVwqSBhDEwytaAxgZas github.com/libp2p/go-libp2p-core v0.5.1/go.mod h1:uN7L2D4EvPCvzSH5SrhR72UWbnSGpt5/a35Sm4upn4Y= github.com/libp2p/go-libp2p-core v0.5.2 h1:hevsCcdLiazurKBoeNn64aPYTVOPdY4phaEGeLtHOAs= github.com/libp2p/go-libp2p-core v0.5.2/go.mod h1:uN7L2D4EvPCvzSH5SrhR72UWbnSGpt5/a35Sm4upn4Y= +github.com/libp2p/go-libp2p-core v0.5.3 h1:b9W3w7AZR2n/YJhG8d0qPFGhGhCWKIvPuJgp4hhc4MM= +github.com/libp2p/go-libp2p-core v0.5.3/go.mod h1:uN7L2D4EvPCvzSH5SrhR72UWbnSGpt5/a35Sm4upn4Y= github.com/libp2p/go-libp2p-crypto v0.1.0 h1:k9MFy+o2zGDNGsaoZl0MA3iZ75qXxr9OOoAZF+sD5OQ= github.com/libp2p/go-libp2p-crypto v0.1.0/go.mod h1:sPUokVISZiy+nNuTTH/TY+leRSxnFj/2GLjtOTW90hI= github.com/libp2p/go-libp2p-discovery v0.2.0 h1:1p3YSOq7VsgaL+xVHPi8XAmtGyas6D2J6rWBEfz/aiY= diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 91e10192a2..e20685c340 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -3,6 +3,7 @@ package basichost import ( "context" "errors" + "fmt" "io" "net" "sync" @@ -36,6 +37,9 @@ import ( // peer (for all addresses). const maxAddressResolution = 32 +// addrChangeTickrInterval is the interval between two address change ticks. +var addrChangeTickrInterval = 5 * time.Second + var log = logging.Logger("basichost") var ( @@ -156,7 +160,7 @@ func NewHost(ctx context.Context, net network.Network, opts *HostOpts) (*BasicHo if h.emitters.evtLocalProtocolsUpdated, err = h.eventbus.Emitter(&event.EvtLocalProtocolsUpdated{}); err != nil { return nil, err } - if h.emitters.evtLocalAddrsUpdated, err = h.eventbus.Emitter(&event.EvtLocalAddressesUpdated{}, eventbus.Stateful); err != nil { + if h.emitters.evtLocalAddrsUpdated, err = h.eventbus.Emitter(&event.EvtLocalAddressesUpdated{}); err != nil { return nil, err } @@ -207,6 +211,16 @@ func NewHost(ctx context.Context, net network.Network, opts *HostOpts) (*BasicHo net.SetStreamHandler(h.newStreamHandler) + // persist a signed peer record for self to the peerstore. + rec := peer.PeerRecordFromAddrInfo(peer.AddrInfo{h.ID(), h.Addrs()}) + ev, err := record.Seal(rec, h.signKey) + if err != nil { + return nil, fmt.Errorf("failed to create signed record for self: %w", err) + } + if _, err := cab.ConsumePeerRecord(ev, peerstore.PermanentAddrTTL); err != nil { + return nil, fmt.Errorf("failed to persist signed record to peerstore: %w", err) + } + return h, nil } @@ -392,7 +406,7 @@ func (h *BasicHost) background() { // periodically schedules an IdentifyPush to update our peers for changes // in our address set (if needed) - ticker := time.NewTicker(5 * time.Second) + ticker := time.NewTicker(addrChangeTickrInterval) defer ticker.Stop() for { diff --git a/p2p/protocol/identify/id.go b/p2p/protocol/identify/id.go index a3649a53c6..822ee83dde 100644 --- a/p2p/protocol/identify/id.go +++ b/p2p/protocol/identify/id.go @@ -14,7 +14,6 @@ import ( "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" - "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-libp2p-core/record" "github.com/libp2p/go-eventbus" @@ -29,12 +28,9 @@ import ( var log = logging.Logger("net/identify") -// ID is the protocol.ID of the Identify Service. -const ID = "/p2p/id/1.1.0" - -// LegacyID is the protocol.ID of version 1.0.0 of the identify -// service, which does not support signed peer records. -const LegacyID = "/ipfs/id/1.0.0" +// ID is the protocol.ID of version 1.0.0 of the identify +// service. +const ID = "/ipfs/id/1.0.0" // LibP2PVersion holds the current protocol version for a client running this code // TODO(jbenet): fix the versioning mess. @@ -63,6 +59,17 @@ func init() { // transientTTL is a short ttl for invalidated previously connected addrs const transientTTL = 10 * time.Second +type addPeerHandlerReq struct { + rp peer.ID + localConnAddr ma.Multiaddr + remoteConnAddr ma.Multiaddr + resp chan *peerHandler +} + +type rmPeerHandlerReq struct { + p peer.ID +} + // IDService is a structure that implements ProtocolIdentify. // It is a trivial service that gives the other peer some // useful information about the local peer. A sort of hello. @@ -88,9 +95,6 @@ type IDService struct { addrMu sync.Mutex - peerrec *record.Envelope - peerrecMu sync.RWMutex - // our own observed addresses. observedAddrs *ObservedAddrManager @@ -99,6 +103,9 @@ type IDService struct { evtPeerIdentificationCompleted event.Emitter evtPeerIdentificationFailed event.Emitter } + + addPeerHandlerCh chan addPeerHandlerReq + rmPeerHandlerCh chan rmPeerHandlerReq } // NewIDService constructs a new *IDService and activates it by @@ -123,13 +130,16 @@ func NewIDService(h host.Host, opts ...Option) *IDService { ctxCancel: cancel, conns: make(map[network.Conn]chan struct{}), observedAddrs: NewObservedAddrManager(hostCtx, h), + + addPeerHandlerCh: make(chan addPeerHandlerReq), + rmPeerHandlerCh: make(chan rmPeerHandlerReq), } // handle local protocol handler updates, and push deltas to peers. var err error s.refCount.Add(1) - go s.handleEvents() + go s.loop() s.emitters.evtPeerProtocolsUpdated, err = h.EventBus().Emitter(&event.EvtPeerProtocolsUpdated{}) if err != nil { @@ -146,20 +156,17 @@ func NewIDService(h host.Host, opts ...Option) *IDService { // register protocols that do not depend on peer records. h.SetStreamHandler(IDDelta, s.deltaHandler) - h.SetStreamHandler(LegacyID, s.requestHandler) - h.SetStreamHandler(LegacyIDPush, s.pushHandler) - - // register protocols that depend on peer records. - h.SetStreamHandler(ID, s.requestHandler) + h.SetStreamHandler(ID, s.sendIdentifyResp) h.SetStreamHandler(IDPush, s.pushHandler) h.Network().Notify((*netNotifiee)(s)) return s } -func (ids *IDService) handleEvents() { +func (ids *IDService) loop() { defer ids.refCount.Done() + phs := make(map[peer.ID]*peerHandler) sub, err := ids.Host.EventBus().Subscribe([]interface{}{&event.EvtLocalProtocolsUpdated{}, &event.EvtLocalAddressesUpdated{}}, eventbus.BufSize(256)) if err != nil { @@ -167,19 +174,94 @@ func (ids *IDService) handleEvents() { return } - defer sub.Close() + defer func() { + sub.Close() + for pid := range phs { + phs[pid].close() + } + }() + + phClosedCh := make(chan peer.ID) for { select { + case addReq := <-ids.addPeerHandlerCh: + rp := addReq.rp + ph, ok := phs[rp] + if ok { + addReq.resp <- ph + continue + } + + if ids.Host.Network().Connectedness(rp) == network.Connected { + mes := &pb.Identify{} + ids.populateMessage(mes, rp, addReq.localConnAddr, addReq.remoteConnAddr) + ph = newPeerHandler(rp, ids, mes) + ph.start() + phs[rp] = ph + addReq.resp <- ph + } + + case rmReq := <-ids.rmPeerHandlerCh: + rp := rmReq.p + if ids.Host.Network().Connectedness(rp) != network.Connected { + // before we remove the peerhandler, we should ensure that it will not send any + // more messages. Otherwise, we might create a new handler and the Identify response + // synchronized with the new handler might be overwritten by a message sent by this "old" handler. + ph, ok := phs[rp] + if !ok { + // move on, move on, there's nothing to see here. + continue + } + ids.refCount.Add(1) + go func(req rmPeerHandlerReq, ph *peerHandler) { + defer ids.refCount.Done() + ph.close() + select { + case <-ids.ctx.Done(): + return + case phClosedCh <- req.p: + } + }(rmReq, ph) + } + + case rp := <-phClosedCh: + ph := phs[rp] + + // If we are connected to the peer, it means that we got a connection from the peer + // before we could finish removing it's handler on the previous disconnection. + // If we delete the handler, we wont be able to push updates to it + // till we see a new connection. So, we should restart the handler. + // The fact that we got the handler on this channel means that it's context and handler + // have completed because we write the handler to this chanel only after it closed. + if ids.Host.Network().Connectedness(rp) == network.Connected { + ph.start() + } else { + delete(phs, rp) + } + case e, more := <-sub.Out(): if !more { return } - switch evt := e.(type) { + switch e.(type) { case event.EvtLocalAddressesUpdated: - ids.handleLocalAddrsUpdated(evt) + for pid := range phs { + select { + case phs[pid].pushCh <- struct{}{}: + default: + log.Debugf("dropping addr updated message for %s as buffer full", pid.Pretty()) + } + } + case event.EvtLocalProtocolsUpdated: - ids.handleProtosChanged(evt) + for pid := range phs { + select { + case phs[pid].deltaCh <- struct{}{}: + default: + log.Debugf("dropping protocol updated message for %s as buffer full", pid.Pretty()) + } + } } case <-ids.ctx.Done(): @@ -197,20 +279,6 @@ func (ids *IDService) Close() error { return nil } -func (ids *IDService) handleProtosChanged(evt event.EvtLocalProtocolsUpdated) { - ids.fireProtocolDelta(evt) -} - -func (ids *IDService) handleLocalAddrsUpdated(evt event.EvtLocalAddressesUpdated) { - ids.peerrecMu.Lock() - rec := evt.SignedPeerRecord - ids.peerrec = rec - ids.peerrecMu.Unlock() - - log.Debug("triggering push based on updated local PeerRecord") - ids.Push() -} - // OwnObservedAddrs returns the addresses peers have reported we've dialed from func (ids *IDService) OwnObservedAddrs() []ma.Multiaddr { return ids.observedAddrs.Addrs() @@ -293,36 +361,51 @@ func (ids *IDService) identifyConn(c network.Conn, signal chan struct{}) { ids.removeConn(c) return } + s.SetProtocol(ID) - protocolIDs := []string{ID, LegacyID} // ok give the response to our handler. - var selectedProto string - if selectedProto, err = msmux.SelectOneOf(protocolIDs, s); err != nil { + if err = msmux.SelectProtoOrFail(ID, s); err != nil { log.Event(context.TODO(), "IdentifyOpenFailed", c.RemotePeer(), logging.Metadata{"error": err}) s.Reset() return } - s.SetProtocol(protocol.ID(selectedProto)) - ids.responseHandler(s) + ids.handleIdentifyResponse(s) } -func protoSupportsPeerRecords(proto protocol.ID) bool { - return proto == ID || proto == IDPush -} +func (ids *IDService) sendIdentifyResp(s network.Stream) { + var ph *peerHandler + + defer func() { + helpers.FullClose(s) + if ph != nil { + ph.msgMu.RUnlock() + } + }() -func (ids *IDService) requestHandler(s network.Stream) { - defer helpers.FullClose(s) c := s.Conn() + phCh := make(chan *peerHandler, 1) + select { + case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), c.LocalMultiaddr(), + c.RemoteMultiaddr(), phCh}: + case <-ids.ctx.Done(): + return + } + + select { + case ph = <-phCh: + case <-ids.ctx.Done(): + return + } + + ph.msgMu.RLock() w := ggio.NewDelimitedWriter(s) - mes := pb.Identify{} - ids.populateMessage(&mes, s.Conn(), protoSupportsPeerRecords(s.Protocol())) - w.WriteMsg(&mes) + w.WriteMsg(ph.idMsgSnapshot) log.Debugf("%s sent message to %s %s", ID, c.RemotePeer(), c.RemoteMultiaddr()) } -func (ids *IDService) responseHandler(s network.Stream) { +func (ids *IDService) handleIdentifyResponse(s network.Stream) { c := s.Conn() r := ggio.NewDelimitedReader(s, 2048) @@ -336,69 +419,10 @@ func (ids *IDService) responseHandler(s network.Stream) { defer func() { go helpers.FullClose(s) }() log.Debugf("%s received message from %s %s", s.Protocol(), c.RemotePeer(), c.RemoteMultiaddr()) - ids.consumeMessage(&mes, c, protoSupportsPeerRecords(s.Protocol())) + ids.consumeMessage(&mes, c) } -func (ids *IDService) broadcast(protos []protocol.ID, payloadWriter func(s network.Stream)) { - var wg sync.WaitGroup - - protoStrs := protocol.ConvertToStrings(protos) - ctx, cancel := context.WithTimeout(ids.ctx, 30*time.Second) - ctx = network.WithNoDial(ctx, protoStrs[0]) - - pstore := ids.Host.Peerstore() - for _, p := range ids.Host.Network().Peers() { - wg.Add(1) - - go func(p peer.ID, conns []network.Conn) { - defer wg.Done() - - // Wait till identify completes so we can check the - // supported protocols. - for _, c := range conns { - select { - case <-ids.IdentifyWait(c): - case <-ctx.Done(): - return - } - } - - // avoid the unnecessary stream if the peer does not support the protocol. - if sup, err := pstore.SupportsProtocols(p, protoStrs...); err != nil && len(sup) == 0 { - // the peer does not support the required protocol. - return - } - // if the peerstore query errors, we go ahead anyway. - - s, err := ids.Host.NewStream(ctx, p, protos...) - if err != nil { - log.Debugf("error opening push stream to %s: %s", p, err.Error()) - return - } - - rch := make(chan struct{}, 1) - go func() { - payloadWriter(s) - rch <- struct{}{} - }() - - select { - case <-rch: - case <-ctx.Done(): - // this is taking too long, abort! - s.Reset() - } - }(p, ids.Host.Network().ConnsToPeer(p)) - } - - // this supervisory goroutine is necessary to cancel the context - go func() { - wg.Wait() - cancel() - }() -} - -func (ids *IDService) populateMessage(mes *pb.Identify, c network.Conn, usePeerRecords bool) { +func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, remoteAddr ma.Multiaddr) { // set protocols this node is currently handling protos := ids.Host.Mux().Protocols() mes.Protocols = make([]string, len(protos)) @@ -408,13 +432,26 @@ func (ids *IDService) populateMessage(mes *pb.Identify, c network.Conn, usePeerR // observed address so other side is informed of their // "public" address, at least in relation to us. - mes.ObservedAddr = c.RemoteMultiaddr().Bytes() - - if usePeerRecords { - ids.peerrecMu.RLock() - rec := ids.peerrec - ids.peerrecMu.RUnlock() + mes.ObservedAddr = remoteAddr.Bytes() + + // populate unsigned addresses. + // peers that do not yet support signed addresses will need this. + // set listen addrs, get our latest addrs from Host. + laddrs := ids.Host.Addrs() + // Note: LocalMultiaddr is sometimes 0.0.0.0 + viaLoopback := manet.IsIPLoopback(localAddr) || manet.IsIPLoopback(remoteAddr) + mes.ListenAddrs = make([][]byte, 0, len(laddrs)) + for _, addr := range laddrs { + if !viaLoopback && manet.IsIPLoopback(addr) { + continue + } + mes.ListenAddrs = append(mes.ListenAddrs, addr.Bytes()) + } + // populate signed record. + cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore()) + if ok { + rec := cab.GetPeerRecord(ids.Host.ID()) if rec == nil { log.Errorf("latest peer record does not exist. identify message incomplete!") } else { @@ -423,20 +460,8 @@ func (ids *IDService) populateMessage(mes *pb.Identify, c network.Conn, usePeerR log.Errorf("error marshaling peer record: %v", err) } else { mes.SignedPeerRecord = recBytes - log.Debugf("%s sent peer record to %s", c.LocalPeer(), c.RemotePeer()) - } - } - } else { - // set listen addrs, get our latest addrs from Host. - laddrs := ids.Host.Addrs() - // Note: LocalMultiaddr is sometimes 0.0.0.0 - viaLoopback := manet.IsIPLoopback(c.LocalMultiaddr()) || manet.IsIPLoopback(c.RemoteMultiaddr()) - mes.ListenAddrs = make([][]byte, 0, len(laddrs)) - for _, addr := range laddrs { - if !viaLoopback && manet.IsIPLoopback(addr) { - continue + log.Debugf("%s sent peer record to %s", ids.Host.ID(), rp) } - mes.ListenAddrs = append(mes.ListenAddrs, addr.Bytes()) } } @@ -468,7 +493,7 @@ func (ids *IDService) populateMessage(mes *pb.Identify, c network.Conn, usePeerR mes.AgentVersion = &av } -func (ids *IDService) consumeMessage(mes *pb.Identify, c network.Conn, usePeerRecords bool) { +func (ids *IDService) consumeMessage(mes *pb.Identify, c network.Conn) { p := c.RemotePeer() // mes.Protocols @@ -499,13 +524,11 @@ func (ids *IDService) consumeMessage(mes *pb.Identify, c network.Conn, usePeerRe // many undialable addresses for other peers. // add certified addresses for the peer, if they sent us a signed peer record + // otherwise use the unsigned addresses. var signedPeerRecord *record.Envelope - if usePeerRecords { - var err error - signedPeerRecord, err = signedPeerRecordFromMessage(mes) - if err != nil { - log.Errorf("error getting peer record from Identify message: %v", err) - } + signedPeerRecord, err := signedPeerRecordFromMessage(mes) + if err != nil { + log.Errorf("error getting peer record from Identify message: %v", err) } // Extend the TTLs on the known (probably) good addresses. @@ -701,6 +724,13 @@ func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) { defer ids.addrMu.Unlock() if ids.Host.Network().Connectedness(v.RemotePeer()) != network.Connected { + // consider removing the peer handler for this + select { + case ids.rmPeerHandlerCh <- rmPeerHandlerReq{v.RemotePeer()}: + case <-ids.ctx.Done(): + return + } + // Last disconnect. ps := ids.Host.Peerstore() ps.UpdateAddrs(v.RemotePeer(), peerstore.ConnectedAddrTTL, peerstore.RecentlyConnectedAddrTTL) diff --git a/p2p/protocol/identify/id_delta.go b/p2p/protocol/identify/id_delta.go index 0d5849ed44..2a09170e6b 100644 --- a/p2p/protocol/identify/id_delta.go +++ b/p2p/protocol/identify/id_delta.go @@ -7,8 +7,9 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" - ggio "github.com/gogo/protobuf/io" pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" + + ggio "github.com/gogo/protobuf/io" ) const IDDelta = "/p2p/id/delta/1.0.0" @@ -40,27 +41,6 @@ func (ids *IDService) deltaHandler(s network.Stream) { } } -// fireProtocolDelta fires a delta message to all connected peers to signal a local protocol table update. -func (ids *IDService) fireProtocolDelta(evt event.EvtLocalProtocolsUpdated) { - mes := pb.Identify{ - Delta: &pb.Delta{ - AddedProtocols: protocol.ConvertToStrings(evt.Added), - RmProtocols: protocol.ConvertToStrings(evt.Removed), - }, - } - deltaWriter := func(s network.Stream) { - defer helpers.FullClose(s) - c := s.Conn() - err := ggio.NewDelimitedWriter(s).WriteMsg(&mes) - if err != nil { - log.Warningf("%s error while sending delta update to %s: %s", IDDelta, c.RemotePeer(), c.RemoteMultiaddr()) - return - } - log.Debugf("%s sent delta update to %s: %s", IDDelta, c.RemotePeer(), c.RemoteMultiaddr()) - } - ids.broadcast([]protocol.ID{IDDelta}, deltaWriter) -} - // consumeDelta processes an incoming delta from a peer, updating the peerstore // and emitting the appropriate events. func (ids *IDService) consumeDelta(id peer.ID, delta *pb.Delta) error { diff --git a/p2p/protocol/identify/id_push.go b/p2p/protocol/identify/id_push.go index 6b8dfee98e..c2977e4bfe 100644 --- a/p2p/protocol/identify/id_push.go +++ b/p2p/protocol/identify/id_push.go @@ -2,7 +2,6 @@ package identify import ( "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/protocol" ) // IDPush is the protocol.ID of the Identify push protocol. It sends full identify messages containing @@ -10,20 +9,9 @@ import ( // // It is in the process of being replaced by identify delta, which sends only diffs for better // resource utilisation. -const IDPush = "/p2p/id/push/1.1.0" - -// LegacyIDPush is the protocol.ID of the previous version of the Identify push protocol, -// which does not support exchanging signed addresses in PeerRecords. -// It is still supported for backwards compatibility if a remote peer does not support -// the current version. -const LegacyIDPush = "/ipfs/id/push/1.0.0" - -// Push pushes a full identify message to all peers containing the current state. -func (ids *IDService) Push() { - ids.broadcast([]protocol.ID{IDPush, LegacyIDPush}, ids.requestHandler) -} +const IDPush = "/ipfs/id/push/1.0.0" // pushHandler handles incoming identify push streams. The behaviour is identical to the ordinary identify protocol. func (ids *IDService) pushHandler(s network.Stream) { - ids.responseHandler(s) + ids.handleIdentifyResponse(s) } diff --git a/p2p/protocol/identify/id_test.go b/p2p/protocol/identify/id_test.go index 577f61e5ae..d0006347a6 100644 --- a/p2p/protocol/identify/id_test.go +++ b/p2p/protocol/identify/id_test.go @@ -3,14 +3,12 @@ package identify_test import ( "context" "fmt" - "github.com/libp2p/go-libp2p-core/record" "reflect" "sort" + "sync" "testing" "time" - "github.com/libp2p/go-eventbus" - libp2p "github.com/libp2p/go-libp2p" ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/event" "github.com/libp2p/go-libp2p-core/helpers" @@ -19,15 +17,21 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" "github.com/libp2p/go-libp2p-core/protocol" + "github.com/libp2p/go-libp2p-core/record" coretest "github.com/libp2p/go-libp2p-core/test" + "github.com/libp2p/go-eventbus" + libp2p "github.com/libp2p/go-libp2p" blhost "github.com/libp2p/go-libp2p-blankhost" + "github.com/libp2p/go-libp2p-peerstore/pstoremem" swarmt "github.com/libp2p/go-libp2p-swarm/testing" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/libp2p/go-libp2p/p2p/protocol/identify" + pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" - "github.com/libp2p/go-libp2p-peerstore/pstoremem" - mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + ggio "github.com/gogo/protobuf/io" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" ) func subtestIDService(t *testing.T) { @@ -36,8 +40,6 @@ func subtestIDService(t *testing.T) { h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) h2 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) - generatePeerRecord(t, h1) - generatePeerRecord(t, h2) h1p := h1.ID() h2p := h2.ID() @@ -217,11 +219,18 @@ func testHasPublicKey(t *testing.T, h host.Host, p peer.ID, shouldBe ic.PubKey) } } +func getSignedRecord(t *testing.T, h host.Host, p peer.ID) *record.Envelope { + cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) + require.True(t, ok) + rec := cab.GetPeerRecord(p) + return rec +} + // we're using BlankHost in our tests, which doesn't automatically generate peer records -// like BasicHost. This generates a record and puts it on the host's event bus, which -// will cause the identify service to start supporting new protocol versions that -// depend on peer records being available. -func generatePeerRecord(t *testing.T, h host.Host) { +// and emit address change events on the bus like BasicHost. +// This generates a record, puts it in the peerstore and emits an addr change event +// which will cause the identify service to push it to all peers it's connected to. +func emitAddrChangeEvt(t *testing.T, h host.Host) { t.Helper() key := h.Peerstore().PrivKey(h.ID()) @@ -236,7 +245,13 @@ func generatePeerRecord(t *testing.T, h host.Host) { if err != nil { t.Fatalf("error generating peer record: %s", err) } - evt := event.EvtLocalAddressesUpdated{SignedPeerRecord: signed} + + cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) + require.True(t, ok) + _, err = cab.ConsumePeerRecord(signed, peerstore.PermanentAddrTTL) + require.NoError(t, err) + + evt := event.EvtLocalAddressesUpdated{} emitter, err := h.EventBus().Emitter(new(event.EvtLocalAddressesUpdated), eventbus.Stateful) if err != nil { t.Fatal(err) @@ -407,72 +422,83 @@ func TestIdentifyDeltaOnProtocolChange(t *testing.T) { // add two new protocols in h2 and wait for identify to send deltas. h2.SetStreamHandler(protocol.ID("foo"), func(_ network.Stream) {}) h2.SetStreamHandler(protocol.ID("bar"), func(_ network.Stream) {}) - <-time.After(500 * time.Millisecond) // check that h1 now knows about h2's new protocols. - protos, err = h1.Peerstore().GetProtocols(h2.ID()) - if err != nil { - t.Fatal(err) - } - have := make(map[string]struct{}, len(protos)) - for _, p := range protos { - have[p] = struct{}{} - } + require.Eventually(t, func() bool { + protos, err = h1.Peerstore().GetProtocols(h2.ID()) + if err != nil { + return false + } + have := make(map[string]struct{}, len(protos)) + for _, p := range protos { + have[p] = struct{}{} + } - if _, ok := have["foo"]; !ok { - t.Fatalf("expected peer 1 to know that peer 2 now speaks protocol 'foo', known: %v", protos) - } - if _, ok := have["bar"]; !ok { - t.Fatalf("expected peer 1 to know that peer 2 now speaks protocol 'bar', known: %v", protos) - } + _, okfoo := have["foo"] + _, okbar := have["bar"] + return okfoo && okbar + }, 5*time.Second, 500*time.Millisecond) // remove one of the newly added protocols from h2, and wait for identify to send the delta. h2.RemoveStreamHandler(protocol.ID("bar")) - <-time.After(500 * time.Millisecond) - // check that h1 now has forgotten about h2's bar protocol. - protos, err = h1.Peerstore().GetProtocols(h2.ID()) - if err != nil { - t.Fatal(err) - } - have = make(map[string]struct{}, len(protos)) - for _, p := range protos { - have[p] = struct{}{} - } - if _, ok := have["foo"]; !ok { - t.Fatalf("expected peer 1 to know that peer 2 now speaks protocol 'foo', known: %v", protos) - } - if _, ok := have["bar"]; ok { - t.Fatalf("expected peer 1 to have forgotten that peer 2 spoke protocol 'bar', known: %v", protos) - } + require.Eventually(t, func() bool { + protos, err = h1.Peerstore().GetProtocols(h2.ID()) + if err != nil { + return false + } + have := make(map[string]struct{}, len(protos)) + for _, p := range protos { + have[p] = struct{}{} + } + + _, okfoo := have["foo"] + _, okbar := have["bar"] + return okfoo && !okbar + }, 5*time.Second, 500*time.Millisecond) // make sure that h1 emitted events in the eventbus for h2's protocol updates. - evts := make([]event.EvtPeerProtocolsUpdated, 3) done := make(chan struct{}) + + var lk sync.Mutex + var added []string + var removed []string + var success bool + go func() { - evts[0] = (<-sub.Out()).(event.EvtPeerProtocolsUpdated) - evts[1] = (<-sub.Out()).(event.EvtPeerProtocolsUpdated) - evts[2] = (<-sub.Out()).(event.EvtPeerProtocolsUpdated) + defer close(done) + for { + select { + case <-time.After(5 * time.Second): + return + case e, ok := <-sub.Out(): + if !ok { + return + } + evt := e.(event.EvtPeerProtocolsUpdated) + lk.Lock() + added = append(added, protocol.ConvertToStrings(evt.Added)...) + removed = append(removed, protocol.ConvertToStrings(evt.Removed)...) + sort.Strings(added) + sort.Strings(removed) + if reflect.DeepEqual(added, []string{"bar", "foo"}) && + reflect.DeepEqual(removed, []string{"bar"}) { + success = true + lk.Unlock() + return + } + lk.Unlock() + } + } + close(done) }() - select { - case <-done: - case <-time.After(1 * time.Second): - t.Fatalf("timed out while consuming events from subscription") - } - - added := protocol.ConvertToStrings(append(evts[0].Added, append(evts[1].Added, evts[2].Added...)...)) - removed := protocol.ConvertToStrings(append(evts[0].Removed, append(evts[1].Removed, evts[2].Removed...)...)) - sort.Strings(added) - sort.Strings(removed) + <-done - if !reflect.DeepEqual(added, []string{"bar", "foo"}) { - t.Fatalf("expected to have received updates for added protos") - } - if !reflect.DeepEqual(removed, []string{"bar"}) { - t.Fatalf("expected to have received updates for removed protos") - } + lk.Lock() + defer lk.Unlock() + require.True(t, success, "did not get correct peer protocol updated events") } // TestIdentifyDeltaWhileIdentifyingConn tests that the host waits to push delta updates if an identify is ongoing. @@ -495,12 +521,12 @@ func TestIdentifyDeltaWhileIdentifyingConn(t *testing.T) { block := make(chan struct{}) handler := func(s network.Stream) { <-block - go helpers.FullClose(s) + w := ggio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Identify{Protocols: h1.Mux().Protocols()}) + helpers.FullClose(s) } h1.RemoveStreamHandler(identify.ID) - h1.RemoveStreamHandler(identify.LegacyID) h1.SetStreamHandler(identify.ID, handler) - h1.SetStreamHandler(identify.LegacyID, handler) // from h2 connect to h1. if err := h2.Connect(ctx, peer.AddrInfo{ID: h1.ID(), Addrs: h1.Addrs()}); err != nil { @@ -511,7 +537,6 @@ func TestIdentifyDeltaWhileIdentifyingConn(t *testing.T) { conn := h2.Network().ConnsToPeer(h1.ID())[0] go func() { ids2.IdentifyConn(conn) - ids2.IdentifyConn(conn) }() <-time.After(500 * time.Millisecond) @@ -544,6 +569,88 @@ func TestIdentifyDeltaWhileIdentifyingConn(t *testing.T) { } } +func TestIdentifyPushOnAddrChange(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + h2 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + + h1p := h1.ID() + h2p := h2.ID() + + ids1 := identify.NewIDService(h1) + ids2 := identify.NewIDService(h2) + defer ids1.Close() + defer ids2.Close() + + testKnowsAddrs(t, h1, h2p, []ma.Multiaddr{}) // nothing + testKnowsAddrs(t, h2, h1p, []ma.Multiaddr{}) // nothing + + h2pi := h2.Peerstore().PeerInfo(h2p) + require.NoError(t, h1.Connect(ctx, h2pi)) + require.Len(t, h1.Network().ConnsToPeer(h2p), 1) + require.Len(t, h2.Network().ConnsToPeer(h1p), 1) + + // wait for identify to complete and assert current addresses + ids1.IdentifyConn(h1.Network().ConnsToPeer(h2p)[0]) + ids2.IdentifyConn(h2.Network().ConnsToPeer(h1p)[0]) + + testKnowsAddrs(t, h1, h2p, h2.Peerstore().Addrs(h2p)) + testKnowsAddrs(t, h2, h1p, h1.Peerstore().Addrs(h1p)) + + // change addr on host 1 and ensure host2 gets a push + lad := ma.StringCast("/ip4/127.0.0.1/tcp/1234") + require.NoError(t, h1.Network().Listen(lad)) + require.Contains(t, h1.Addrs(), lad) + emitAddrChangeEvt(t, h1) + + require.Eventually(t, func() bool { + addrs := h2.Peerstore().Addrs(h1p) + for _, ad := range addrs { + if ad.Equal(lad) { + return true + } + } + return false + }, 5*time.Second, 500*time.Millisecond) + require.NotNil(t, getSignedRecord(t, h2, h1p)) + + // change addr on host2 and ensure host 1 gets a pus + lad = ma.StringCast("/ip4/127.0.0.1/tcp/1235") + require.NoError(t, h2.Network().Listen(lad)) + require.Contains(t, h2.Addrs(), lad) + emitAddrChangeEvt(t, h2) + + require.Eventually(t, func() bool { + addrs := h1.Peerstore().Addrs(h2p) + for _, ad := range addrs { + if ad.Equal(lad) { + return true + } + } + return false + }, 5*time.Second, 500*time.Millisecond) + require.NotNil(t, getSignedRecord(t, h1, h2p)) + + // change addr on host2 again + lad2 := ma.StringCast("/ip4/127.0.0.1/tcp/1236") + require.NoError(t, h2.Network().Listen(lad2)) + require.Contains(t, h2.Addrs(), lad2) + emitAddrChangeEvt(t, h2) + + require.Eventually(t, func() bool { + addrs := h1.Peerstore().Addrs(h2p) + for _, ad := range addrs { + if ad.Equal(lad2) { + return true + } + } + return false + }, 5*time.Second, 500*time.Millisecond) + require.NotNil(t, getSignedRecord(t, h1, h2p)) +} + func TestUserAgent(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -581,8 +688,7 @@ func TestUserAgent(t *testing.T) { } } -// make sure that we still support older peers using "legacy" versions of identify -func TestCompatibilityWithPeersThatDoNotSupportSignedAddrs(t *testing.T) { +func TestSendPushIfDeltaNotSupported(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -591,40 +697,44 @@ func TestCompatibilityWithPeersThatDoNotSupportSignedAddrs(t *testing.T) { defer h2.Close() defer h1.Close() - ids := identify.NewIDService(h1) + ids1 := identify.NewIDService(h1) ids2 := identify.NewIDService(h2) + defer func() { + ids1.Close() + ids2.Close() + }() - defer ids.Close() - defer ids2.Close() - - // generate initial peer record only for h1. this will cause h1 to enable - // the new protocols, but h2 will still use legacy protos - generatePeerRecord(t, h1) - - h2p := h2.ID() - h2pi := h2.Peerstore().PeerInfo(h2p) - if err := h1.Connect(ctx, h2pi); err != nil { - t.Fatal(err) - } - - h1t2c := h1.Network().ConnsToPeer(h2p) - if len(h1t2c) == 0 { - t.Fatal("should have a conn here") - } - - ids.IdentifyConn(h1t2c[0]) - // the IDService should be opened automatically, by the network. - // what we should see now is that both peers know about each others listen addresses. - t.Log("test peer1 has peer2 addrs correctly") - testKnowsAddrs(t, h1, h2p, h2.Peerstore().Addrs(h2p)) // has them - testHasCertifiedAddrs(t, h1, h2p, []ma.Multiaddr{}) // should not have signed addrs - - // double check that it works when both peers support the new protos - // enable new protos for h2 by generating a peer record - generatePeerRecord(t, h2) - - // if we re-identify, h1 should now have certified addrs for h2 - ids.IdentifyConn(h1t2c[0]) - t.Log("test peer1 has peer2 certified addrs correctly") - testHasCertifiedAddrs(t, h1, h2p, h2.Peerstore().Addrs(h2p)) + err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) + require.NoError(t, err) + + // wait for them to Identify each other + ids1.IdentifyConn(h1.Network().ConnsToPeer(h2.ID())[0]) + ids2.IdentifyConn(h2.Network().ConnsToPeer(h1.ID())[0]) + + // h1 knows h2 speaks Delta + sup, err := h1.Peerstore().SupportsProtocols(h2.ID(), []string{identify.IDDelta}...) + require.NoError(t, err) + require.Equal(t, []string{identify.IDDelta}, sup) + + // h2 stops supporting Delta and that information flows to h1 + h2.RemoveStreamHandler(identify.IDDelta) + + require.Eventually(t, func() bool { + sup, err := h1.Peerstore().SupportsProtocols(h2.ID(), []string{identify.IDDelta}...) + return err == nil && len(sup) == 0 + }, 5*time.Second, 500*time.Millisecond) + + // h1 starts listening on a new protocol and h2 finds out about that through a push + h1.SetStreamHandler("rand", func(network.Stream) {}) + require.Eventually(t, func() bool { + sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []string{"rand"}...) + return err == nil && len(sup) == 1 && sup[0] == "rand" + }, 5*time.Second, 500*time.Millisecond) + + // h1 stops listening on a protocol and h2 finds out about it via a push + h1.RemoveStreamHandler("rand") + require.Eventually(t, func() bool { + sup, err := h2.Peerstore().SupportsProtocols(h1.ID(), []string{"rand"}...) + return err == nil && len(sup) == 0 + }, 5*time.Second, 500*time.Millisecond) } diff --git a/p2p/protocol/identify/peer_loop.go b/p2p/protocol/identify/peer_loop.go new file mode 100644 index 0000000000..7fc4021d4b --- /dev/null +++ b/p2p/protocol/identify/peer_loop.go @@ -0,0 +1,268 @@ +package identify + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/helpers" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + + pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" + + ggio "github.com/gogo/protobuf/io" +) + +var errProtocolNotSupported = errors.New("protocol not supported") +var isTesting = false + +type peerHandler struct { + ids *IDService + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + pid peer.ID + + msgMu sync.RWMutex + idMsgSnapshot *pb.Identify + + pushCh chan struct{} + deltaCh chan struct{} + evalTestCh chan func() // for testing +} + +func newPeerHandler(pid peer.ID, ids *IDService, initState *pb.Identify) *peerHandler { + ph := &peerHandler{ + ids: ids, + pid: pid, + + idMsgSnapshot: initState, + + pushCh: make(chan struct{}, 1), + deltaCh: make(chan struct{}, 1), + } + + if isTesting { + ph.evalTestCh = make(chan func()) + } + + return ph +} + +func (ph *peerHandler) start() { + ctx, cancel := context.WithCancel(context.Background()) + ph.ctx = ctx + ph.cancel = cancel + + ph.wg.Add(1) + go ph.loop() +} + +func (ph *peerHandler) close() error { + ph.cancel() + ph.wg.Wait() + return nil +} + +// per peer loop for pushing updates +func (ph *peerHandler) loop() { + defer ph.wg.Done() + + for { + select { + // our listen addresses have changed, send an IDPush. + case <-ph.pushCh: + if err := ph.sendPush(); err != nil { + log.Warnw("failed to send Identify Push", "peer", ph.pid, "error", err) + } + + case <-ph.deltaCh: + if err := ph.sendDelta(); err != nil { + log.Warnw("failed to send Identify Delta", "peer", ph.pid, "error", err) + } + + case fnc := <-ph.evalTestCh: + fnc() + + case <-ph.ctx.Done(): + return + } + } +} + +func (ph *peerHandler) sendDelta() error { + mes := ph.mkDelta() + if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) { + return nil + } + + // send a push if the peer does not support the Delta protocol. + if !ph.peerSupportsProtos([]string{IDDelta}) { + log.Debugw("will send push as peer does not support delta", "peer", ph.pid) + if err := ph.sendPush(); err != nil { + return fmt.Errorf("failed to send push on delta message: %w", err) + } + return nil + } + + ph.msgMu.Lock() + // update our identify snapshot for this peer by applying the delta to it + ph.applyDelta(mes) + ph.msgMu.Unlock() + + ds, err := ph.openStream([]string{IDDelta}) + if err != nil { + return fmt.Errorf("failed to open delta stream: %w", err) + } + + if err := ph.sendMessage(ds, &pb.Identify{Delta: mes}); err != nil { + return fmt.Errorf("failed to send delta message, %w", err) + } + return nil +} + +func (ph *peerHandler) sendPush() error { + dp, err := ph.openStream([]string{IDPush}) + if err == errProtocolNotSupported { + log.Debugw("not sending push as peer does not support protocol", "peer", ph.pid) + return nil + } + if err != nil { + return fmt.Errorf("failed to open push stream: %w", err) + } + + conn := dp.Conn() + mes := &pb.Identify{} + ph.ids.populateMessage(mes, ph.pid, conn.LocalMultiaddr(), conn.RemoteMultiaddr()) + + ph.msgMu.Lock() + ph.idMsgSnapshot = mes + ph.msgMu.Unlock() + + if err := ph.sendMessage(dp, mes); err != nil { + return fmt.Errorf("failed to send push message: %w", err) + } + return nil +} + +func (ph *peerHandler) applyDelta(mes *pb.Delta) { + for _, p1 := range mes.RmProtocols { + for j, p2 := range ph.idMsgSnapshot.Protocols { + if p2 == p1 { + ph.idMsgSnapshot.Protocols[j] = ph.idMsgSnapshot.Protocols[len(ph.idMsgSnapshot.Protocols)-1] + ph.idMsgSnapshot.Protocols = ph.idMsgSnapshot.Protocols[:len(ph.idMsgSnapshot.Protocols)-1] + } + } + } + + for _, p := range mes.AddedProtocols { + ph.idMsgSnapshot.Protocols = append(ph.idMsgSnapshot.Protocols, p) + } +} + +func (ph *peerHandler) openStream(protos []string) (network.Stream, error) { + // wait for the other peer to send us an Identify response on "all" connections we have with it + // so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol + // if we know for a fact that it dosen't support the protocol. + conns := ph.ids.Host.Network().ConnsToPeer(ph.pid) + for _, c := range conns { + select { + case <-ph.ids.IdentifyWait(c): + case <-ph.ctx.Done(): + return nil, ph.ctx.Err() + } + } + + if !ph.peerSupportsProtos(protos) { + return nil, errProtocolNotSupported + } + + // negotiate a stream without opening a new connection as we "should" already have a connection. + ctx, cancel := context.WithTimeout(ph.ctx, 30*time.Second) + defer cancel() + ctx = network.WithNoDial(ctx, "should already have connection") + + // newstream will open a stream on the first protocol the remote peer supports from the among + // the list of protocols passed to it. + s, err := ph.ids.Host.NewStream(ctx, ph.pid, protocol.ConvertFromStrings(protos)...) + if err != nil { + return nil, err + } + + return s, err +} + +// returns true if the peer supports atleast one of the given protocols +func (ph *peerHandler) peerSupportsProtos(protos []string) bool { + conns := ph.ids.Host.Network().ConnsToPeer(ph.pid) + for _, c := range conns { + select { + case <-ph.ids.IdentifyWait(c): + case <-ph.ctx.Done(): + return false + } + } + + pstore := ph.ids.Host.Peerstore() + + if sup, err := pstore.SupportsProtocols(ph.pid, protos...); err == nil && len(sup) == 0 { + return false + } + return true +} + +func (ph *peerHandler) mkDelta() *pb.Delta { + old := ph.idMsgSnapshot.GetProtocols() + curr := ph.ids.Host.Mux().Protocols() + + oldProtos := make(map[string]struct{}, len(old)) + currProtos := make(map[string]struct{}, len(curr)) + + for _, proto := range old { + oldProtos[proto] = struct{}{} + } + + for _, proto := range curr { + currProtos[proto] = struct{}{} + } + + var added []string + var removed []string + + // has it been added ? + for p := range currProtos { + if _, ok := oldProtos[p]; !ok { + added = append(added, p) + } + } + + // has it been removed ? + for p := range oldProtos { + if _, ok := currProtos[p]; !ok { + removed = append(removed, p) + } + } + + return &pb.Delta{ + AddedProtocols: added, + RmProtocols: removed, + } +} + +func (ph *peerHandler) sendMessage(s network.Stream, mes *pb.Identify) error { + defer helpers.FullClose(s) + c := s.Conn() + if err := ggio.NewDelimitedWriter(s).WriteMsg(mes); err != nil { + return err + + } + log.Debugw("sent identify update", "protocol", s.Protocol(), "peer", c.RemotePeer(), + "peer address", c.RemoteMultiaddr()) + return nil +} diff --git a/p2p/protocol/identify/peer_loop_test.go b/p2p/protocol/identify/peer_loop_test.go new file mode 100644 index 0000000000..6eb3bd6056 --- /dev/null +++ b/p2p/protocol/identify/peer_loop_test.go @@ -0,0 +1,115 @@ +package identify + +import ( + "context" + "testing" + + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + + blhost "github.com/libp2p/go-libp2p-blankhost" + swarmt "github.com/libp2p/go-libp2p-swarm/testing" + pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" + + "github.com/stretchr/testify/require" +) + +func doeval(t *testing.T, ph *peerHandler, f func()) { + done := make(chan struct{}, 1) + ph.evalTestCh <- func() { + f() + done <- struct{}{} + } + <-done +} + +func TestMakeApplyDelta(t *testing.T) { + isTesting = true + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + defer h1.Close() + ids1 := NewIDService(h1) + ph := newPeerHandler(h1.ID(), ids1, &pb.Identify{}) + ph.start() + defer ph.close() + + m1 := ph.mkDelta() + require.NotNil(t, m1) + // all the Id protocols must have been added + require.NotEmpty(t, m1.AddedProtocols) + doeval(t, ph, func() { + ph.applyDelta(m1) + }) + + h1.SetStreamHandler("p1", func(network.Stream) {}) + m2 := ph.mkDelta() + require.Len(t, m2.AddedProtocols, 1) + require.Contains(t, m2.AddedProtocols, "p1") + require.Empty(t, m2.RmProtocols) + doeval(t, ph, func() { + ph.applyDelta(m2) + }) + + h1.SetStreamHandler("p2", func(network.Stream) {}) + h1.SetStreamHandler("p3", func(stream network.Stream) {}) + m3 := ph.mkDelta() + require.Len(t, m3.AddedProtocols, 2) + require.Contains(t, m3.AddedProtocols, "p2") + require.Contains(t, m3.AddedProtocols, "p3") + require.Empty(t, m3.RmProtocols) + doeval(t, ph, func() { + ph.applyDelta(m3) + }) + + h1.RemoveStreamHandler("p3") + m4 := ph.mkDelta() + require.Empty(t, m4.AddedProtocols) + require.Len(t, m4.RmProtocols, 1) + require.Contains(t, m4.RmProtocols, "p3") + doeval(t, ph, func() { + ph.applyDelta(m4) + }) + + h1.RemoveStreamHandler("p2") + h1.RemoveStreamHandler("p1") + m5 := ph.mkDelta() + require.Empty(t, m5.AddedProtocols) + require.Len(t, m5.RmProtocols, 2) + require.Contains(t, m5.RmProtocols, "p2") + require.Contains(t, m5.RmProtocols, "p1") +} + +func TestHandlerClose(t *testing.T) { + isTesting = true + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + defer h1.Close() + ids1 := NewIDService(h1) + ph := newPeerHandler(h1.ID(), ids1, nil) + ph.start() + + require.NoError(t, ph.close()) +} + +func TestPeerSupportsProto(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) + defer h1.Close() + ids1 := NewIDService(h1) + + rp := peer.ID("test") + ph := newPeerHandler(rp, ids1, nil) + require.NoError(t, h1.Peerstore().AddProtocols(rp, "test")) + require.True(t, ph.peerSupportsProtos([]string{"test"})) + require.False(t, ph.peerSupportsProtos([]string{"random"})) + + // remove support for protocol and check + require.NoError(t, h1.Peerstore().RemoveProtocols(rp, "test")) + require.False(t, ph.peerSupportsProtos([]string{"test"})) +}