Skip to content
159 changes: 158 additions & 1 deletion provider/provider.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
package provider

import (
"context"
"fmt"
"sync"
"sync/atomic"

"github.com/filecoin-project/go-clock"
logging "github.com/ipfs/go-log/v2"
kb "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
ma "github.com/multiformats/go-multiaddr"
mh "github.com/multiformats/go-multihash"

"github.com/probe-lab/go-libdht/kad/key"
"github.com/probe-lab/go-libdht/kad/key/bitstr"

pb "github.com/libp2p/go-libp2p-kad-dht/pb"
"github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace"
)

// DHTProvider is an interface for providing keys to a DHT swarm. It holds a
Expand Down Expand Up @@ -41,8 +57,149 @@ type DHTProvider interface {

var _ DHTProvider = &SweepingProvider{}

var logger = logging.Logger("dht/SweepingProvider")

type KadClosestPeersRouter interface {
GetClosestPeers(context.Context, string) ([]peer.ID, error)
}

type SweepingProvider struct {
// TODO: implement me
// TODO: complete me
peerid peer.ID
router KadClosestPeersRouter

clock clock.Clock

maxProvideConnsPerWorker int

msgSender pb.MessageSender
getSelfAddrs func() []ma.Multiaddr
addLocalRecord func(mh.Multihash) error
}

// FIXME: remove me
func (s *SweepingProvider) SatisfyLinter() {
s.vanillaProvide([]byte{})
s.closestPeersToKey("")
}

// vanillaProvide provides a single key to the network without any
// optimization. It should be used for providing a small number of keys
// (typically 1 or 2), because exploring the keyspace would add too much
// overhead for a small number of keys.
func (s *SweepingProvider) vanillaProvide(k mh.Multihash) (bitstr.Key, error) {
// Add provider record to local provider store.
s.addLocalRecord(k)
// Get peers to which the record will be allocated.
peers, err := s.router.GetClosestPeers(context.Background(), string(k))
if err != nil {
return "", err
}
coveredPrefix, _ := keyspace.ShortestCoveredPrefix(bitstr.Key(key.BitString(keyspace.MhToBit256(k))), peers)
addrInfo := peer.AddrInfo{ID: s.peerid, Addrs: s.getSelfAddrs()}
keysAllocations := make(map[peer.ID][]mh.Multihash)
for _, p := range peers {
keysAllocations[p] = []mh.Multihash{k}
}
return coveredPrefix, s.sendProviderRecords(keysAllocations, addrInfo)
}

// closestPeersToKey returns a valid peer ID sharing a long common prefix with
// the provided key. Note that the returned peer IDs aren't random, they are
// taken from a static list of preimages.
func (s *SweepingProvider) closestPeersToKey(k bitstr.Key) ([]peer.ID, error) {
p, _ := kb.GenRandPeerIDWithCPL(keyspace.KeyToBytes(k), kb.PeerIDPreimageMaxCpl)
return s.router.GetClosestPeers(context.Background(), string(p))
}

const minimalRegionReachablePeersRatio float32 = 0.2

type provideJob struct {
pid peer.ID
keys []mh.Multihash
}

// sendProviderRecords manages reprovides for all given peer ids and allocated
// keys. Upon failure to reprovide a key, or to connect to a peer, it will NOT
// retry.
//
// Returns an error if we were unable to reprovide keys to a given threshold of
// peers. In this case, the region reprovide is considered failed and the
// caller is responsible for trying again. This allows detecting if we are
// offline.
func (s *SweepingProvider) sendProviderRecords(keysAllocations map[peer.ID][]mh.Multihash, addrInfo peer.AddrInfo) error {
nPeers := len(keysAllocations)
if nPeers == 0 {
return nil
}
startTime := s.clock.Now()
errCount := atomic.Uint32{}
nWorkers := s.maxProvideConnsPerWorker
jobChan := make(chan provideJob, nWorkers)

wg := sync.WaitGroup{}
wg.Add(nWorkers)
for range nWorkers {
go func() {
pmes := genProvideMessage(addrInfo)
defer wg.Done()
for job := range jobChan {
err := s.provideKeysToPeer(job.pid, job.keys, pmes)
if err != nil {
errCount.Add(1)
}
}
}()
}

for p, keys := range keysAllocations {
jobChan <- provideJob{p, keys}
}
close(jobChan)
wg.Wait()

errCountLoaded := int(errCount.Load())
logger.Infof("sent provider records to peers in %s, errors %d/%d", s.clock.Since(startTime), errCountLoaded, len(keysAllocations))

if errCountLoaded == nPeers || errCountLoaded > int(float32(nPeers)*(1-minimalRegionReachablePeersRatio)) {
return fmt.Errorf("unable to provide to enough peers (%d/%d)", nPeers-errCountLoaded, nPeers)
}
return nil
}

// genProvideMessage generates a new provide message with the supplied
// AddrInfo. The message contains no keys, as they will be set later before
// sending the message.
func genProvideMessage(addrInfo peer.AddrInfo) *pb.Message {
pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, []byte{}, 0)
pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]peer.AddrInfo{addrInfo})
return pmes
}

// maxConsecutiveProvideFailuresAllowed is the maximum number of consecutive
// provides that are allowed to fail to the same remote peer before cancelling
// all pending requests to this peer.
const maxConsecutiveProvideFailuresAllowed = 2

// provideKeysToPeer performs the network operation to advertise to the given
// DHT server (p) that we serve all the given keys.
func (s *SweepingProvider) provideKeysToPeer(p peer.ID, keys []mh.Multihash, pmes *pb.Message) error {
errCount := 0
for _, mh := range keys {
pmes.Key = mh
err := s.msgSender.SendMessage(context.Background(), p, pmes)
if err != nil {
errCount++

if errCount == len(keys) || errCount > maxConsecutiveProvideFailuresAllowed {
return fmt.Errorf("failed to provide to %s: %s", p, err.Error())
}
} else if errCount > 0 {
// Reset error count
errCount = 0
}
}
return nil
}

// ProvideOnce sends provider records for the specified keys to the DHT swarm
Expand Down
118 changes: 118 additions & 0 deletions provider/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package provider

import (
"context"
"crypto/sha256"
"errors"
"strconv"
"testing"

"github.com/ipfs/go-test/random"
kb "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
mh "github.com/multiformats/go-multihash"

"github.com/probe-lab/go-libdht/kad/key/bit256"
"github.com/probe-lab/go-libdht/kad/trie"

pb "github.com/libp2p/go-libp2p-kad-dht/pb"
"github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace"

"github.com/stretchr/testify/require"
)

func genMultihashes(n int) []mh.Multihash {
mhs := make([]mh.Multihash, n)
var err error
for i := range n {
mhs[i], err = mh.Sum([]byte(strconv.Itoa(i)), mh.SHA2_256, -1)
if err != nil {
panic(err)
}
}
return mhs
}

var _ pb.MessageSender = (*mockMsgSender)(nil)

type mockMsgSender struct {
sendMessageFunc func(ctx context.Context, p peer.ID, m *pb.Message) error
}

func (ms *mockMsgSender) SendRequest(ctx context.Context, p peer.ID, m *pb.Message) (*pb.Message, error) {
// Unused
return nil, nil
}

func (ms *mockMsgSender) SendMessage(ctx context.Context, p peer.ID, m *pb.Message) error {
if ms.sendMessageFunc == nil {
return nil
}
return ms.sendMessageFunc(ctx, p, m)
}

func TestProvideKeysToPeer(t *testing.T) {
msgCount := 0
msgSender := &mockMsgSender{
sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error {
msgCount++
return errors.New("error")
},
}
prov := SweepingProvider{
msgSender: msgSender,
}

nKeys := 16
pid, err := peer.Decode("12BoooooPEER")
require.NoError(t, err)
mhs := genMultihashes(nKeys)
pmes := &pb.Message{}

// All ADD_PROVIDER RPCs fail, return an error after reprovideInitialFailuresAllowed+1 attempts
err = prov.provideKeysToPeer(pid, mhs, pmes)
require.Error(t, err)
require.Equal(t, maxConsecutiveProvideFailuresAllowed+1, msgCount)

// Only fail 33% of requests. The operation should be considered a success.
msgCount = 0
msgSender.sendMessageFunc = func(ctx context.Context, p peer.ID, m *pb.Message) error {
msgCount++
if msgCount%3 == 0 {
return errors.New("error")
}
return nil
}
err = prov.provideKeysToPeer(pid, mhs, pmes)
require.NoError(t, err)
require.Equal(t, nKeys, msgCount)
}

func TestKeysAllocationsToPeers(t *testing.T) {
nKeys := 1024
nPeers := 128
replicationFactor := 10

mhs := genMultihashes(nKeys)
keysTrie := trie.New[bit256.Key, mh.Multihash]()
for _, c := range mhs {
keysTrie.Add(keyspace.MhToBit256(c), c)
}
peers := random.Peers(nPeers)
peersTrie := trie.New[bit256.Key, peer.ID]()
for i := range peers {
peersTrie.Add(keyspace.PeerIDToBit256(peers[i]), peers[i])
}
keysAllocations := keyspace.AllocateToKClosest(keysTrie, peersTrie, replicationFactor)

for _, c := range mhs {
k := sha256.Sum256(c)
closestPeers := kb.SortClosestPeers(peers, k[:])[:replicationFactor]
for _, p := range closestPeers[:replicationFactor] {
require.Contains(t, keysAllocations[p], c)
}
for _, p := range closestPeers[replicationFactor:] {
require.NotContains(t, keysAllocations[p], c)
}
}
}
Loading