diff --git a/amino/defaults.go b/amino/defaults.go index a979183e2..34a7fce89 100644 --- a/amino/defaults.go +++ b/amino/defaults.go @@ -48,6 +48,11 @@ const ( // find the multiaddress associated with the returned peer id. DefaultProviderAddrTTL = 24 * time.Hour + // DefaultReprovideInterval is the default interval at which the keys should + // be reprovided to the DHT swarm to ensure there are enough live records in + // the swarm. + DefaultReprovideInterval = 22 * time.Hour + // DefaultMaxPeersPerIPGroup is the maximal number of peers with addresses in // the same IP group allowed in the routing table. Once this limit is // reached, newly discovered peers with addresses in the same IP group will diff --git a/dht.go b/dht.go index 749582162..f1d9773d5 100644 --- a/dht.go +++ b/dht.go @@ -855,6 +855,11 @@ func (dht *IpfsDHT) RoutingTable() *kb.RoutingTable { return dht.routingTable } +// BucketSize returns the size of the DHT's routing table buckets. +func (dht *IpfsDHT) BucketSize() int { + return dht.bucketSize +} + // Close calls Process Close. func (dht *IpfsDHT) Close() error { dht.cancel() @@ -897,6 +902,11 @@ func (dht *IpfsDHT) Host() host.Host { return dht.host } +// MessageSender returns the DHT's message sender. +func (dht *IpfsDHT) MessageSender() pb.MessageSender { + return dht.msgSender +} + // Ping sends a ping message to the passed peer and waits for a response. func (dht *IpfsDHT) Ping(ctx context.Context, p peer.ID) error { ctx, span := internal.StartSpan(ctx, "IpfsDHT.Ping", trace.WithAttributes(attribute.Stringer("PeerID", p))) @@ -932,6 +942,16 @@ func (dht *IpfsDHT) maybeAddAddrs(p peer.ID, addrs []ma.Multiaddr, ttl time.Dura dht.peerstore.AddAddrs(p, dht.filterAddrs(addrs), ttl) } +// FilteredAddrs returns the set of addresses that this DHT instance +// advertises to the swarm, after applying the configured addrFilter. +// +// For example: +// - In a public DHT, local and loopback addresses are filtered out. +// - In a LAN DHT, only loopback addresses are filtered out. +func (dht *IpfsDHT) FilteredAddrs() []ma.Multiaddr { + return dht.filterAddrs(dht.host.Addrs()) +} + func (dht *IpfsDHT) filterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { if f := dht.addrFilter; f != nil { return f(addrs) diff --git a/dht_test.go b/dht_test.go index f1a0af389..35bc042e7 100644 --- a/dht_test.go +++ b/dht_test.go @@ -1738,19 +1738,20 @@ func TestFindClosestPeers(t *testing.T) { nDHTs := 30 dhts := setupDHTS(t, ctx, nDHTs) defer func() { - for i := 0; i < nDHTs; i++ { + for i := range nDHTs { dhts[i].Close() defer dhts[i].host.Close() } }() t.Logf("connecting %d dhts in a ring", nDHTs) - for i := 0; i < nDHTs; i++ { + for i := range nDHTs { connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)]) } querier := dhts[1] - peers, err := querier.GetClosestPeers(ctx, "foo") + queryStr := "foo" + peers, err := querier.GetClosestPeers(ctx, queryStr) if err != nil { t.Fatal(err) } @@ -1758,6 +1759,12 @@ func TestFindClosestPeers(t *testing.T) { if len(peers) < querier.beta { t.Fatalf("got wrong number of peers (got %d, expected at least %d)", len(peers), querier.beta) } + + queryKey := kb.ConvertKey(queryStr) + sortedPeers := kb.SortClosestPeers(peers, queryKey) + for i := range len(sortedPeers) { + require.Equal(t, sortedPeers[i], peers[i]) + } } func TestFixLowPeers(t *testing.T) { diff --git a/fullrt/dht.go b/fullrt/dht.go index f140828b5..189a3a457 100644 --- a/fullrt/dht.go +++ b/fullrt/dht.go @@ -150,6 +150,7 @@ func NewFullRT(h host.Host, protocolPrefix protocol.ID, options ...Option) (*Ful EnableProviders: true, EnableValues: true, ProtocolPrefix: protocolPrefix, + MsgSenderBuilder: net.NewMessageSenderImpl, } if err := dhtcfg.Apply(fullrtcfg.dhtOpts...); err != nil { @@ -163,7 +164,7 @@ func NewFullRT(h host.Host, protocolPrefix protocol.ID, options ...Option) (*Ful return nil, err } - ms := net.NewMessageSenderImpl(h, amino.Protocols) + ms := dhtcfg.MsgSenderBuilder(h, amino.Protocols) protoMessenger, err := dht_pb.NewProtocolMessenger(ms) if err != nil { return nil, err @@ -303,6 +304,10 @@ func (dht *FullRT) Host() host.Host { return dht.h } +func (dht *FullRT) MessageSender() dht_pb.MessageSender { + return dht.messageSender +} + func (dht *FullRT) runCrawler(ctx context.Context) { defer dht.wg.Done() t := time.NewTicker(dht.crawlerInterval) diff --git a/go.mod b/go.mod index 7ae3c41ec..eb38e4436 100644 --- a/go.mod +++ b/go.mod @@ -1,19 +1,22 @@ module github.com/libp2p/go-libp2p-kad-dht -go 1.24 +go 1.24.0 require ( + github.com/gammazero/deque v1.1.0 github.com/google/gopacket v1.1.19 github.com/google/uuid v1.6.0 + github.com/guillaumemichel/reservedpool v0.2.0 github.com/hashicorp/golang-lru v1.0.2 github.com/ipfs/boxo v0.33.1 github.com/ipfs/go-cid v0.5.0 - github.com/ipfs/go-datastore v0.8.2 + github.com/ipfs/go-datastore v0.8.4 github.com/ipfs/go-detect-race v0.0.1 - github.com/ipfs/go-log/v2 v2.8.0 - github.com/ipfs/go-test v0.2.2 + github.com/ipfs/go-dsqueue v0.0.5 + github.com/ipfs/go-log/v2 v2.8.1 + github.com/ipfs/go-test v0.2.3 github.com/libp2p/go-libp2p v0.43.0 - github.com/libp2p/go-libp2p-kbucket v0.7.0 + github.com/libp2p/go-libp2p-kbucket v0.8.0 github.com/libp2p/go-libp2p-record v0.3.1 github.com/libp2p/go-libp2p-routing-helpers v0.7.5 github.com/libp2p/go-libp2p-testing v0.12.0 @@ -25,11 +28,12 @@ require ( github.com/multiformats/go-multibase v0.2.0 github.com/multiformats/go-multihash v0.2.3 github.com/multiformats/go-multistream v0.6.1 - github.com/stretchr/testify v1.10.0 + github.com/probe-lab/go-libdht v0.2.1 + github.com/stretchr/testify v1.11.1 github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 - go.opentelemetry.io/otel v1.37.0 - go.opentelemetry.io/otel/metric v1.37.0 - go.opentelemetry.io/otel/trace v1.37.0 + go.opentelemetry.io/otel v1.38.0 + go.opentelemetry.io/otel/metric v1.38.0 + go.opentelemetry.io/otel/trace v1.38.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 gonum.org/v1/gonum v0.16.0 @@ -50,6 +54,7 @@ require ( github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gorilla/websocket v1.5.3 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/huin/goupnp v1.3.0 // indirect github.com/ipfs/go-block-format v0.2.2 // indirect github.com/ipld/go-ipld-prime v0.21.0 // indirect diff --git a/go.sum b/go.sum index b13d1915f..a068f875f 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,8 @@ github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiD github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/gammazero/deque v1.1.0 h1:OyiyReBbnEG2PP0Bnv1AASLIYvyKqIFN5xfl1t8oGLo= +github.com/gammazero/deque v1.1.0/go.mod h1:JVrR+Bj1NMQbPnYclvDlvSX0nVGReLrQZ0aUMuWLctg= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= @@ -108,11 +110,15 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= +github.com/guillaumemichel/reservedpool v0.2.0 h1:q73gtdMFJHtW+dDJ/fwtk34p7JprQv8fJSK7dEjf8Sw= +github.com/guillaumemichel/reservedpool v0.2.0/go.mod h1:sXSDIaef81TFdAJglsCFCMfgF5E5Z5xK1tFhjDhvbUc= github.com/gxed/hashland/keccakpg v0.0.1/go.mod h1:kRzw3HkwxFU1mpmPP8v1WyQzwdGfmKFJ6tItnhQ67kU= github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc= @@ -127,19 +133,21 @@ github.com/ipfs/go-cid v0.5.0 h1:goEKKhaGm0ul11IHA7I6p1GmKz8kEYniqFopaB5Otwg= github.com/ipfs/go-cid v0.5.0/go.mod h1:0L7vmeNXpQpUS9vt+yEARkJ8rOg43DF3iPgn4GIN0mk= github.com/ipfs/go-datastore v0.1.0/go.mod h1:d4KVXhMt913cLBEI/PXAy6ko+W7e9AhyAKBGh803qeE= github.com/ipfs/go-datastore v0.1.1/go.mod h1:w38XXW9kVFNp57Zj5knbKWM2T+KOZCGDRVNdgPHtbHw= -github.com/ipfs/go-datastore v0.8.2 h1:Jy3wjqQR6sg/LhyY0NIePZC3Vux19nLtg7dx0TVqr6U= -github.com/ipfs/go-datastore v0.8.2/go.mod h1:W+pI1NsUsz3tcsAACMtfC+IZdnQTnC/7VfPoJBQuts0= +github.com/ipfs/go-datastore v0.8.4 h1:vXEsd76T3KIOSKXizjhmS3ICGMl+oOSjpLSxE3v8/Wc= +github.com/ipfs/go-datastore v0.8.4/go.mod h1:uT77w/XEGrvJWwHgdrMr8bqCN6ZTW9gzmi+3uK+ouHg= github.com/ipfs/go-detect-race v0.0.1 h1:qX/xay2W3E4Q1U7d9lNs1sU9nvguX0a7319XbyQ6cOk= github.com/ipfs/go-detect-race v0.0.1/go.mod h1:8BNT7shDZPo99Q74BpGMK+4D8Mn4j46UU0LZ723meps= github.com/ipfs/go-ds-badger v0.0.7/go.mod h1:qt0/fWzZDoPW6jpQeqUjR5kBfhDNB65jd9YlmAvpQBk= github.com/ipfs/go-ds-leveldb v0.1.0/go.mod h1:hqAW8y4bwX5LWcCtku2rFNX3vjDZCy5LZCg+cSZvYb8= +github.com/ipfs/go-dsqueue v0.0.5 h1:TUOk15TlCJ/NKV8Yk2W5wgkEjDa44Nem7a7FGIjsMNU= +github.com/ipfs/go-dsqueue v0.0.5/go.mod h1:i/jAlpZjBbQJLioN+XKbFgnd+u9eAhGZs9IrqIzTd9g= github.com/ipfs/go-ipfs-delay v0.0.0-20181109222059-70721b86a9a8/go.mod h1:8SP1YXK1M1kXuc4KJZINY3TQQ03J2rwBG9QfXmbRPrw= github.com/ipfs/go-ipfs-util v0.0.1/go.mod h1:spsl5z8KUnrve+73pOhSVZND1SIxPW5RyBCNzQxlJBc= github.com/ipfs/go-log v0.0.1/go.mod h1:kL1d2/hzSpI0thNYjiKfjanbVNU+IIGA/WnNESY9leM= -github.com/ipfs/go-log/v2 v2.8.0 h1:SptNTPJQV3s5EF4FdrTu/yVdOKfGbDgn1EBZx4til2o= -github.com/ipfs/go-log/v2 v2.8.0/go.mod h1:2LEEhdv8BGubPeSFTyzbqhCqrwqxCbuTNTLWqgNAipo= -github.com/ipfs/go-test v0.2.2 h1:1yjYyfbdt1w93lVzde6JZ2einh3DIV40at4rVoyEcE8= -github.com/ipfs/go-test v0.2.2/go.mod h1:cmLisgVwkdRCnKu/CFZOk2DdhOcwghr5GsHeqwexoRA= +github.com/ipfs/go-log/v2 v2.8.1 h1:Y/X36z7ASoLJaYIJAL4xITXgwf7RVeqb1+/25aq/Xk0= +github.com/ipfs/go-log/v2 v2.8.1/go.mod h1:NyhTBcZmh2Y55eWVjOeKf8M7e4pnJYM3yDZNxQBWEEY= +github.com/ipfs/go-test v0.2.3 h1:Z/jXNAReQFtCYyn7bsv/ZqUwS6E7iIcSpJ2CuzCvnrc= +github.com/ipfs/go-test v0.2.3/go.mod h1:QW8vSKkwYvWFwIZQLGQXdkt9Ud76eQXRQ9Ao2H+cA1o= github.com/ipld/go-ipld-prime v0.21.0 h1:n4JmcpOlPDIxBcY037SVfpd1G+Sj1nKZah0m6QH9C2E= github.com/ipld/go-ipld-prime v0.21.0/go.mod h1:3RLqy//ERg/y5oShXXdx5YIp50cFGOanyMctpPjsvxQ= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= @@ -195,8 +203,8 @@ github.com/libp2p/go-libp2p-asn-util v0.4.1/go.mod h1:d/NI6XZ9qxw67b4e+NgpQexCIi github.com/libp2p/go-libp2p-core v0.2.4/go.mod h1:STh4fdfa5vDYr0/SzYYeqnt+E6KfEV5VxfIrm0bcI0g= github.com/libp2p/go-libp2p-core v0.3.0/go.mod h1:ACp3DmS3/N64c2jDzcV429ukDpicbL6+TrrxANBjPGw= github.com/libp2p/go-libp2p-kbucket v0.3.1/go.mod h1:oyjT5O7tS9CQurok++ERgc46YLwEpuGoFq9ubvoUOio= -github.com/libp2p/go-libp2p-kbucket v0.7.0 h1:vYDvRjkyJPeWunQXqcW2Z6E93Ywx7fX0jgzb/dGOKCs= -github.com/libp2p/go-libp2p-kbucket v0.7.0/go.mod h1:blOINGIj1yiPYlVEX0Rj9QwEkmVnz3EP8LK1dRKBC6g= +github.com/libp2p/go-libp2p-kbucket v0.8.0 h1:QAK7RzKJpYe+EuSEATAaaHYMYLkPDGC18m9jxPLnU8s= +github.com/libp2p/go-libp2p-kbucket v0.8.0/go.mod h1:JMlxqcEyKwO6ox716eyC0hmiduSWZZl6JY93mGaaqc4= github.com/libp2p/go-libp2p-peerstore v0.1.4/go.mod h1:+4BDbDiiKf4PzpANZDAT+knVdLxvqh7hXOujessqdzs= github.com/libp2p/go-libp2p-record v0.3.1 h1:cly48Xi5GjNw5Wq+7gmjfBiG9HCzQVkiZOUZ8kUl+Fg= github.com/libp2p/go-libp2p-record v0.3.1/go.mod h1:T8itUkLcWQLCYMqtX7Th6r7SexyUJpIyPgks757td/E= @@ -346,6 +354,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/polydawn/refmt v0.89.0 h1:ADJTApkvkeBZsN0tBTx8QjpD9JkmxbKp0cxfr9qszm4= github.com/polydawn/refmt v0.89.0/go.mod h1:/zvteZs/GwLtCgZ4BL6CBsk9IKIlexP43ObX9AxTqTw= +github.com/probe-lab/go-libdht v0.2.1 h1:oBCsKBvS/OVirTO5+BT6/AOocWjdqwpfSfkTfBjUPJE= +github.com/probe-lab/go-libdht v0.2.1/go.mod h1:q+WlGiqs/UIRfdhw9Gmc+fPoAYlOim7VvXTjOI6KJmQ= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= @@ -421,8 +431,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= @@ -446,12 +456,12 @@ go.opencensus.io v0.22.1/go.mod h1:Ap50jQcDJrx6rB6VgeeFPtuPIf3wMRvRfrfYDO6+BmA= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= -go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= -go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= -go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= -go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= -go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.uber.org/dig v1.19.0 h1:BACLhebsYdpQ7IROQ1AGPjrXcP5dF80U3gKoFzbaq/4= go.uber.org/dig v1.19.0/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= go.uber.org/fx v1.24.0 h1:wE8mruvpg2kiiL1Vqd0CC+tr0/24XIB10Iwp2lLWzkg= diff --git a/lookup_optim.go b/lookup_optim.go index 0973c56b3..f32fd5de9 100644 --- a/lookup_optim.go +++ b/lookup_optim.go @@ -242,7 +242,7 @@ func (os *optimisticState) stopFn(qps *qpeerset.QueryPeerset) bool { func (os *optimisticState) putProviderRecord(pid peer.ID) { err := os.dht.protoMessenger.PutProviderAddrs(os.putCtx, pid, []byte(os.key), peer.AddrInfo{ ID: os.dht.self, - Addrs: os.dht.filterAddrs(os.dht.host.Addrs()), + Addrs: os.dht.FilteredAddrs(), }) os.peerStatesLk.Lock() if err != nil { diff --git a/provider/buffered/options.go b/provider/buffered/options.go new file mode 100644 index 000000000..11f9f4a96 --- /dev/null +++ b/provider/buffered/options.go @@ -0,0 +1,65 @@ +// Package buffered provides a buffered provider implementation that queues operations +// and processes them in batches for improved performance. +package buffered + +import "time" + +const ( + // DefaultDsName is the default datastore namespace for the buffered provider. + DefaultDsName = "bprov" // for buffered provider + // DefaultBatchSize is the default number of operations to process in a single batch. + DefaultBatchSize = 1 << 10 + // DefaultIdleWriteTime is the default duration to wait before flushing pending operations. + DefaultIdleWriteTime = time.Minute +) + +// config contains all options for the buffered provider. +type config struct { + dsName string + batchSize int + idleWriteTime time.Duration +} + +// Option is a function that configures the buffered provider. +type Option func(*config) + +// getOpts creates a config and applies Options to it. +func getOpts(opts []Option) config { + cfg := config{ + dsName: DefaultDsName, + batchSize: DefaultBatchSize, + idleWriteTime: DefaultIdleWriteTime, + } + + for _, opt := range opts { + opt(&cfg) + } + return cfg +} + +// WithDsName sets the datastore namespace for the buffered provider. +// If name is empty, the option is ignored. +func WithDsName(name string) Option { + return func(c *config) { + if len(name) > 0 { + c.dsName = name + } + } +} + +// WithBatchSize sets the number of operations to process in a single batch. +// If n is zero or negative, the option is ignored. +func WithBatchSize(n int) Option { + return func(c *config) { + if n > 0 { + c.batchSize = n + } + } +} + +// WithIdleWriteTime sets the duration to wait before flushing pending operations. +func WithIdleWriteTime(d time.Duration) Option { + return func(c *config) { + c.idleWriteTime = d + } +} diff --git a/provider/buffered/provider.go b/provider/buffered/provider.go new file mode 100644 index 000000000..5f335a03a --- /dev/null +++ b/provider/buffered/provider.go @@ -0,0 +1,269 @@ +package buffered + +import ( + "errors" + "sync" + + "github.com/ipfs/go-datastore" + "github.com/ipfs/go-dsqueue" + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p-kad-dht/provider" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal" + mh "github.com/multiformats/go-multihash" +) + +var logger = logging.Logger(provider.LoggerName) + +const ( + // provideOnceOp represents a one-time provide operation. + provideOnceOp byte = iota + // startProvidingOp represents starting continuous providing. + startProvidingOp + // forceStartProvidingOp represents forcefully starting providing (overrides existing). + forceStartProvidingOp + // stopProvidingOp represents stopping providing. + stopProvidingOp + // lastOp is used for array sizing. + lastOp +) + +var _ internal.Provider = (*SweepingProvider)(nil) + +// buffered.SweepingProvider is a wrapper around a SweepingProvider buffering +// requests, to allow core operations to return instantly. Operations are +// queued and processed asynchronously in batches for improved performance. +type SweepingProvider struct { + closeOnce sync.Once + done chan struct{} + closed chan struct{} + + newItems chan struct{} + provider internal.Provider + queue *dsqueue.DSQueue + batchSize int +} + +// New creates a new SweepingProvider that wraps the given provider with +// buffering capabilities. Operations are queued and processed asynchronously +// in batches for improved performance. +func New(prov internal.Provider, ds datastore.Batching, opts ...Option) *SweepingProvider { + cfg := getOpts(opts) + s := &SweepingProvider{ + done: make(chan struct{}), + closed: make(chan struct{}), + + newItems: make(chan struct{}, 1), + provider: prov, + queue: dsqueue.New(ds, cfg.dsName, + dsqueue.WithDedupCacheSize(0), // disable deduplication + dsqueue.WithIdleWriteTime(cfg.idleWriteTime), + ), + batchSize: cfg.batchSize, + } + go s.worker() + return s +} + +// Close stops the provider and releases all resources. +// +// It waits for the worker goroutine to finish processing current operations +// and closes the underneath provider. The queue current state is persisted on +// the datastore. +func (s *SweepingProvider) Close() error { + var err error + s.closeOnce.Do(func() { + close(s.closed) + err = errors.Join(s.queue.Close(), s.provider.Close()) + <-s.done + }) + return err +} + +// toBytes serializes an operation and multihash into a byte slice for storage. +func toBytes(op byte, key mh.Multihash) []byte { + return append([]byte{op}, key...) +} + +// fromBytes deserializes a byte slice back into an operation and multihash. +func fromBytes(data []byte) (byte, mh.Multihash, error) { + op := data[0] + h, err := mh.Cast(data[1:]) + return op, h, err +} + +// getOperations processes a batch of dequeued operations and groups them by +// type. +// +// It discards multihashes from the `StopProviding` operation if +// `StartProviding` was called after `StopProviding` for the same multihash. +func getOperations(dequeued [][]byte) ([][]mh.Multihash, error) { + stopProv := make(map[string]struct{}) + ops := [lastOp - 1][]mh.Multihash{} // don't store stop ops + + for _, bs := range dequeued { + op, h, err := fromBytes(bs) + if err != nil { + return nil, err + } + switch op { + case provideOnceOp: + ops[provideOnceOp] = append(ops[provideOnceOp], h) + case startProvidingOp, forceStartProvidingOp: + delete(stopProv, string(h)) + ops[op] = append(ops[op], h) + case stopProvidingOp: + stopProv[string(h)] = struct{}{} + } + } + stopOps := make([]mh.Multihash, 0, len(stopProv)) + for hstr := range stopProv { + stopOps = append(stopOps, mh.Multihash(hstr)) + } + return append(ops[:], stopOps), nil +} + +// executeOperation executes a provider operation on the underlying provider +// with the given multihashes, logging any errors encountered. +func executeOperation(f func(...mh.Multihash) error, keys []mh.Multihash) { + if len(keys) == 0 { + return + } + if err := f(keys...); err != nil { + logger.Warn(err) + } +} + +// worker processes operations from the queue in batches. +// It runs in a separate goroutine and continues until the provider is closed. +func (s *SweepingProvider) worker() { + defer close(s.done) + var emptyQueue bool + for { + if emptyQueue { + select { + case <-s.closed: + return + case <-s.newItems: + } + emptyQueue = false + } else { + select { + case <-s.closed: + return + case <-s.newItems: + default: + } + } + + res, err := s.queue.GetN(s.batchSize) + if err != nil { + logger.Warnf("BufferedSweepingProvider unable to dequeue: %v", err) + continue + } + if len(res) < s.batchSize { + // Queue was fully drained. + emptyQueue = true + } + ops, err := getOperations(res) + if err != nil { + logger.Warnf("BufferedSweepingProvider unable to parse dequeued item: %v", err) + continue + } + // Execute the 4 kinds of queued provider operations on the underlying + // provider. + + // Process `StartProviding` (force=true) ops first, so that if + // `StartProviding` (force=false) is called after, there is no need to + // enqueue the multihash a second time to the provide queue. + executeOperation(func(keys ...mh.Multihash) error { return s.provider.StartProviding(true, keys...) }, ops[forceStartProvidingOp]) + executeOperation(func(keys ...mh.Multihash) error { return s.provider.StartProviding(false, keys...) }, ops[startProvidingOp]) + executeOperation(s.provider.ProvideOnce, ops[provideOnceOp]) + // Process `StopProviding` last, so that multihashes that should have been + // provided, and then stopped provided in the same batch are provided only + // once. Don't `StopProviding` multihashes, for which `StartProviding` has + // been called after `StopProviding`. + executeOperation(s.provider.StopProviding, ops[stopProvidingOp]) + } +} + +// enqueue adds operations to the queue for asynchronous processing. +func (s *SweepingProvider) enqueue(op byte, keys ...mh.Multihash) error { + for _, h := range keys { + if err := s.queue.Put(toBytes(op, h)); err != nil { + return err + } + } + select { + case s.newItems <- struct{}{}: + default: + } + return nil +} + +// ProvideOnce enqueues multihashes for which the provider will send provider +// records out only once to the DHT swarm. It does NOT take the responsibility +// to reprovide these keys. +// +// Returns immediately after enqueuing the keys, the actual provide operation +// happens asynchronously. Returns an error if the multihashes couldn't be +// enqueued. +func (s *SweepingProvider) ProvideOnce(keys ...mh.Multihash) error { + return s.enqueue(provideOnceOp, keys...) +} + +// StartProviding adds the supplied keys to the queue of keys that will be +// provided to the DHT swarm unless they were already provided in the past. The +// keys will be periodically reprovided until StopProviding is called for the +// same keys or the keys are removed from the Keystore. +// +// If force is true, the keys are provided to the DHT swarm regardless of +// whether they were already being reprovided in the past. +// +// Returns immediately after enqueuing the keys, the actual provide operation +// happens asynchronously. Returns an error if the multihashes couldn't be +// enqueued. +func (s *SweepingProvider) StartProviding(force bool, keys ...mh.Multihash) error { + op := startProvidingOp + if force { + op = forceStartProvidingOp + } + return s.enqueue(op, keys...) +} + +// StopProviding adds the supplied multihashes to the BufferedSweepingProvider +// queue, to stop reproviding the given keys to the DHT swarm. +// +// The node stops being referred as a provider when the provider records in the +// DHT swarm expire. +// +// Returns immediately after enqueuing the keys, the actual provide operation +// happens asynchronously. Returns an error if the multihashes couldn't be +// enqueued. +func (s *SweepingProvider) StopProviding(keys ...mh.Multihash) error { + return s.enqueue(stopProvidingOp, keys...) +} + +// Clear clears the all the keys from the provide queue and returns the number +// of keys that were cleared. +// +// The keys are not deleted from the keystore, so they will continue to be +// reprovided as scheduled. +func (s *SweepingProvider) Clear() int { + return s.provider.Clear() +} + +// RefreshSchedule scans the KeyStore for any keys that are not currently +// scheduled for reproviding. If such keys are found, it schedules their +// associated keyspace region to be reprovided. +// +// This function doesn't remove prefixes that have no keys from the schedule. +// This is done automatically during the reprovide operation if a region has no +// keys. +// +// Returns an error if the provider is closed or if the node is currently +// Offline (either never bootstrapped, or disconnected since more than +// `OfflineDelay`). The schedule depends on the network size, hence recent +// network connectivity is essential. +func (s *SweepingProvider) RefreshSchedule() error { + return s.provider.RefreshSchedule() +} diff --git a/provider/buffered/provider_test.go b/provider/buffered/provider_test.go new file mode 100644 index 000000000..3c66ae56b --- /dev/null +++ b/provider/buffered/provider_test.go @@ -0,0 +1,283 @@ +//go:build go1.25 +// +build go1.25 + +package buffered + +import ( + "bytes" + "sync" + "testing" + "testing/synctest" + "time" + + "github.com/ipfs/go-datastore" + "github.com/ipfs/go-test/random" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal" + mh "github.com/multiformats/go-multihash" +) + +var _ internal.Provider = (*fakeProvider)(nil) + +type fakeProvider struct { + mu sync.Mutex + provideOnceCalls [][]mh.Multihash + startProvidingCalls []startProvidingCall + stopProvidingCalls [][]mh.Multihash + + // Signal when operations are processed + processed chan struct{} +} + +type startProvidingCall struct { + force bool + keys []mh.Multihash +} + +func (f *fakeProvider) ProvideOnce(keys ...mh.Multihash) error { + f.mu.Lock() + defer f.mu.Unlock() + if len(keys) > 0 { + f.provideOnceCalls = append(f.provideOnceCalls, keys) + if f.processed != nil { + select { + case f.processed <- struct{}{}: + default: + } + } + } + return nil +} + +func (f *fakeProvider) StartProviding(force bool, keys ...mh.Multihash) error { + f.mu.Lock() + defer f.mu.Unlock() + if len(keys) > 0 { + f.startProvidingCalls = append(f.startProvidingCalls, startProvidingCall{ + force: force, + keys: keys, + }) + if f.processed != nil { + select { + case f.processed <- struct{}{}: + default: + } + } + } + return nil +} + +func (f *fakeProvider) StopProviding(keys ...mh.Multihash) error { + f.mu.Lock() + defer f.mu.Unlock() + if len(keys) > 0 { + f.stopProvidingCalls = append(f.stopProvidingCalls, keys) + if f.processed != nil { + select { + case f.processed <- struct{}{}: + default: + } + } + } + return nil +} + +func (f *fakeProvider) Clear() int { + // Unused + return 0 +} + +func (f *fakeProvider) RefreshSchedule() error { + // Unused + return nil +} + +func (f *fakeProvider) Close() error { + // Unused + return nil +} + +func newFakeProvider() *fakeProvider { + return &fakeProvider{ + processed: make(chan struct{}, 10), // Buffered channel for test signaling + } +} + +func TestQueueingMechanism(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + fake := newFakeProvider() + ds := datastore.NewMapDatastore() + provider := New(fake, ds, + WithDsName("test1"), + WithIdleWriteTime(time.Millisecond), + WithBatchSize(10)) + defer provider.Close() + + keys := random.Multihashes(3) + + // Queue various operations + if err := provider.ProvideOnce(keys[0]); err != nil { + t.Fatalf("ProvideOnce failed: %v", err) + } + if err := provider.StartProviding(false, keys[1]); err != nil { + t.Fatalf("StartProviding failed: %v", err) + } + if err := provider.StartProviding(true, keys[2]); err != nil { + t.Fatalf("StartProviding (force) failed: %v", err) + } + if err := provider.StopProviding(keys[0]); err != nil { + t.Fatalf("StopProviding failed: %v", err) + } + + // Wait for operations to be processed by expecting 4 signals + for i := 0; i < 4; i++ { + select { + case <-fake.processed: + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for operation %d to be processed", i+1) + } + } + + // Verify all operations were dequeued and processed + if len(fake.provideOnceCalls) != 1 { + t.Errorf("Expected 1 ProvideOnce call, got %d", len(fake.provideOnceCalls)) + } else if len(fake.provideOnceCalls[0]) != 1 || !bytes.Equal(fake.provideOnceCalls[0][0], keys[0]) { + t.Errorf("Expected ProvideOnce call with keys[0], got %v", fake.provideOnceCalls[0]) + } + + if len(fake.startProvidingCalls) != 2 { + t.Errorf("Expected 2 StartProviding calls, got %d", len(fake.startProvidingCalls)) + } else { + // Check that we have one force=true call and one force=false call + foundForce := false + foundRegular := false + for _, call := range fake.startProvidingCalls { + if call.force { + foundForce = true + if len(call.keys) != 1 || !bytes.Equal(call.keys[0], keys[2]) { + t.Errorf("Expected force StartProviding call with keys[2], got %v", call.keys) + } + } else { + foundRegular = true + if len(call.keys) != 1 || !bytes.Equal(call.keys[0], keys[1]) { + t.Errorf("Expected regular StartProviding call with keys[1], got %v", call.keys) + } + } + } + if !foundForce { + t.Errorf("Expected to find a StartProviding call with force=true") + } + if !foundRegular { + t.Errorf("Expected to find a StartProviding call with force=false") + } + } + + if len(fake.stopProvidingCalls) != 1 { + t.Errorf("Expected 1 StopProviding call, got %d", len(fake.stopProvidingCalls)) + } else if len(fake.stopProvidingCalls[0]) != 1 || !bytes.Equal(fake.stopProvidingCalls[0][0], keys[0]) { + t.Errorf("Expected StopProviding call with keys[0], got %v", fake.stopProvidingCalls[0]) + } + }) +} + +func TestStartProvidingAfterStopProvidingRemovesStopOperation(t *testing.T) { + // Test the core logic directly by calling getOperations with known data + t.Run("DirectTest", func(t *testing.T) { + key := random.Multihashes(1)[0] + + // Create batch data that simulates StopProviding followed by StartProviding + stopData := toBytes(stopProvidingOp, key) + startData := toBytes(startProvidingOp, key) + + dequeued := [][]byte{stopData, startData} + ops, err := getOperations(dequeued) // We need to create this helper + if err != nil { + t.Fatalf("getOperations failed: %v", err) + } + + // StartProviding should be present + if len(ops[startProvidingOp]) != 1 || !bytes.Equal(ops[startProvidingOp][0], key) { + t.Errorf("Expected StartProviding operation with key, got %v", ops[startProvidingOp]) + } + + // StopProviding should be canceled (empty) + if len(ops[stopProvidingOp]) != 0 { + t.Errorf("Expected StopProviding operations to be canceled, got %v", ops[stopProvidingOp]) + } + }) +} + +func TestMultipleOperationsOnSameKey(t *testing.T) { + // Test the core batch processing logic directly + t.Run("DirectTest", func(t *testing.T) { + key := random.Multihashes(1)[0] + + // Create batch data with multiple operations on same key + ops := [][]byte{ + toBytes(stopProvidingOp, key), // StopProviding + toBytes(forceStartProvidingOp, key), // StartProviding(force=true) + toBytes(stopProvidingOp, key), // StopProviding again + toBytes(startProvidingOp, key), // StartProviding(force=false) + } + + processed, err := getOperations(ops) + if err != nil { + t.Fatalf("getOperations failed: %v", err) + } + + // Should have 2 StartProviding operations + if len(processed[startProvidingOp]) != 1 { + t.Errorf("Expected 1 StartProviding (force=false) operation, got %d", len(processed[startProvidingOp])) + } + if len(processed[forceStartProvidingOp]) != 1 { + t.Errorf("Expected 1 StartProviding (force=true) operation, got %d", len(processed[forceStartProvidingOp])) + } + + // StopProviding should be canceled (empty) because StartProviding operations were in same batch + if len(processed[stopProvidingOp]) != 0 { + t.Errorf("Expected 0 StopProviding operations (should be canceled), got %d", len(processed[stopProvidingOp])) + } + }) +} + +func TestBatchProcessing(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + fake := newFakeProvider() + ds := datastore.NewMapDatastore() + provider := New(fake, ds, + WithDsName("test4"), + WithBatchSize(3), // Process 3 operations at once + WithIdleWriteTime(time.Second)) + defer provider.Close() + + // Queue multiple keys - total of 3 operations (2 from ProvideOnce + 1 from StartProviding) + keys := random.Multihashes(3) + + if err := provider.ProvideOnce(keys[0], keys[1]); err != nil { + t.Fatalf("ProvideOnce failed: %v", err) + } + if err := provider.StartProviding(false, keys[2]); err != nil { + t.Fatalf("StartProviding failed: %v", err) + } + synctest.Wait() + + // Close to ensure all operations are flushed + provider.Close() + + // Verify operations were batched correctly + totalProvideOnceCalls := 0 + for _, call := range fake.provideOnceCalls { + totalProvideOnceCalls += len(call) + } + if totalProvideOnceCalls != 2 { + t.Errorf("Expected 2 total keys in ProvideOnce calls, got %d", totalProvideOnceCalls) + } + + totalStartProvidingCalls := 0 + for _, call := range fake.startProvidingCalls { + totalStartProvidingCalls += len(call.keys) + } + if totalStartProvidingCalls != 1 { + t.Errorf("Expected 1 total key in StartProviding calls, got %d", totalStartProvidingCalls) + } + }) +} diff --git a/provider/dual/options.go b/provider/dual/options.go new file mode 100644 index 000000000..e657aeacb --- /dev/null +++ b/provider/dual/options.go @@ -0,0 +1,334 @@ +package dual + +import ( + "errors" + "fmt" + "time" + + ds "github.com/ipfs/go-datastore" + "github.com/libp2p/go-libp2p-kad-dht/amino" + "github.com/libp2p/go-libp2p-kad-dht/dual" + pb "github.com/libp2p/go-libp2p-kad-dht/pb" + "github.com/libp2p/go-libp2p-kad-dht/provider" + "github.com/libp2p/go-libp2p-kad-dht/provider/keystore" +) + +const ( + lanID uint8 = iota + wanID +) + +type config struct { + keystore keystore.Keystore + + reprovideInterval [2]time.Duration // [0] = LAN, [1] = WAN + maxReprovideDelay [2]time.Duration + + offlineDelay [2]time.Duration + connectivityCheckOnlineInterval [2]time.Duration + connectivityCheckOfflineInterval [2]time.Duration + + maxWorkers [2]int + dedicatedPeriodicWorkers [2]int + dedicatedBurstWorkers [2]int + maxProvideConnsPerWorker [2]int + + msgSenders [2]pb.MessageSender +} + +type Option func(opt *config) error + +func (cfg *config) apply(opts ...Option) error { + for i, o := range opts { + if err := o(cfg); err != nil { + return fmt.Errorf("dual dht provider option %d failed: %w", i, err) + } + } + return nil +} + +func (cfg *config) resolveDefaults(d *dual.DHT) { + if cfg.msgSenders[lanID] == nil { + cfg.msgSenders[lanID] = d.LAN.MessageSender() + } + if cfg.msgSenders[wanID] == nil { + cfg.msgSenders[wanID] = d.WAN.MessageSender() + } +} + +func (c *config) validate() error { + if c.dedicatedPeriodicWorkers[lanID]+c.dedicatedBurstWorkers[lanID] > c.maxWorkers[lanID] { + return errors.New("provider config: total dedicated workers exceed max workers") + } + if c.dedicatedPeriodicWorkers[wanID]+c.dedicatedBurstWorkers[wanID] > c.maxWorkers[wanID] { + return errors.New("provider config: total dedicated workers exceed max workers") + } + return nil +} + +var DefaultConfig = func(cfg *config) error { + var err error + cfg.keystore, err = keystore.NewKeystore(ds.NewMapDatastore()) + if err != nil { + return err + } + + cfg.reprovideInterval = [2]time.Duration{amino.DefaultReprovideInterval, amino.DefaultReprovideInterval} + cfg.maxReprovideDelay = [2]time.Duration{provider.DefaultMaxReprovideDelay, provider.DefaultMaxReprovideDelay} + + cfg.offlineDelay = [2]time.Duration{provider.DefaultOfflineDelay, provider.DefaultOfflineDelay} + cfg.connectivityCheckOnlineInterval = [2]time.Duration{provider.DefaultConnectivityCheckOnlineInterval, provider.DefaultConnectivityCheckOnlineInterval} + + cfg.maxWorkers = [2]int{4, 4} + cfg.dedicatedPeriodicWorkers = [2]int{2, 2} + cfg.dedicatedBurstWorkers = [2]int{1, 1} + cfg.maxProvideConnsPerWorker = [2]int{20, 20} + + return nil +} + +func WithKeystore(ks keystore.Keystore) Option { + return func(cfg *config) error { + if ks == nil { + return errors.New("provider config: keystore cannot be nil") + } + cfg.keystore = ks + return nil + } +} + +func withReprovideInterval(reprovideInterval time.Duration, dhts ...uint8) Option { + return func(cfg *config) error { + if reprovideInterval <= 0 { + return fmt.Errorf("reprovide interval must be positive, got %s", reprovideInterval) + } + for _, dht := range dhts { + cfg.reprovideInterval[dht] = reprovideInterval + } + return nil + } +} + +func WithReprovideInterval(reprovideInterval time.Duration) Option { + return withReprovideInterval(reprovideInterval, lanID, wanID) +} + +func WithReprovideIntervalLAN(reprovideInterval time.Duration) Option { + return withReprovideInterval(reprovideInterval, lanID) +} + +func WithReprovideIntervalWAN(reprovideInterval time.Duration) Option { + return withReprovideInterval(reprovideInterval, wanID) +} + +func withMaxReprovideDelay(maxReprovideDelay time.Duration, dhts ...uint8) Option { + return func(cfg *config) error { + if maxReprovideDelay <= 0 { + return fmt.Errorf("max reprovide delay must be positive, got %s", maxReprovideDelay) + } + for _, dht := range dhts { + cfg.maxReprovideDelay[dht] = maxReprovideDelay + } + return nil + } +} + +func WithMaxReprovideDelay(maxReprovideDelay time.Duration) Option { + return withMaxReprovideDelay(maxReprovideDelay, lanID, wanID) +} + +func WithMaxReprovideDelayLAN(maxReprovideDelay time.Duration) Option { + return withMaxReprovideDelay(maxReprovideDelay, lanID) +} + +func WithMaxReprovideDelayWAN(maxReprovideDelay time.Duration) Option { + return withMaxReprovideDelay(maxReprovideDelay, wanID) +} + +func withOfflineDelay(offlineDelay time.Duration, dhts ...uint8) Option { + return func(cfg *config) error { + if offlineDelay < 0 { + return fmt.Errorf("invalid offline delay %s", offlineDelay) + } + for _, dht := range dhts { + cfg.offlineDelay[dht] = offlineDelay + } + return nil + } +} + +func WithOfflineDelay(offlineDelay time.Duration) Option { + return withOfflineDelay(offlineDelay, lanID, wanID) +} + +func WithOfflineDelayLAN(offlineDelay time.Duration) Option { + return withOfflineDelay(offlineDelay, lanID) +} + +func WithOfflineDelayWAN(offlineDelay time.Duration) Option { + return withOfflineDelay(offlineDelay, wanID) +} + +func withConnectivityCheckOnlineInterval(onlineInterval time.Duration, dhts ...uint8) Option { + return func(cfg *config) error { + if onlineInterval <= 0 { + return fmt.Errorf("invalid connectivity check online interval %s", onlineInterval) + } + for _, dht := range dhts { + cfg.connectivityCheckOnlineInterval[dht] = onlineInterval + } + return nil + } +} + +func WithConnectivityCheckOnlineInterval(onlineInterval time.Duration) Option { + return withConnectivityCheckOnlineInterval(onlineInterval, lanID, wanID) +} + +func WithConnectivityCheckOnlineIntervalLAN(onlineInterval time.Duration) Option { + return withConnectivityCheckOnlineInterval(onlineInterval, lanID) +} + +func WithConnectivityCheckOnlineIntervalWAN(onlineInterval time.Duration) Option { + return withConnectivityCheckOnlineInterval(onlineInterval, wanID) +} + +func withConnectivityCheckOfflineInterval(offlineInterval time.Duration, dhts ...uint8) Option { + return func(cfg *config) error { + if offlineInterval <= 0 { + return fmt.Errorf("invalid connectivity check offline interval %s", offlineInterval) + } + for _, dht := range dhts { + cfg.connectivityCheckOfflineInterval[dht] = offlineInterval + } + return nil + } +} + +func WithConnectivityCheckOfflineInterval(offlineInterval time.Duration) Option { + return withConnectivityCheckOfflineInterval(offlineInterval, lanID, wanID) +} + +func WithConnectivityCheckOfflineIntervalLAN(offlineInterval time.Duration) Option { + return withConnectivityCheckOfflineInterval(offlineInterval, lanID) +} + +func WithConnectivityCheckOfflineIntervalWAN(offlineInterval time.Duration) Option { + return withConnectivityCheckOfflineInterval(offlineInterval, wanID) +} + +func withMaxWorkers(maxWorkers int, dhts ...uint8) Option { + return func(cfg *config) error { + if maxWorkers <= 0 { + return fmt.Errorf("invalid max workers %d", maxWorkers) + } + for _, dht := range dhts { + cfg.maxWorkers[dht] = maxWorkers + } + return nil + } +} + +func WithMaxWorkers(maxWorkers int) Option { + return withMaxWorkers(maxWorkers, lanID, wanID) +} + +func WithMaxWorkersLAN(maxWorkers int) Option { + return withMaxWorkers(maxWorkers, lanID) +} + +func WithMaxWorkersWAN(maxWorkers int) Option { + return withMaxWorkers(maxWorkers, wanID) +} + +func withDedicatedPeriodicWorkers(dedicatedPeriodicWorkers int, dhts ...uint8) Option { + return func(cfg *config) error { + if dedicatedPeriodicWorkers < 0 { + return fmt.Errorf("invalid dedicated periodic workers %d", dedicatedPeriodicWorkers) + } + for _, dht := range dhts { + cfg.dedicatedPeriodicWorkers[dht] = dedicatedPeriodicWorkers + } + return nil + } +} + +func WithDedicatedPeriodicWorkers(dedicatedPeriodicWorkers int) Option { + return withDedicatedPeriodicWorkers(dedicatedPeriodicWorkers, lanID, wanID) +} + +func WithDedicatedPeriodicWorkersLAN(dedicatedPeriodicWorkers int) Option { + return withDedicatedPeriodicWorkers(dedicatedPeriodicWorkers, lanID) +} + +func WithDedicatedPeriodicWorkersWAN(dedicatedPeriodicWorkers int) Option { + return withDedicatedPeriodicWorkers(dedicatedPeriodicWorkers, wanID) +} + +func withDedicatedBurstWorkers(dedicatedBurstWorkers int, dhts ...uint8) Option { + return func(cfg *config) error { + if dedicatedBurstWorkers < 0 { + return fmt.Errorf("invalid dedicated burst workers %d", dedicatedBurstWorkers) + } + for _, dht := range dhts { + cfg.dedicatedBurstWorkers[dht] = dedicatedBurstWorkers + } + return nil + } +} + +func WithDedicatedBurstWorkers(dedicatedBurstWorkers int) Option { + return withDedicatedBurstWorkers(dedicatedBurstWorkers, lanID, wanID) +} + +func WithDedicatedBurstWorkersLAN(dedicatedBurstWorkers int) Option { + return withDedicatedBurstWorkers(dedicatedBurstWorkers, lanID) +} + +func WithDedicatedBurstWorkersWAN(dedicatedBurstWorkers int) Option { + return withDedicatedBurstWorkers(dedicatedBurstWorkers, wanID) +} + +func withMaxProvideConnsPerWorker(maxProvideConnsPerWorker int, dhts ...uint8) Option { + return func(cfg *config) error { + if maxProvideConnsPerWorker <= 0 { + return fmt.Errorf("invalid max provide conns per worker %d", maxProvideConnsPerWorker) + } + for _, dht := range dhts { + cfg.maxProvideConnsPerWorker[dht] = maxProvideConnsPerWorker + } + return nil + } +} + +func WithMaxProvideConnsPerWorker(maxProvideConnsPerWorker int) Option { + return withMaxProvideConnsPerWorker(maxProvideConnsPerWorker, lanID, wanID) +} + +func WithMaxProvideConnsPerWorkerLAN(maxProvideConnsPerWorker int) Option { + return withMaxProvideConnsPerWorker(maxProvideConnsPerWorker, lanID) +} + +func WithMaxProvideConnsPerWorkerWAN(maxProvideConnsPerWorker int) Option { + return withMaxProvideConnsPerWorker(maxProvideConnsPerWorker, wanID) +} + +func withMessageSender(msgSender pb.MessageSender, dhts ...uint8) Option { + return func(cfg *config) error { + if msgSender == nil { + return errors.New("provider config: message sender cannot be nil") + } + for _, dht := range dhts { + cfg.msgSenders[dht] = msgSender + } + return nil + } +} + +func WithMessageSenderLAN(msgSender pb.MessageSender) Option { + return withMessageSender(msgSender, lanID) +} + +func WithMessageSenderWAN(msgSender pb.MessageSender) Option { + return withMessageSender(msgSender, wanID) +} diff --git a/provider/dual/provider.go b/provider/dual/provider.go new file mode 100644 index 000000000..6a1ac463c --- /dev/null +++ b/provider/dual/provider.go @@ -0,0 +1,202 @@ +package dual + +import ( + "context" + "errors" + "fmt" + + "github.com/ipfs/go-cid" + dht "github.com/libp2p/go-libp2p-kad-dht" + "github.com/libp2p/go-libp2p-kad-dht/dual" + "github.com/libp2p/go-libp2p-kad-dht/provider" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal" + "github.com/libp2p/go-libp2p-kad-dht/provider/keystore" + mh "github.com/multiformats/go-multihash" +) + +var _ internal.Provider = (*SweepingProvider)(nil) + +// SweepingProvider manages provides and reprovides for both DHT swarms (LAN +// and WAN) in the dual DHT setup. +type SweepingProvider struct { + dht *dual.DHT + LAN *provider.SweepingProvider + WAN *provider.SweepingProvider + keystore keystore.Keystore +} + +// New creates a new SweepingProvider that manages provides and reprovides for +// both DHT swarms (LAN and WAN) in a dual DHT setup. +func New(d *dual.DHT, opts ...Option) (*SweepingProvider, error) { + if d == nil || d.LAN == nil || d.WAN == nil { + return nil, errors.New("cannot create sweeping provider for nil dual DHT") + } + + var cfg config + err := cfg.apply(append([]Option{DefaultConfig}, opts...)...) + if err != nil { + return nil, err + } + cfg.resolveDefaults(d) + err = cfg.validate() + if err != nil { + return nil, err + } + + sweepingProviders := make([]*provider.SweepingProvider, 2) + for i, dht := range []*dht.IpfsDHT{d.LAN, d.WAN} { + if dht == nil { + continue + } + dhtOpts := []provider.Option{ + provider.WithPeerID(dht.PeerID()), + provider.WithReplicationFactor(dht.BucketSize()), + provider.WithSelfAddrs(dht.FilteredAddrs), + provider.WithRouter(dht), + provider.WithAddLocalRecord(func(h mh.Multihash) error { + return dht.Provide(dht.Context(), cid.NewCidV1(cid.Raw, h), false) + }), + provider.WithKeystore(cfg.keystore), + provider.WithMessageSender(cfg.msgSenders[i]), + provider.WithReprovideInterval(cfg.reprovideInterval[i]), + provider.WithMaxReprovideDelay(cfg.maxReprovideDelay[i]), + provider.WithOfflineDelay(cfg.offlineDelay[i]), + provider.WithConnectivityCheckOnlineInterval(cfg.connectivityCheckOnlineInterval[i]), + provider.WithMaxWorkers(cfg.maxWorkers[i]), + provider.WithDedicatedPeriodicWorkers(cfg.dedicatedPeriodicWorkers[i]), + provider.WithDedicatedBurstWorkers(cfg.dedicatedBurstWorkers[i]), + provider.WithMaxProvideConnsPerWorker(cfg.maxProvideConnsPerWorker[i]), + } + sweepingProviders[i], err = provider.New(dhtOpts...) + if err != nil { + return nil, err + } + } + + return &SweepingProvider{ + dht: d, + LAN: sweepingProviders[0], + WAN: sweepingProviders[1], + keystore: cfg.keystore, + }, nil +} + +// runOnBoth runs the provided function on both the LAN and WAN providers in +// parallel and waits for both to complete. +func (s *SweepingProvider) runOnBoth(f func(*provider.SweepingProvider) error) error { + errCh := make(chan error, 1) + go func() { + err := f(s.LAN) + if err != nil { + err = fmt.Errorf("LAN provider: %w", err) + } + errCh <- err + }() + err := f(s.WAN) + if err != nil { + err = fmt.Errorf("WAN provider: %w", err) + } + lanErr := <-errCh + return errors.Join(lanErr, err) +} + +// Close stops both DHT providers and releases associated resources. +func (s *SweepingProvider) Close() error { + return s.runOnBoth(func(p *provider.SweepingProvider) error { + return p.Close() + }) +} + +// ProvideOnce sends provider records for the specified keys to both DHT swarms +// only once. It does not automatically reprovide those keys afterward. +// +// Add the supplied multihashes to the provide queues, and return right after. +// The provide operation happens asynchronously. +// +// Returns an error if the keys couldn't be added to the provide queue. This +// can happen if the provider is closed or if the node is currently Offline +// (either never bootstrapped, or disconnected since more than `OfflineDelay`). +// The schedule and provide queue depend on the network size, hence recent +// network connectivity is essential. +func (s *SweepingProvider) ProvideOnce(keys ...mh.Multihash) error { + return s.runOnBoth(func(p *provider.SweepingProvider) error { + return p.ProvideOnce(keys...) + }) +} + +// StartProviding ensures keys are periodically advertised to both DHT swarms. +// +// If the `keys` aren't currently being reprovided, they are added to the +// queue to be provided to the DHT swarm as soon as possible, and scheduled +// to be reprovided periodically. If `force` is set to true, all keys are +// provided to the DHT swarm, regardless of whether they were already being +// reprovided in the past. `keys` keep being reprovided until `StopProviding` +// is called. +// +// This operation is asynchronous, it returns as soon as the `keys` are added +// to the provide queue, and provides happens asynchronously. +// +// Returns an error if the keys couldn't be added to the provide queue. This +// can happen if the provider is closed or if the node is currently Offline +// (either never bootstrapped, or disconnected since more than `OfflineDelay`). +// The schedule and provide queue depend on the network size, hence recent +// network connectivity is essential. +func (s *SweepingProvider) StartProviding(force bool, keys ...mh.Multihash) error { + ctx := context.Background() + newKeys, err := s.keystore.Put(ctx, keys...) + if err != nil { + return fmt.Errorf("failed to store multihashes: %w", err) + } + + s.runOnBoth(func(p *provider.SweepingProvider) error { + return p.AddToSchedule(newKeys...) + }) + + if !force { + keys = newKeys + } + + return s.ProvideOnce(keys...) +} + +// StopProviding stops reproviding the given keys to both DHT swarms. The node +// stops being referred as a provider when the provider records in the DHT +// swarms expire. +// +// Remove the `keys` from the schedule and return immediately. Valid records +// can remain in the DHT swarms up to the provider record TTL after calling +// `StopProviding`. +func (s *SweepingProvider) StopProviding(keys ...mh.Multihash) error { + err := s.keystore.Delete(context.Background(), keys...) + if err != nil { + return fmt.Errorf("failed to stop providing keys: %w", err) + } + return nil +} + +// Clear clears the all the keys from the provide queues of both DHTs and +// returns the number of keys that were cleared (sum of both queues). +// +// The keys are not deleted from the keystore, so they will continue to be +// reprovided as scheduled. +func (s *SweepingProvider) Clear() int { + return s.LAN.Clear() + s.WAN.Clear() +} + +// RefreshSchedule scans the Keystore for any keys that are not currently +// scheduled for reproviding. If such keys are found, it schedules their +// associated keyspace region to be reprovided for both DHT providers. +// +// This function doesn't remove prefixes that have no keys from the schedule. +// This is done automatically during the reprovide operation if a region has no +// keys. +// +// Returns an error if the provider is closed or if the node is currently +// Offline (either never bootstrapped, or disconnected since more than +// `OfflineDelay`). The schedule depends on the network size, hence recent +// network connectivity is essential. +func (s *SweepingProvider) RefreshSchedule() error { + return s.runOnBoth(func(p *provider.SweepingProvider) error { + return p.RefreshSchedule() + }) +} diff --git a/provider/internal/connectivity/connectivity.go b/provider/internal/connectivity/connectivity.go new file mode 100644 index 000000000..495b9a32a --- /dev/null +++ b/provider/internal/connectivity/connectivity.go @@ -0,0 +1,216 @@ +package connectivity + +import ( + "sync" + "sync/atomic" + "time" +) + +const ( + initialBackoffDelay = 100 * time.Millisecond + maxBackoffDelay = time.Minute +) + +// ConnectivityChecker provides a thread-safe way to verify the connectivity of +// a node, and triggers wake-up callbacks when the node changes connectivity +// state. The `checkFunc` callback used to verify network connectivity is user +// supplied. +// +// State Machine starting in OFFLINE state (when `Start()` is called) +// 1. OFFLINE state: +// - Calls `checkFunc` with exponential backoff until node is found ONLINE. +// - Calls to `TriggerCheck()` are ignored while OFFLINE. +// - When `checkFunc` returns true, state changes to ONLINE and +// `onOnline()` callback is called. +// 2. ONLINE state: +// - Calls to `TriggerCheck()` will call `checkFunc` only if at least +// `onlineCheckInterval` has passed since the last check. +// - If `TriggerCheck()` returns false, switch state to DISCONNECTED. +// 3. DISCONNECTED state: +// - Calls `checkFunc` with exponential backoff until node is found ONLINE. +// - Calls to `TriggerCheck()` are ignored while DISCONNECTED. +// - When `checkFunc` returns true, state changes to ONLINE and +// `onOnline()` callback is called. +// - After `offlineDelay` has passed in DISCONNECTED state, state changes +// to OFFLINE and `onOffline()` callback is called. +type ConnectivityChecker struct { + done chan struct{} + closed bool + closeOnce sync.Once + mutex sync.Mutex + + online atomic.Bool + + lastCheck time.Time + onlineCheckInterval time.Duration // minimum check interval when online + + checkFunc func() bool // function to check whether node is online + + onOffline func() + onOnline func() + offlineDelay time.Duration +} + +// New creates a new ConnectivityChecker instance. +func New(checkFunc func() bool, opts ...Option) (*ConnectivityChecker, error) { + var cfg config + err := cfg.apply(append([]Option{DefaultConfig}, opts...)...) + if err != nil { + return nil, err + } + c := &ConnectivityChecker{ + done: make(chan struct{}), + checkFunc: checkFunc, + onlineCheckInterval: cfg.onlineCheckInterval, + onOffline: cfg.onOffline, + onOnline: cfg.onOnline, + offlineDelay: cfg.offlineDelay, + } + return c, nil +} + +// SetCallbacks sets the onOnline and onOffline callbacks after construction. +// This allows breaking circular dependencies during initialization. +// +// SetCallbacks must be called before Start(). +func (c *ConnectivityChecker) SetCallbacks(onOnline, onOffline func()) { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.closed { + return + } + c.onOnline = onOnline + c.onOffline = onOffline +} + +// Start the ConnectivityChecker in Offline state, by begining connectivity +// probes, until the node is found Online. +// +// If SetCallbacks() is used, Start() must be called after SetCallbacks(). +func (c *ConnectivityChecker) Start() { + c.mutex.Lock() + // Start probing until the node comes online + go func() { + defer c.mutex.Unlock() + + if c.probe() { + // Node is already online + return + } + // Wait for node to come online + c.probeLoop(true) + }() +} + +// Close stops any running connectivity checks and prevents future ones. +func (c *ConnectivityChecker) Close() error { + c.closeOnce.Do(func() { + close(c.done) + c.mutex.Lock() + c.closed = true + c.mutex.Unlock() + }) + return nil +} + +// IsOnline returns true if the node is currently online, false otherwise. +func (c *ConnectivityChecker) IsOnline() bool { + return c.online.Load() +} + +// TriggerCheck triggers an asynchronous connectivity check. +// +// * If a check is already running, does nothing. +// * If a check was already performed within the last `onlineCheckInterval`, does nothing. +// * If after running the check the node is still online, update the last check timestamp. +// * If the node is found offline, enter the loop: +// - Perform connectivity check every `offlineCheckInterval`. +// - Exit if context is cancelled, or ConnectivityChecker is closed. +// - When node is found back online, run the `backOnlineNotify` callback. +func (c *ConnectivityChecker) TriggerCheck() { + if !c.mutex.TryLock() { + return // check already in progress + } + if c.closed { + c.mutex.Unlock() + return + } + if c.online.Load() && time.Since(c.lastCheck) < c.onlineCheckInterval { + c.mutex.Unlock() + return // last check was too recent + } + + go func() { + defer c.mutex.Unlock() + + if c.checkFunc() { + c.lastCheck = time.Now() + return + } + + // Online -> Disconnected + c.online.Store(false) + + // Start periodic checks until node comes back Online + c.probeLoop(false) + }() +} + +// probeLoop runs connectivity probes with exponential backoff until the node +// comes back Online, or the ConnectivityChecker is closed. +func (c *ConnectivityChecker) probeLoop(init bool) { + var offlineC <-chan time.Time + if !init { + if c.offlineDelay == 0 { + if c.onOffline != nil { + // Online -> Offline + c.onOffline() + } + } else { + offlineTimer := time.NewTimer(c.offlineDelay) + defer offlineTimer.Stop() + offlineC = offlineTimer.C + } + } + + delay := initialBackoffDelay + timer := time.NewTimer(delay) + defer timer.Stop() + for { + select { + case <-c.done: + return + case <-timer.C: + if c.probe() { + return + } + delay = min(2*delay, maxBackoffDelay) + timer.Reset(delay) + case <-offlineC: + // Disconnected -> Offline + if c.onOffline != nil { + c.onOffline() + } + } + } +} + +// probe runs the connectivity check function once, and if the node is found +// Online, updates the state and runs the onOnline callback. +func (c *ConnectivityChecker) probe() bool { + if c.checkFunc() { + select { + case <-c.done: + default: + // Node is back Online. + c.online.Store(true) + + c.lastCheck = time.Now() + if c.onOnline != nil { + c.onOnline() + } + } + return true + } + return false +} diff --git a/provider/internal/connectivity/connectivity_test.go b/provider/internal/connectivity/connectivity_test.go new file mode 100644 index 000000000..c30de2f2c --- /dev/null +++ b/provider/internal/connectivity/connectivity_test.go @@ -0,0 +1,409 @@ +//go:build go1.25 +// +build go1.25 + +package connectivity + +import ( + "sync/atomic" + "testing" + "testing/synctest" + "time" + + "github.com/stretchr/testify/require" +) + +var ( + onlineCheckFunc = func() bool { return true } + offlineCheckFunc = func() bool { return false } +) + +func TestNewConnectiviyChecker(t *testing.T) { + t.Run("initial state is offline", func(t *testing.T) { + connChecker, err := New(onlineCheckFunc) + require.NoError(t, err) + defer connChecker.Close() + + require.False(t, connChecker.IsOnline()) + }) + + t.Run("start online", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + onlineChan := make(chan struct{}) + onOnline := func() { close(onlineChan) } + + connChecker, err := New(onlineCheckFunc, + WithOnOnline(onOnline), + ) + require.NoError(t, err) + defer connChecker.Close() + + require.False(t, connChecker.IsOnline()) + + connChecker.Start() + + <-onlineChan // wait for onOnline to be run + synctest.Wait() + + require.True(t, connChecker.IsOnline()) + }) + }) + + t.Run("start offline", func(t *testing.T) { + onlineCount, offlineCount := atomic.Int32{}, atomic.Int32{} + onOnline := func() { onlineCount.Add(1) } + onOffline := func() { offlineCount.Add(1) } + + connChecker, err := New(offlineCheckFunc, + WithOnOnline(onOnline), + WithOnOffline(onOffline), + ) + require.NoError(t, err) + defer connChecker.Close() + + require.False(t, connChecker.IsOnline()) + + connChecker.Start() + + require.False(t, connChecker.mutex.TryLock()) // node probing until it comes online + + require.False(t, connChecker.IsOnline()) + require.Equal(t, int32(0), onlineCount.Load()) + require.Equal(t, int32(0), offlineCount.Load()) + }) +} + +func TestStateTransitions(t *testing.T) { + t.Run("offline to online", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + checkInterval := time.Second + offlineDelay := time.Minute + + online := atomic.Bool{} // start offline + checkFunc := func() bool { return online.Load() } + + onlineChan, offlineChan := make(chan struct{}), make(chan struct{}) + onOnline := func() { close(onlineChan) } + onOffline := func() { close(offlineChan) } + + connChecker, err := New(checkFunc, + WithOfflineDelay(offlineDelay), + WithOnlineCheckInterval(checkInterval), + WithOnOnline(onOnline), + WithOnOffline(onOffline), + ) + require.NoError(t, err) + defer connChecker.Close() + + require.False(t, connChecker.IsOnline()) + connChecker.Start() + + time.Sleep(initialBackoffDelay) + + online.Store(true) + + <-onlineChan // wait for onOnline to be run + require.True(t, connChecker.IsOnline()) + select { + case <-offlineChan: + require.FailNow(t, "onOffline shouldn't have been called") + default: + } + }) + }) + + t.Run("online to disconnected to offline", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + checkInterval := time.Second + offlineDelay := time.Minute + + online := atomic.Bool{} + online.Store(true) + checkFunc := func() bool { return online.Load() } + + onlineChan, offlineChan := make(chan struct{}), make(chan struct{}) + onOnline := func() { close(onlineChan) } + onOffline := func() { close(offlineChan) } + + connChecker, err := New(checkFunc, + WithOfflineDelay(offlineDelay), + WithOnlineCheckInterval(checkInterval), + WithOnOnline(onOnline), + WithOnOffline(onOffline), + ) + require.NoError(t, err) + defer connChecker.Close() + + require.False(t, connChecker.IsOnline()) + connChecker.Start() + + <-onlineChan // wait for onOnline to be run + require.True(t, connChecker.IsOnline()) + require.Equal(t, time.Now(), connChecker.lastCheck) + + online.Store(false) + // Cannot trigger check yet + connChecker.TriggerCheck() + require.True(t, connChecker.mutex.TryLock()) // node still online + connChecker.mutex.Unlock() + + time.Sleep(checkInterval - time.Millisecond) + connChecker.TriggerCheck() + require.True(t, connChecker.mutex.TryLock()) // node still online + connChecker.mutex.Unlock() + + time.Sleep(time.Millisecond) + connChecker.TriggerCheck() + require.False(t, connChecker.mutex.TryLock()) + + synctest.Wait() + + require.False(t, connChecker.IsOnline()) + select { + case <-offlineChan: + require.FailNow(t, "onOffline shouldn't have been called") + default: // Disconnected but not Offline + } + + connChecker.TriggerCheck() // noop since Disconnected + require.False(t, connChecker.mutex.TryLock()) + + time.Sleep(offlineDelay) + + require.False(t, connChecker.IsOnline()) + <-offlineChan // wait for callback to be run + + connChecker.TriggerCheck() // noop since Offline + require.False(t, connChecker.mutex.TryLock()) + }) + }) + + t.Run("remain online", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + checkInterval := time.Second + offlineDelay := time.Minute + + online := atomic.Bool{} + online.Store(true) + checkCount := atomic.Int32{} + checkFunc := func() bool { checkCount.Add(1); return online.Load() } + + onlineChan, offlineChan := make(chan struct{}), make(chan struct{}) + onOnline := func() { close(onlineChan) } + onOffline := func() { close(offlineChan) } + + connChecker, err := New(checkFunc, + WithOfflineDelay(offlineDelay), + WithOnlineCheckInterval(checkInterval), + WithOnOnline(onOnline), + WithOnOffline(onOffline), + ) + require.NoError(t, err) + defer connChecker.Close() + + require.False(t, connChecker.IsOnline()) + connChecker.Start() + + <-onlineChan + + require.True(t, connChecker.IsOnline()) + require.Equal(t, int32(1), checkCount.Load()) + require.Equal(t, time.Now(), connChecker.lastCheck) + + connChecker.TriggerCheck() // recent check, should be no-op + synctest.Wait() + require.Equal(t, int32(1), checkCount.Load()) + + time.Sleep(checkInterval - 1) + connChecker.TriggerCheck() // recent check, should be no-op + synctest.Wait() + require.Equal(t, int32(1), checkCount.Load()) + + time.Sleep(1) + connChecker.TriggerCheck() // checkInterval has passed, new check is run + synctest.Wait() + require.Equal(t, int32(2), checkCount.Load()) + require.Equal(t, time.Now(), connChecker.lastCheck) + + time.Sleep(checkInterval) + connChecker.TriggerCheck() // checkInterval has passed, new check is run + synctest.Wait() + require.Equal(t, int32(3), checkCount.Load()) + require.Equal(t, time.Now(), connChecker.lastCheck) + }) + }) +} + +func TestSetCallbacks(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Callbacks MUST be set before calling Start() + oldOnlineCount, oldOfflineCount, newOnlineCount, newOfflineCount := atomic.Int32{}, atomic.Int32{}, atomic.Int32{}, atomic.Int32{} + onlineChan, offlineChan := make(chan struct{}), make(chan struct{}) + oldOnOnline := func() { oldOnlineCount.Add(1); close(onlineChan) } + oldOnOffline := func() { oldOfflineCount.Add(1); close(offlineChan) } + newOnOnline := func() { newOnlineCount.Add(1); close(onlineChan) } + newOnOffline := func() { newOfflineCount.Add(1); close(offlineChan) } + + checkInterval := time.Second + online := atomic.Bool{} + online.Store(true) + checkFunc := func() bool { return online.Load() } + + connChecker, err := New(checkFunc, + WithOnOnline(oldOnOnline), + WithOnOffline(oldOnOffline), + WithOfflineDelay(0), + WithOnlineCheckInterval(checkInterval), + ) + require.NoError(t, err) + defer connChecker.Close() + + connChecker.SetCallbacks(newOnOnline, newOnOffline) + + connChecker.Start() + + <-onlineChan // wait for newOnOnline to be called + require.True(t, connChecker.IsOnline()) + require.Equal(t, int32(0), oldOnlineCount.Load()) + require.Equal(t, int32(1), newOnlineCount.Load()) + + // Wait until we can perform a new check + time.Sleep(checkInterval) + + // Go offline + online.Store(false) + connChecker.TriggerCheck() + require.False(t, connChecker.mutex.TryLock()) // node probing until it comes online + + <-offlineChan // wait for newOnOffline to be called + require.False(t, connChecker.IsOnline()) + require.Equal(t, int32(0), oldOfflineCount.Load()) + require.Equal(t, int32(1), newOfflineCount.Load()) + }) +} + +func TestExponentialBackoff(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + checkCount := atomic.Int32{} + checkFunc := func() bool { checkCount.Add(1); return false } + connChecker, err := New(checkFunc) + require.NoError(t, err) + defer connChecker.Close() + + connChecker.Start() + require.False(t, connChecker.mutex.TryLock()) // node probing until it comes online + require.False(t, connChecker.IsOnline()) + + // Exponential backoff increase + expectedWait := initialBackoffDelay + expectedChecks := int32(1) // initial check + for expectedWait < maxBackoffDelay { + synctest.Wait() + require.Equal(t, expectedChecks, checkCount.Load()) + time.Sleep(expectedWait) + expectedChecks++ + expectedWait *= 2 + } + + // Reached max backoff delay + synctest.Wait() + require.Equal(t, expectedChecks, checkCount.Load()) + + time.Sleep(maxBackoffDelay) + expectedChecks++ + synctest.Wait() + require.Equal(t, expectedChecks, checkCount.Load()) + + time.Sleep(3 * maxBackoffDelay) + expectedChecks += 3 + synctest.Wait() + require.Equal(t, expectedChecks, checkCount.Load()) + }) +} + +func TestInvalidOptions(t *testing.T) { + t.Run("negative online check interval", func(t *testing.T) { + _, err := New(onlineCheckFunc, WithOnlineCheckInterval(-1)) + require.Error(t, err) + }) + + t.Run("negative offline delay", func(t *testing.T) { + _, err := New(onlineCheckFunc, WithOfflineDelay(-1*time.Hour)) + require.Error(t, err) + }) +} + +func TestClose(t *testing.T) { + t.Run("close while offline", func(t *testing.T) { + connChecker, err := New(offlineCheckFunc) + require.NoError(t, err) + defer connChecker.Close() + + connChecker.Start() + require.False(t, connChecker.mutex.TryLock()) // node probing until it comes online + require.False(t, connChecker.IsOnline()) + + err = connChecker.Close() + require.NoError(t, err) + + require.True(t, connChecker.mutex.TryLock()) + connChecker.mutex.Unlock() + }) + + t.Run("close while online", func(t *testing.T) { + onlineChan := make(chan struct{}) + onOnline := func() { close(onlineChan) } + connChecker, err := New(onlineCheckFunc, + WithOnOnline(onOnline), + ) + require.NoError(t, err) + defer connChecker.Close() + + connChecker.Start() + <-onlineChan + require.True(t, connChecker.IsOnline()) + + connChecker.Close() + }) + + t.Run("SetCallbacks after Close", func(t *testing.T) { + onlineChan, offlineChan := make(chan struct{}), make(chan struct{}) + onOnline := func() { close(onlineChan) } + onOffline := func() { close(offlineChan) } + + connChecker, err := New(offlineCheckFunc) + require.NoError(t, err) + defer connChecker.Close() + + require.Nil(t, connChecker.onOffline) + require.Nil(t, connChecker.onOnline) + + connChecker.Close() + connChecker.SetCallbacks(onOnline, onOffline) + + // Assert that callbacks were NOT set + require.Nil(t, connChecker.onOffline) + require.Nil(t, connChecker.onOnline) + }) + + t.Run("TriggerCheck after Close", func(t *testing.T) { + connChecker, err := New(offlineCheckFunc) + require.NoError(t, err) + defer connChecker.Close() + + connChecker.Start() + require.False(t, connChecker.mutex.TryLock()) // node probing until it comes online + require.False(t, connChecker.IsOnline()) + + err = connChecker.Close() + require.NoError(t, err) + + require.True(t, connChecker.mutex.TryLock()) + connChecker.mutex.Unlock() + + connChecker.TriggerCheck() // noop since closed + + require.True(t, connChecker.mutex.TryLock()) + connChecker.mutex.Unlock() + require.False(t, connChecker.IsOnline()) + }) +} diff --git a/provider/internal/connectivity/options.go b/provider/internal/connectivity/options.go new file mode 100644 index 000000000..a67fe2a6a --- /dev/null +++ b/provider/internal/connectivity/options.go @@ -0,0 +1,69 @@ +package connectivity + +import ( + "fmt" + "time" +) + +type config struct { + onlineCheckInterval time.Duration // minimum check interval when online + + offlineDelay time.Duration + + onOffline func() + onOnline func() +} + +func (cfg *config) apply(opts ...Option) error { + for i, o := range opts { + if err := o(cfg); err != nil { + return fmt.Errorf("reprovider dht option %d failed: %w", i, err) + } + } + return nil +} + +type Option func(opt *config) error + +var DefaultConfig = func(cfg *config) error { + cfg.onlineCheckInterval = 1 * time.Minute + cfg.offlineDelay = 2 * time.Hour + return nil +} + +// WithOnlineCheckInterval sets the minimum interval between online checks. +// This is used to throttle the number of connectivity checks when the node is +// online. +func WithOnlineCheckInterval(d time.Duration) Option { + return func(cfg *config) error { + if d <= 0 { + return fmt.Errorf("online check interval must be positive, got %s", d) + } + cfg.onlineCheckInterval = d + return nil + } +} + +func WithOfflineDelay(d time.Duration) Option { + return func(cfg *config) error { + if d < 0 { + return fmt.Errorf("offline delay must be non-negative, got %s", d) + } + cfg.offlineDelay = d + return nil + } +} + +func WithOnOffline(f func()) Option { + return func(cfg *config) error { + cfg.onOffline = f + return nil + } +} + +func WithOnOnline(f func()) Option { + return func(cfg *config) error { + cfg.onOnline = f + return nil + } +} diff --git a/provider/internal/interface.go b/provider/internal/interface.go new file mode 100644 index 000000000..0c63e0aee --- /dev/null +++ b/provider/internal/interface.go @@ -0,0 +1,14 @@ +package internal + +import ( + mh "github.com/multiformats/go-multihash" +) + +type Provider interface { + StartProviding(force bool, keys ...mh.Multihash) error + StopProviding(keys ...mh.Multihash) error + ProvideOnce(keys ...mh.Multihash) error + Clear() int + RefreshSchedule() error + Close() error +} diff --git a/provider/internal/keyspace/key.go b/provider/internal/keyspace/key.go new file mode 100644 index 000000000..4b5c42646 --- /dev/null +++ b/provider/internal/keyspace/key.go @@ -0,0 +1,191 @@ +package keyspace + +import ( + "cmp" + "crypto/sha256" + "slices" + + kb "github.com/libp2p/go-libp2p-kbucket" + "github.com/libp2p/go-libp2p/core/peer" + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad" + "github.com/probe-lab/go-libdht/kad/key" + "github.com/probe-lab/go-libdht/kad/key/bit256" + "github.com/probe-lab/go-libdht/kad/key/bitstr" +) + +// KeyLen is the length of a 256-bit kademlia identifier in bits. +const KeyLen = bit256.KeyLen * 8 // 256 + +// MhToBit256 converts a multihash to a its 256-bit kademlia identifier by +// hashing it with SHA-256. +func MhToBit256(h mh.Multihash) bit256.Key { + hash := sha256.Sum256(h) + return bit256.NewKey(hash[:]) +} + +// PeerIDToBit256 converts a peer.ID to a its 256-bit kademlia identifier by +// hashing it with SHA-256. +func PeerIDToBit256(id peer.ID) bit256.Key { + hash := sha256.Sum256([]byte(id)) + return bit256.NewKey(hash[:]) +} + +// FlipLastBit flips the last bit of the given key. +func FlipLastBit(k bitstr.Key) bitstr.Key { + if len(k) == 0 { + return k + } + flipped := byte('0' + '1' - k[len(k)-1]) + return k[:len(k)-1] + bitstr.Key(flipped) +} + +// FirstFullKeyWithPrefix returns to closest 256-bit key to order, starting +// with the given k as a prefix. +func FirstFullKeyWithPrefix[K kad.Key[K]](k bitstr.Key, order K) bitstr.Key { + kLen := k.BitLen() + if kLen > KeyLen { + return k[:KeyLen] + } + return k + bitstr.Key(key.BitString(order))[kLen:] +} + +// IsBitstrPrefix returns true if k0 is a prefix of k1. +func IsBitstrPrefix(k0 bitstr.Key, k1 bitstr.Key) bool { + return len(k0) <= len(k1) && k0 == k1[:len(k0)] +} + +// IsPrefix returns true if k0 is a prefix of k1 +func IsPrefix[K0 kad.Key[K0], K1 kad.Key[K1]](k0 K0, k1 K1) bool { + if k0.BitLen() > k1.BitLen() { + return false + } + for i := range k0.BitLen() { + if k0.Bit(i) != k1.Bit(i) { + return false + } + } + return true +} + +const initMask = (byte(1) << 7) // 0x80 + +// KeyToBytes converts a kad.Key to a byte slice. If the provided key has a +// size that isn't a multiple of 8, right pad the resulting byte with 0s. +func KeyToBytes[K kad.Key[K]](k K) []byte { + bitLen := k.BitLen() + byteLen := (bitLen + 7) / 8 + b := make([]byte, byteLen) + + byteIndex := 0 + mask := initMask + by := byte(0) + + for i := range bitLen { + if k.Bit(i) == 1 { + by |= mask + } + mask >>= 1 + + if mask == 0 { + b[byteIndex] = by + byteIndex++ + by = 0 + mask = initMask + } + } + if mask != initMask { + b[byteIndex] = by + } + return b +} + +// ShortestCoveredPrefix takes as input the `target` key and the list of +// closest peers to this key. It returns a prefix of `requested` that is +// covered by these peers, along with the peers matching this prefix. +// +// We say that a set of peers fully "covers" a prefix of the global keyspace, +// if all the peers matching this prefix are included in the set. +// +// If every peer shares the same CPL to `target`, then no deeper zone is +// covered, we learn that the adjacent sibling branch is empty. In this case we +// return the prefix one bit deeper (`minCPL+1`) and an empty peer list. +func ShortestCoveredPrefix(target bitstr.Key, peers []peer.ID) (bitstr.Key, []peer.ID) { + if len(peers) == 0 { + return target, peers + } + // Sort the peers by their distance to the requested key. + peers = kb.SortClosestPeers(peers, KeyToBytes(target)) + + minCpl := target.BitLen() // key bitlen + coveredCpl := 0 + lastCoveredPeerIndex := 0 + for i, p := range peers { + cpl := key.CommonPrefixLength(target, PeerIDToBit256(p)) + if cpl < minCpl { + coveredCpl = cpl + 1 + lastCoveredPeerIndex = i + minCpl = cpl + } + } + return target[:coveredCpl], peers[:lastCoveredPeerIndex] +} + +// ExtendBinaryPrefix returns all bitstrings of length n that start with prefix. +// Example: prefix="1101", n=6 -> ["110100", "110101", "110110", "110111"]. +func ExtendBinaryPrefix(prefix bitstr.Key, n int) []bitstr.Key { + extraBits := n - len(prefix) + if n < 0 || extraBits < 0 { + return nil + } + + extLen := 1 << extraBits // 2^extraBits + rd := make([]bitstr.Key, 0, extLen) + wr := make([]bitstr.Key, 1, extLen) + wr[0] = prefix + + // Iteratively append bits until reaching length n. + for range extraBits { + rd, wr = wr, rd[:0] + for _, s := range rd { + wr = append(wr, s+"0", s+"1") + } + } + return wr +} + +// SiblingPrefixes returns the prefixes of the sibling subtrees along the path +// to key. Together with the subtree under `key` itself, these prefixes +// partition the keyspace. +// +// For key "1100" it returns: ["0", "10", "111", "1101"]. +func SiblingPrefixes(key bitstr.Key) []bitstr.Key { + complements := make([]bitstr.Key, len(key)) + for i := range key { + complements[i] = FlipLastBit(key[:i+1]) + } + return complements +} + +// PrefixAndKeys is a struct that holds a prefix and the multihashes whose +// kademlia identifier share the same prefix. +type PrefixAndKeys struct { + Prefix bitstr.Key + Keys []mh.Multihash +} + +// SortPrefixesBySize sorts the prefixes by the number of keys they contain, +// largest first. +func SortPrefixesBySize(prefixes map[bitstr.Key][]mh.Multihash) []PrefixAndKeys { + out := make([]PrefixAndKeys, 0, len(prefixes)) + for prefix, keys := range prefixes { + if keys != nil { + out = append(out, PrefixAndKeys{Prefix: prefix, Keys: keys}) + } + } + slices.SortFunc(out, func(a, b PrefixAndKeys) int { + return cmp.Compare(len(b.Keys), len(a.Keys)) + }) + return out +} diff --git a/provider/internal/keyspace/key_test.go b/provider/internal/keyspace/key_test.go new file mode 100644 index 000000000..bc65db1c1 --- /dev/null +++ b/provider/internal/keyspace/key_test.go @@ -0,0 +1,260 @@ +package keyspace + +import ( + "crypto/rand" + "strconv" + "strings" + "testing" + + "github.com/ipfs/go-test/random" + kb "github.com/libp2p/go-libp2p-kbucket" + "github.com/libp2p/go-libp2p/core/peer" + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad/key" + "github.com/probe-lab/go-libdht/kad/key/bit256" + "github.com/probe-lab/go-libdht/kad/key/bitstr" + + "github.com/stretchr/testify/require" +) + +func TestFlipLastBit(t *testing.T) { + require.Equal(t, FlipLastBit(""), bitstr.Key("")) + require.Equal(t, FlipLastBit("0"), bitstr.Key("1")) + require.Equal(t, FlipLastBit("1"), bitstr.Key("0")) + require.Equal(t, FlipLastBit("00"), bitstr.Key("01")) + require.Equal(t, FlipLastBit("00000000"), bitstr.Key("00000001")) +} + +func TestFirstFullKeyWithPrefix(t *testing.T) { + zeroKey := bitstr.Key(strings.Repeat("0", KeyLen)) + oneKey := bitstr.Key(strings.Repeat("1", KeyLen)) + + require.Equal(t, zeroKey, FirstFullKeyWithPrefix(bitstr.Key(""), zeroKey)) + require.Equal(t, zeroKey, FirstFullKeyWithPrefix(bitstr.Key("0"), zeroKey)) + require.Equal(t, bitstr.Key("000"+strings.Repeat("1", KeyLen-3)), FirstFullKeyWithPrefix(bitstr.Key("000"), oneKey)) + require.Equal(t, zeroKey, FirstFullKeyWithPrefix(zeroKey, zeroKey)) + require.Equal(t, oneKey, FirstFullKeyWithPrefix(oneKey, zeroKey)) + require.Equal(t, zeroKey, FirstFullKeyWithPrefix(zeroKey+"1", zeroKey)) +} + +func TestIsPrefix(t *testing.T) { + require.True(t, IsPrefix(bitstr.Key(""), bitstr.Key(""))) + require.True(t, IsPrefix(bitstr.Key(""), bitstr.Key("1"))) + require.True(t, IsPrefix(bitstr.Key("0"), bitstr.Key("0"))) + require.True(t, IsPrefix(bitstr.Key("0"), bitstr.Key("01"))) + require.True(t, IsPrefix(bitstr.Key("1"), bitstr.Key("11"))) + require.True(t, IsPrefix(bitstr.Key("0"), bitstr.Key("00000000"))) + require.True(t, IsPrefix(bitstr.Key("0101010"), bitstr.Key("01010100"))) + require.True(t, IsPrefix(bitstr.Key("0101010"), bitstr.Key("01010101"))) + + require.False(t, IsPrefix(bitstr.Key("1"), bitstr.Key(""))) + require.False(t, IsPrefix(bitstr.Key("1"), bitstr.Key("0"))) + require.False(t, IsPrefix(bitstr.Key("0"), bitstr.Key("1"))) + require.False(t, IsPrefix(bitstr.Key("00"), bitstr.Key("0"))) +} + +func TestIsBitstrPrefix(t *testing.T) { + fullKey := bitstr.Key("000") + require.True(t, IsBitstrPrefix(bitstr.Key(""), fullKey)) + require.True(t, IsBitstrPrefix(bitstr.Key("0"), fullKey)) + require.True(t, IsBitstrPrefix(bitstr.Key("00"), fullKey)) + require.True(t, IsBitstrPrefix(bitstr.Key("000"), fullKey)) + require.False(t, IsBitstrPrefix(bitstr.Key("1"), fullKey)) + require.False(t, IsBitstrPrefix(bitstr.Key("01"), fullKey)) + require.False(t, IsBitstrPrefix(bitstr.Key("001"), fullKey)) + require.False(t, IsBitstrPrefix(bitstr.Key("0000"), fullKey)) +} + +func TestKeyToBytes(t *testing.T) { + nKeys := 1 << 8 + buf := make([]byte, 32) + for range nKeys { + if _, err := rand.Read(buf); err != nil { + t.Fatal(err) + } + b256 := bit256.NewKey(buf) + bstr := bitstr.Key(key.BitString(b256)) + require.Equal(t, buf, KeyToBytes(b256)) + require.Equal(t, buf, KeyToBytes(bstr)) + } +} + +func TestKeyToBytesPadding(t *testing.T) { + k := bitstr.Key("") + bs := KeyToBytes(k) + require.Equal(t, []byte{}, bs) + + k = bitstr.Key("1") + bs = KeyToBytes(k) + require.Equal(t, []byte{0b10000000}, bs) + + k = bitstr.Key("0") + bs = KeyToBytes(k) + require.Equal(t, []byte{0b00000000}, bs) + + k = bitstr.Key("111111") // 6 ones + bs = KeyToBytes(k) + require.Equal(t, []byte{0b11111100}, bs) + + k = bitstr.Key("00000000") // 8 zeros + bs = KeyToBytes(k) + require.Equal(t, []byte{0b00000000}, bs) + + k = bitstr.Key("11111111") // 8 ones + bs = KeyToBytes(k) + require.Equal(t, []byte{0b11111111}, bs) + + k = bitstr.Key("000000000") // 9 zeros + bs = KeyToBytes(k) + require.Equal(t, []byte{0b00000000, 0b00000000}, bs) + + k = bitstr.Key("111111111") // 9 ones + bs = KeyToBytes(k) + require.Equal(t, []byte{0b11111111, 0b10000000}, bs) +} + +func TestShortestCoveredPrefix(t *testing.T) { + // All keys share CPL of 5, except one sharing a CPL of 4 + var target [32]byte + _, err := rand.Read(target[:]) + require.NoError(t, err) + targetBitstr := bitstr.Key(key.BitString(bit256.NewKey(target[:]))) + + cpl := 5 + nPeers := 16 + peers := make([]peer.ID, nPeers) + for i := range peers { + peers[i], err = kb.GenRandPeerIDWithCPL(target[:], uint(cpl)) + require.NoError(t, err) + } + + // This is a corner case. + // All peers share exactly `cpl` bits with the target, meaning that the + // prefix with `cpl+1` bits has been fully covered and contains 0 peers. No + // peers match this covered prefix. + prefix, coveredPeers := ShortestCoveredPrefix(targetBitstr, peers) + require.Len(t, coveredPeers, 0) + require.Equal(t, targetBitstr[:cpl+1], prefix) + + // Last peer has a lower CPL + peers[len(peers)-1], err = kb.GenRandPeerIDWithCPL(target[:], uint(cpl-1)) + require.NoError(t, err) + prefix, coveredPeers = ShortestCoveredPrefix(targetBitstr, peers) + require.Len(t, coveredPeers, len(peers)-1) + require.Equal(t, targetBitstr[:cpl], prefix) + peers[len(peers)-1], err = kb.GenRandPeerIDWithCPL(target[:], uint(cpl)) + require.NoError(t, err) + + // First peer has a lower CPL + peers[0], err = kb.GenRandPeerIDWithCPL(target[:], uint(cpl-1)) + require.NoError(t, err) + prefix, coveredPeers = ShortestCoveredPrefix(targetBitstr, peers) + require.Len(t, coveredPeers, len(peers)-1) + require.Equal(t, targetBitstr[:cpl], prefix) + + // First peer has a much lower CPL + peers[0], err = kb.GenRandPeerIDWithCPL(target[:], uint(cpl-3)) + require.NoError(t, err) + prefix, coveredPeers = ShortestCoveredPrefix(targetBitstr, peers) + require.Len(t, coveredPeers, len(peers)-1) + require.Equal(t, targetBitstr[:cpl-2], prefix) + + // First peer has a higher CPL + peers[0], err = kb.GenRandPeerIDWithCPL(target[:], uint(cpl+1)) + require.NoError(t, err) + prefix, coveredPeers = ShortestCoveredPrefix(targetBitstr, peers) + require.Len(t, coveredPeers, 1) + require.Equal(t, targetBitstr[:cpl+1], prefix) + + // First peer has a much higher CPL + peers[0], err = kb.GenRandPeerIDWithCPL(target[:], uint(cpl+3)) + require.NoError(t, err) + prefix, coveredPeers = ShortestCoveredPrefix(targetBitstr, peers) + require.Len(t, coveredPeers, 1) + require.Equal(t, targetBitstr[:cpl+1], prefix) + + // Test with random peer ids + nIterations := 64 + for range nIterations { + minCpl := KeyLen + largestCplCount := 0 + peers = random.Peers(nPeers) + peers = kb.SortClosestPeers(peers, target[:]) + for i := range peers { + cpl = kb.CommonPrefixLen(kb.ConvertPeerID(peers[i]), target[:]) + if cpl < minCpl { + minCpl = cpl + largestCplCount = 1 + } else { + largestCplCount++ + } + } + prefix, coveredPeers = ShortestCoveredPrefix(targetBitstr, peers) + require.Len(t, coveredPeers, len(peers)-largestCplCount) + require.Equal(t, targetBitstr[:minCpl+1], prefix) + } + + // Test without supplying peers + bstrTarget := bitstr.Key("110101111") + prefix, coveredPeers = ShortestCoveredPrefix(bstrTarget, nil) + require.Equal(t, bstrTarget, prefix) + require.Empty(t, coveredPeers) +} + +func TestExtendBinaryPrefix(t *testing.T) { + prefix := bitstr.Key("") + l := 1 + require.Equal(t, []bitstr.Key{"0", "1"}, ExtendBinaryPrefix(prefix, l)) + prefix = bitstr.Key("1101") + l = 6 + require.Equal(t, []bitstr.Key{"110100", "110101", "110110", "110111"}, ExtendBinaryPrefix(prefix, l)) +} + +func TestSiblingPrefixes(t *testing.T) { + k := bitstr.Key("") + require.Empty(t, SiblingPrefixes(k)) + k = bitstr.Key("0") + require.Equal(t, []bitstr.Key{"1"}, SiblingPrefixes(k)) + k = bitstr.Key("1") + require.Equal(t, []bitstr.Key{"0"}, SiblingPrefixes(k)) + k = bitstr.Key("00") + require.Equal(t, []bitstr.Key{"1", "01"}, SiblingPrefixes(k)) + k = bitstr.Key("000") + require.Equal(t, []bitstr.Key{"1", "01", "001"}, SiblingPrefixes(k)) + k = bitstr.Key("1100") + require.Equal(t, []bitstr.Key{"0", "10", "111", "1101"}, SiblingPrefixes(k)) +} + +func genMultihashes(n int) []mh.Multihash { + mhs := make([]mh.Multihash, n) + for i := range mhs { + h, err := mh.Sum([]byte(strconv.Itoa(i)), mh.SHA2_256, -1) + if err != nil { + panic(err) + } + mhs[i], err = mh.Encode(h, mh.SHA2_256) + if err != nil { + panic(err) + } + } + return mhs +} + +func TestSortPrefixesBySize(t *testing.T) { + prefixLen := 6 + allocations := make(map[bitstr.Key][]mh.Multihash, 1< depth+1 { + for _, siblingPrefix := range SiblingPrefixes(k)[depth+1:] { + gaps = append(gaps, siblingPrefix[depth:]) + } + } + } else { + gaps = append(gaps, bstr) + } + } else { + for _, gap := range trieGapsAtDepth(b, depth+1) { + gaps = append(gaps, bstr+gap) + } + } + } + return gaps +} + +// mapMerge merges all key-value pairs from the source map into the destination +// map. Values from the source are appended to existing slices in the +// destination. +func mapMerge[K comparable, V any](dst, src map[K][]V) { + for k, vs := range src { + dst[k] = append(dst[k], vs...) + } +} + +// AllocateToKClosest distributes items from the items trie to the k closest +// destinations in the dests trie based on XOR distance between their keys. +// +// The algorithm uses the trie structure to efficiently compute proximity +// without explicit distance calculations. Items are allocated to destinations +// by traversing both tries simultaneously and selecting the k destinations +// with the smallest XOR distance to each item's key. +// +// Returns a map where each destination value is associated with all items +// allocated to it. If k is 0 or either trie is empty, returns an empty map. +func AllocateToKClosest[K kad.Key[K], V0 any, V1 comparable](items *trie.Trie[K, V0], dests *trie.Trie[K, V1], k int) map[V1][]V0 { + return allocateToKClosestAtDepth(items, dests, k, 0) +} + +// allocateToKClosestAtDepth performs the recursive allocation algorithm at a specific +// trie depth. At each depth, it processes both branches (0 and 1) of the trie, +// determining which destinations are closest to the items based on matching bit +// patterns at the current depth. +// +// The algorithm prioritizes destinations in the same branch as items (smaller XOR +// distance) and recursively processes deeper levels when more granular distance +// calculations are needed to select exactly k destinations. +// +// Parameters: +// - items: trie containing items to be allocated +// - dests: trie containing destination candidates +// - k: maximum number of destinations to allocate each item to +// - depth: current bit depth in the trie traversal +// +// Returns a map of destination values to their allocated items. +func allocateToKClosestAtDepth[K kad.Key[K], V0 any, V1 comparable](items *trie.Trie[K, V0], dests *trie.Trie[K, V1], k, depth int) map[V1][]V0 { + m := make(map[V1][]V0) + if k == 0 { + return m + } + for i := range 2 { + // Assign all items from branch i + + matchingItemsBranch := items.Branch(i) + matchingItems := AllValues(matchingItemsBranch, bit256.ZeroKey()) + if len(matchingItems) == 0 { + if !items.IsNonEmptyLeaf() || int((*items.Key()).Bit(depth)) != i { + // items' current branch is empty, skip it + continue + } + // items' current branch contains a single leaf + matchingItems = []V0{items.Data()} + matchingItemsBranch = items + } + + matchingDestsBranch := dests.Branch(i) + otherDestsBranch := dests.Branch(1 - i) + matchingDests := AllValues(matchingDestsBranch, bit256.ZeroKey()) + otherDests := AllValues(otherDestsBranch, bit256.ZeroKey()) + if dests.IsLeaf() { + // Single key (leaf) in dests + if dests.IsNonEmptyLeaf() { + if int((*dests.Key()).Bit(depth)) == i { + // Leaf matches current branch + matchingDests = []V1{dests.Data()} + matchingDestsBranch = dests + } else { + // Leaf matches other branch + otherDests = []V1{dests.Data()} + otherDestsBranch = dests + } + } else { + // Empty leaf, no dests to allocate items. + return m + } + } + + if nMatchingDests := len(matchingDests); nMatchingDests <= k { + // Allocate matching items to the matching dests branch + for _, dest := range matchingDests { + m[dest] = append(m[dest], matchingItems...) + } + if nMatchingDests == k || len(otherDests) == 0 { + // Items were assigned to all k dests, or other branch is empty. + continue + } + + nMissingDests := k - nMatchingDests + if len(otherDests) <= nMissingDests { + // Other branch contains at most the missing number of dests to be + // allocated to. Allocate matching items to the other dests branch. + for _, dest := range otherDests { + m[dest] = append(m[dest], matchingItems...) + } + } else { + // Other branch contains more than the missing number of dests, go one + // level deeper to assign matching items to the closest dests. + allocs := allocateToKClosestAtDepth(matchingItemsBranch, otherDestsBranch, nMissingDests, depth+1) + mapMerge(m, allocs) + } + } else { + // Number of matching dests is larger than k, go one level deeper. + allocs := allocateToKClosestAtDepth(matchingItemsBranch, matchingDestsBranch, k, depth+1) + mapMerge(m, allocs) + } + } + return m +} + +// Region represents a subtrie of the complete DHT keyspace. +// +// - Prefix is the identifier of the subtrie. +// - Peers contains all the network peers matching this region. +// - Keys contains all the keys provided by the local node matching this +// region. +type Region struct { + Prefix bitstr.Key + Peers *trie.Trie[bit256.Key, peer.ID] + Keys *trie.Trie[bit256.Key, mh.Multihash] +} + +// RegionsFromPeers returns the keyspace regions of size at least `regionSize` +// from the given `peers` sorted according to `order` along with the Common +// Prefix shared by all peers. +func RegionsFromPeers(peers []peer.ID, regionSize int, order bit256.Key) ([]Region, bitstr.Key) { + if len(peers) == 0 { + return []Region{}, "" + } + peersTrie := trie.New[bit256.Key, peer.ID]() + minCpl := KeyLen + firstPeerKey := PeerIDToBit256(peers[0]) + for _, p := range peers { + k := PeerIDToBit256(p) + peersTrie.Add(k, p) + minCpl = min(minCpl, firstPeerKey.CommonPrefixLength(k)) + } + commonPrefix := bitstr.Key(key.BitString(firstPeerKey)[:minCpl]) + regions := extractMinimalRegions(peersTrie, commonPrefix, regionSize, order) + return regions, commonPrefix +} + +// extractMinimalRegions returns the list of all non-overlapping subtries of +// `t` having at least `size` elements, sorted according to `order`. Every +// element is included in exactly one region. +func extractMinimalRegions(t *trie.Trie[bit256.Key, peer.ID], path bitstr.Key, size int, order bit256.Key) []Region { + if t.IsEmptyLeaf() { + return nil + } + branch0, branch1 := t.Branch(0), t.Branch(1) + if branch0 != nil && branch1 != nil && branch0.Size() >= size && branch1.Size() >= size { + b := int(order.Bit(len(path))) + return append(extractMinimalRegions(t.Branch(b), path+bitstr.Key(byte('0'+b)), size, order), + extractMinimalRegions(t.Branch(1-b), path+bitstr.Key(byte('1'-b)), size, order)...) + } + return []Region{{Prefix: path, Peers: t}} +} + +// AssignKeysToRegions assigns the provided keys to the regions based on their +// kademlia identifier key. +func AssignKeysToRegions(regions []Region, keys []mh.Multihash) []Region { + for i := range regions { + regions[i].Keys = trie.New[bit256.Key, mh.Multihash]() + } + for _, k := range keys { + h := MhToBit256(k) + for i, r := range regions { + if IsPrefix(r.Prefix, h) { + regions[i].Keys.Add(h, k) + break + } + } + } + return regions +} diff --git a/provider/internal/keyspace/trie_test.go b/provider/internal/keyspace/trie_test.go new file mode 100644 index 000000000..8ea1108b0 --- /dev/null +++ b/provider/internal/keyspace/trie_test.go @@ -0,0 +1,807 @@ +package keyspace + +import ( + "crypto/rand" + "fmt" + "sort" + "testing" + + "github.com/ipfs/go-test/random" + kb "github.com/libp2p/go-libp2p-kbucket" + "github.com/libp2p/go-libp2p/core/peer" + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad/key" + "github.com/probe-lab/go-libdht/kad/key/bit256" + "github.com/probe-lab/go-libdht/kad/key/bitstr" + "github.com/probe-lab/go-libdht/kad/trie" + + "github.com/stretchr/testify/require" +) + +func TestAllEntries(t *testing.T) { + tr := trie.New[bitstr.Key, string]() + elements := []struct { + key bitstr.Key + fruit string + }{ + { + key: bitstr.Key("000"), + fruit: "apple", + }, + { + key: bitstr.Key("010"), + fruit: "banana", + }, + { + key: bitstr.Key("101"), + fruit: "cherry", + }, + { + key: bitstr.Key("111"), + fruit: "durian", + }, + } + + for _, e := range elements { + tr.Add(e.key, e.fruit) + } + + // Test in 0 -> 1 order + entries := AllEntries(tr, bitstr.Key("000")) + require.Equal(t, len(elements), len(entries)) + for i := range entries { + require.Equal(t, entries[i].Key, elements[i].key) + require.Equal(t, entries[i].Data, elements[i].fruit) + } + + // Test in reverse order (1 -> 0) + entries = AllEntries(tr, bitstr.Key("111")) + require.Equal(t, len(elements), len(entries)) + for i := range entries { + require.Equal(t, entries[i].Key, elements[len(elements)-1-i].key) + require.Equal(t, entries[i].Data, elements[len(elements)-1-i].fruit) + } +} + +func TestFindPrefixOfKey(t *testing.T) { + tr := trie.New[bitstr.Key, struct{}]() + + keys := []bitstr.Key{ + "00", + "10", + } + for _, k := range keys { + tr.Add(k, struct{}{}) + } + + match, ok := FindPrefixOfKey(tr, bitstr.Key("00")) + require.True(t, ok) + require.Equal(t, bitstr.Key("00"), match) + + match, ok = FindPrefixOfKey(tr, bitstr.Key("10000000")) + require.True(t, ok) + require.Equal(t, bitstr.Key("10"), match) + + _, ok = FindPrefixOfKey(tr, bitstr.Key("01")) + require.False(t, ok) + _, ok = FindPrefixOfKey(tr, bitstr.Key("11000000")) + require.False(t, ok) +} + +func TestFindPrefixOfTooShortKey(t *testing.T) { + tr := trie.New[bitstr.Key, struct{}]() + keys := []bitstr.Key{ + "0000", + "0001", + "0010", + "0011", + } + for _, k := range keys { + tr.Add(k, struct{}{}) + } + _, ok := FindPrefixOfKey(tr, bitstr.Key("000")) + require.False(t, ok) +} + +func TestFindSubtrie(t *testing.T) { + keys := []bitstr.Key{ + "0000", + "0001", + "0010", + "0100", + "0111", + "1010", + "1011", + "1101", + "1110", + } + tr := trie.New[bitstr.Key, struct{}]() + + _, ok := FindSubtrie(tr, bitstr.Key("0000")) + require.False(t, ok) + + for _, k := range keys { + tr.Add(k, struct{}{}) + } + + subtrie, ok := FindSubtrie(tr, bitstr.Key("")) + require.True(t, ok) + require.Equal(t, tr, subtrie) + require.Equal(t, 9, subtrie.Size()) + + subtrie, ok = FindSubtrie(tr, bitstr.Key("0")) + require.True(t, ok) + require.Equal(t, tr.Branch(0), subtrie) + require.Equal(t, 5, subtrie.Size()) + + subtrie, ok = FindSubtrie(tr, bitstr.Key("1")) + require.True(t, ok) + require.Equal(t, tr.Branch(1), subtrie) + require.Equal(t, 4, subtrie.Size()) + + subtrie, ok = FindSubtrie(tr, bitstr.Key("000")) + require.True(t, ok) + require.Equal(t, tr.Branch(0).Branch(0).Branch(0), subtrie) + require.Equal(t, 2, subtrie.Size()) + + subtrie, ok = FindSubtrie(tr, bitstr.Key("0000")) + require.True(t, ok) + require.Equal(t, tr.Branch(0).Branch(0).Branch(0).Branch(0), subtrie) + require.Equal(t, 1, subtrie.Size()) + require.True(t, subtrie.IsNonEmptyLeaf()) + + subtrie, ok = FindSubtrie(tr, bitstr.Key("111")) + require.True(t, ok) + require.Equal(t, tr.Branch(1).Branch(1).Branch(1), subtrie) + require.Equal(t, 1, subtrie.Size()) + require.True(t, subtrie.IsNonEmptyLeaf()) + + _, ok = FindSubtrie(tr, bitstr.Key("100")) + require.False(t, ok) + _, ok = FindSubtrie(tr, bitstr.Key("1001")) + require.False(t, ok) + _, ok = FindSubtrie(tr, bitstr.Key("00000")) + require.False(t, ok) +} + +func TestNextNonEmptyLeafFullTrie(t *testing.T) { + bitlen := 4 + + tr := trie.New[bitstr.Key, any]() + nKeys := 1 << bitlen + binaryKeys := make([]bitstr.Key, 0, nKeys) + for i := range nKeys { + binary := fmt.Sprintf("%0*b", bitlen, i) + k := bitstr.Key(binary) + tr.Add(k, struct{}{}) + binaryKeys = append(binaryKeys, k) + } + + order := binaryKeys[0] + t.Run("OrderZero", func(t *testing.T) { + for i, k := range binaryKeys { + nextKey := NextNonEmptyLeaf(tr, k, order).Key + require.Equal(t, binaryKeys[(i+1)%nKeys], nextKey) + } + }) + + t.Run("Cycle", func(t *testing.T) { + initialKey := binaryKeys[0] + k := initialKey + for range binaryKeys { + k = NextNonEmptyLeaf(tr, k, order).Key + } + require.Equal(t, initialKey, k) + }) + + order = binaryKeys[nKeys-1] + t.Run("CustomOrder", func(t *testing.T) { + for i, k := range binaryKeys { + nextKey := NextNonEmptyLeaf(tr, k, order).Key + require.Equal(t, binaryKeys[(i-1+nKeys)%nKeys], nextKey) + } + }) +} + +func TestNextNonEmptyLeafSparseTrie(t *testing.T) { + bitlen := 10 + sparsity := 4 + + tr := trie.New[bitstr.Key, any]() + nKeys := 1 << (bitlen - sparsity) + binaryKeys := make([]bitstr.Key, 0, nKeys) + suffix := (1 << sparsity) - 1 + for i := range nKeys { + binary := fmt.Sprintf("%0*b", bitlen, i*(1<= 0 { + // `prefix` has superstrings in the queue. Remove them all and insert + // `prefix` in the queue at the location of the first removed superstring. + q.queue.Insert(firstRemovedIndex, prefix) + // Add `prefix` to prefixes trie. + q.prefixes.Add(prefix, struct{}{}) + } else if _, ok := keyspace.FindPrefixOfKey(q.prefixes, prefix); !ok { + // No prefixes nor superstrings of `prefix` found in the queue. + q.queue.PushBack(prefix) + q.prefixes.Add(prefix, struct{}{}) + } +} + +// Pop removes and returns the first prefix from the queue. +func (q *prefixQueue) Pop() (bitstr.Key, bool) { + if q.queue.Len() == 0 { + return bitstr.Key(""), false + } + // Dequeue the first prefix from the queue. + prefix := q.queue.PopFront() + // Remove the prefix from the prefixes trie. + q.prefixes.Remove(prefix) + + return prefix, true +} + +// Remove removes a prefix or all its superstrings from the queue, if any. +func (q *prefixQueue) Remove(prefix bitstr.Key) bool { + return q.removeSuperstrings(prefix) >= 0 +} + +// Returns the number of prefixes in the queue. +func (q *prefixQueue) Size() int { + return q.queue.Len() +} + +// Clear removes all keys from the queue and returns the number of keys that +// were removed. +func (q *prefixQueue) Clear() int { + size := q.Size() + + q.queue.Clear() + *q.prefixes = trie.Trie[bitstr.Key, struct{}]{} + + return size +} + +// removeSuperstrings finds all superstrings of `prefix` in the trie, removes +// them from the queue, and returns the index at which the first removal +// occurred, or -1 if none. +func (q *prefixQueue) removeSuperstrings(prefix bitstr.Key) int { + subtrie, ok := keyspace.FindSubtrie(q.prefixes, prefix) + if !ok { + return -1 + } + entries := keyspace.AllEntries(subtrie, bit256.ZeroKey()) + toRemove := make([]bitstr.Key, len(entries)) + for i, e := range entries { + toRemove[i] = e.Key + } + return q.removePrefixesFromQueue(toRemove) +} + +// removeSubtrieFromQueue removes all keys in the provided subtrie from q.queue +// and q.prefixes. Returns the position of the first removed key in the queue. +func (q *prefixQueue) removePrefixesFromQueue(prefixes []bitstr.Key) int { + indexes := make([]int, 0, len(prefixes)) + for _, prefix := range prefixes { + // Remove elements from the queue that are superstrings of `prefix`. + q.prefixes.Remove(prefix) + // Find indexes of the superstrings in the queue. + index := q.queue.Index(func(element bitstr.Key) bool { return element == prefix }) + if index >= 0 { + indexes = append(indexes, index) + } + } + // Sort indexes to remove in descending order so that we can remove them + // without affecting the indexes of the remaining elements. + slices.Sort(indexes) + slices.Reverse(indexes) + // Remove items in the queue that are prefixes of `prefix` + for _, index := range indexes { + q.queue.Remove(index) + } + return indexes[len(indexes)-1] // return the position of the first removed key +} diff --git a/provider/internal/queue/provide.go b/provider/internal/queue/provide.go new file mode 100644 index 000000000..ed2227e71 --- /dev/null +++ b/provider/internal/queue/provide.go @@ -0,0 +1,185 @@ +package queue + +import ( + "sync" + + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad/key/bit256" + "github.com/probe-lab/go-libdht/kad/key/bitstr" + "github.com/probe-lab/go-libdht/kad/trie" + + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace" +) + +// ProvideQueue is a thread-safe queue storing multihashes about to be provided +// to a Kademlia DHT, allowing smart batching. +// +// The queue groups keys by their kademlia identifier prefixes, so that keys +// that should be allocated to the same DHT peers are dequeued together from +// the queue, for efficient batch providing. +// +// The insertion order of prefixes is preserved, but not for keys. Inserting +// keys matching a prefix that is already in the queue inserts the keys at the +// position of the existing prefix. +// +// ProvideQueue allows dequeuing the first prefix of the queue, with all +// matching keys or dequeuing all keys matching a requested prefix. +type ProvideQueue struct { + mu sync.Mutex + + queue prefixQueue + keys *trie.Trie[bit256.Key, mh.Multihash] // used to store keys in the queue +} + +// NewProvideQueue creates a new ProvideQueue instance. +func NewProvideQueue() *ProvideQueue { + return &ProvideQueue{ + queue: prefixQueue{prefixes: trie.New[bitstr.Key, struct{}]()}, + keys: trie.New[bit256.Key, mh.Multihash](), + } +} + +// Enqueue adds the supplied keys to the queue under the given prefix. +// +// If the prefix already sits in the queue, supplied keys join the queue at the +// position of the existing prefix. If the queue contains prefixes that are +// superstrings of the supplied prefix, all keys matching the supplied prefix +// are consolidated at the position of the first matching superstring in the +// queue. +// +// If supplied prefix doesn't exist yet in the queue, add it at the end. +// +// Supplied keys MUST match the supplied prefix. +func (q *ProvideQueue) Enqueue(prefix bitstr.Key, keys ...mh.Multihash) { + if len(keys) == 0 { + return + } + q.mu.Lock() + defer q.mu.Unlock() + + // Enqueue the prefix in the queue if required. + q.queue.Push(prefix) + + // Add keys to the keys trie. + for _, h := range keys { + q.keys.Add(keyspace.MhToBit256(h), h) + } +} + +// Dequeue pops the first prefix of the queue along with all matching keys. +// +// The prefix and keys are removed from the queue. If the queue is empty, +// return false and the empty prefix. +func (q *ProvideQueue) Dequeue() (bitstr.Key, []mh.Multihash, bool) { + q.mu.Lock() + defer q.mu.Unlock() + prefix, ok := q.queue.Pop() + if !ok { + return prefix, nil, false + } + + // Get all keys that match the prefix. + subtrie, _ := keyspace.FindSubtrie(q.keys, prefix) + keys := keyspace.AllValues(subtrie, bit256.ZeroKey()) + + // Remove the keys from the keys trie. + keyspace.PruneSubtrie(q.keys, prefix) + + return prefix, keys, true +} + +// DequeueMatching returns keys matching the given prefix from the queue. +// +// The keys and prefix are removed from the queue. If the queue is empty, or +// supplied prefix doesn't match any keys, an empty slice is returned. +func (q *ProvideQueue) DequeueMatching(prefix bitstr.Key) []mh.Multihash { + q.mu.Lock() + defer q.mu.Unlock() + + subtrie, ok := keyspace.FindSubtrie(q.keys, prefix) + if !ok { + // No keys matching the prefix. + return nil + } + keys := keyspace.AllValues(subtrie, bit256.ZeroKey()) + + // Remove the keys from the keys trie. + keyspace.PruneSubtrie(q.keys, prefix) + + // Remove prefix and its superstrings from queue if any. + removed := q.queue.Remove(prefix) + if !removed { + // prefix and superstrings not in queue. + if shorterPrefix, ok := keyspace.FindPrefixOfKey(q.queue.prefixes, prefix); ok { + // prefix is a superstring of some other shorter prefix in the queue. + // Leave it in the queue, unless the shorter prefix doesn't have any + // matching keys left. + if _, ok := keyspace.FindSubtrie(q.keys, shorterPrefix); !ok { + // No keys matching shorterPrefix, remove shorterPrefix from queue. + q.queue.Remove(shorterPrefix) + } + } + } + return keys +} + +// Remove removes the supplied keys from the queue. +// +// If this operation removes the last keys for prefixes in the queue, remove +// the prefixes from the queue. +func (q *ProvideQueue) Remove(keys ...mh.Multihash) { + q.mu.Lock() + defer q.mu.Unlock() + + matchingPrefixes := make(map[bitstr.Key]struct{}) + + // Remove keys from the keys trie. + for _, h := range keys { + k := keyspace.MhToBit256(h) + q.keys.Remove(k) + if prefix, ok := keyspace.FindPrefixOfKey(q.queue.prefixes, k); ok { + // Get the trie leaf matching the key, if any. + matchingPrefixes[prefix] = struct{}{} + } + } + + // For matching prefixes, if no more keys are matching, remove them from + // queue. + prefixesToRemove := make([]bitstr.Key, 0) + for prefix := range matchingPrefixes { + if _, ok := keyspace.FindSubtrie(q.keys, prefix); !ok { + prefixesToRemove = append(prefixesToRemove, prefix) + } + } + if len(prefixesToRemove) > 0 { + q.queue.removePrefixesFromQueue(prefixesToRemove) + } +} + +// IsEmpty returns true if the queue is empty. +func (q *ProvideQueue) IsEmpty() bool { + q.mu.Lock() + defer q.mu.Unlock() + return q.queue.Size() == 0 +} + +// Size returns the number of regions containing at least one key in the queue. +func (q *ProvideQueue) Size() int { + q.mu.Lock() + defer q.mu.Unlock() + return q.keys.Size() +} + +// Clear removes all keys from the queue and returns the number of keys that +// were removed. +func (q *ProvideQueue) Clear() int { + q.mu.Lock() + defer q.mu.Unlock() + size := q.keys.Size() + + q.queue.Clear() + *q.keys = trie.Trie[bit256.Key, mh.Multihash]{} + + return size +} diff --git a/provider/internal/queue/provide_test.go b/provider/internal/queue/provide_test.go new file mode 100644 index 000000000..21ab8e76b --- /dev/null +++ b/provider/internal/queue/provide_test.go @@ -0,0 +1,280 @@ +package queue + +import ( + "testing" + + "github.com/ipfs/go-test/random" + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad/key/bitstr" + "github.com/probe-lab/go-libdht/kad/trie" + "github.com/stretchr/testify/require" + + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace" +) + +func genMultihashesMatchingPrefix(prefix bitstr.Key, n int) []mh.Multihash { + mhs := make([]mh.Multihash, 0, n) + for i := 0; len(mhs) < n; i++ { + h := random.Multihashes(1)[0] + k := keyspace.MhToBit256(h) + if keyspace.IsPrefix(prefix, k) { + mhs = append(mhs, h) + } + } + return mhs +} + +func TestProvideEnqueueSimple(t *testing.T) { + nMultihashesPerPrefix := 1 << 4 + + q := NewProvideQueue() + + // Enqueue no multihash + q.Enqueue(bitstr.Key("1010")) + require.Equal(t, q.Size(), 0) + + prefixes := []bitstr.Key{ + "000", + "001", + "010", + "011", + "10", + } + for _, prefix := range prefixes { + mhs := genMultihashesMatchingPrefix(prefix, nMultihashesPerPrefix) + q.Enqueue(prefix, mhs...) + } + + // Verify prefixes are in the queue + require.Equal(t, len(prefixes), q.queue.prefixes.Size()) + require.Equal(t, len(prefixes), q.queue.queue.Len()) + for _, prefix := range prefixes { + require.GreaterOrEqual(t, q.queue.queue.Index(func(k bitstr.Key) bool { return k == prefix }), 0) + ok, _ := trie.Find(q.queue.prefixes, prefix) + require.True(t, ok) + } + // Verify the count of multihashes matches + require.Equal(t, len(prefixes)*nMultihashesPerPrefix, q.Size()) +} + +func TestProvideEnqueueOverlapping(t *testing.T) { + nMultihashesPerPrefix := 1 << 4 + + q := NewProvideQueue() + + prefixes := []bitstr.Key{ + "000", + "0000", + } + for _, prefix := range prefixes { + mhs := genMultihashesMatchingPrefix(prefix, nMultihashesPerPrefix) + q.Enqueue(prefix, mhs...) + } + + require.Equal(t, 1, q.queue.prefixes.Size()) // Only shortest prefix should remain + require.Equal(t, 1, q.queue.queue.Len()) + require.GreaterOrEqual(t, q.queue.queue.Index(func(k bitstr.Key) bool { return k == prefixes[0] }), 0) // "000" is in queue + require.Negative(t, q.queue.queue.Index(func(k bitstr.Key) bool { return k == prefixes[1] })) // "0000" is NOT in queue + + // Verify the count of multihashes matches + require.Equal(t, len(prefixes)*nMultihashesPerPrefix, q.Size()) + + prefixes = []bitstr.Key{ + "1111", + "111", + } + for _, prefix := range prefixes { + mhs := genMultihashesMatchingPrefix(prefix, nMultihashesPerPrefix) + q.Enqueue(prefix, mhs...) + } + + require.Equal(t, 2, q.queue.prefixes.Size()) // only "000" and "111" should remain + require.Equal(t, 2, q.queue.queue.Len()) + require.GreaterOrEqual(t, q.queue.queue.Index(func(k bitstr.Key) bool { return k == prefixes[1] }), 0) // "111" is in queue + require.Negative(t, q.queue.queue.Index(func(k bitstr.Key) bool { return k == prefixes[0] })) // "1111" is NOT in queue + + // Verify the count of multihashes matches + require.Equal(t, 2*len(prefixes)*nMultihashesPerPrefix, q.Size()) +} + +func TestProvideDequeue(t *testing.T) { + nMultihashesPerPrefix := 1 << 4 + q := NewProvideQueue() + prefixes := []bitstr.Key{ + "100", + "001", + "010", + "11", + "000", + } + mhMap := make(map[bitstr.Key][]mh.Multihash) + for _, prefix := range prefixes { + mhs := genMultihashesMatchingPrefix(prefix, nMultihashesPerPrefix) + q.Enqueue(prefix, mhs...) + mhMap[prefix] = mhs + } + require.Equal(t, q.queue.prefixes.Size(), len(prefixes)) + require.Equal(t, q.queue.queue.Len(), len(prefixes)) + require.Equal(t, q.Size(), len(prefixes)*nMultihashesPerPrefix) + + for i := 0; !q.IsEmpty(); i++ { + prefix, mhs, ok := q.Dequeue() + require.True(t, ok) + require.Equal(t, prefixes[i], prefix) + require.ElementsMatch(t, mhMap[prefix], mhs) + require.Negative(t, q.queue.queue.Index(func(k bitstr.Key) bool { return k == prefix })) // prefix not in queue anymore + require.False(t, q.queue.prefixes.Remove(prefix)) + require.Equal(t, q.Size(), (len(prefixes)-i-1)*nMultihashesPerPrefix) + } + + prefix, mhs, ok := q.Dequeue() + require.False(t, ok) // Queue is empty + require.Equal(t, bitstr.Key(""), prefix) + require.Empty(t, mhs) +} + +func TestProvideDequeueMatching(t *testing.T) { + nMultihashesPerPrefix := 1 << 4 + q := NewProvideQueue() + prefixes := []bitstr.Key{ + "0010", + "100", + "010", + "0011", + "11", + "000", + } + mhMap := make(map[bitstr.Key][]mh.Multihash) + for _, prefix := range prefixes { + mhs := genMultihashesMatchingPrefix(prefix, nMultihashesPerPrefix) + q.Enqueue(prefix, mhs...) + mhMap[prefix] = mhs + } + require.Equal(t, q.queue.prefixes.Size(), len(prefixes)) + require.Equal(t, q.queue.queue.Len(), len(prefixes)) + require.Equal(t, q.Size(), len(prefixes)*nMultihashesPerPrefix) + + // Prefix not in queue. + mhs := q.DequeueMatching(bitstr.Key("101")) + require.Empty(t, mhs) + + mhs = q.DequeueMatching(bitstr.Key("010")) + require.ElementsMatch(t, mhMap[bitstr.Key("010")], mhs) + require.Equal(t, 5, q.queue.queue.Len()) + require.Equal(t, 5, q.queue.prefixes.Size()) + // Verify queue order didn't change + require.Equal(t, 0, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("0010") })) + require.Equal(t, 1, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("100") })) + require.Equal(t, 2, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("0011") })) + require.Equal(t, 3, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("11") })) + require.Equal(t, 4, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("000") })) + + mhs = q.DequeueMatching(bitstr.Key("001")) + require.ElementsMatch(t, append(mhMap[bitstr.Key("0010")], mhMap[bitstr.Key("0011")]...), mhs) + require.Equal(t, 3, q.queue.queue.Len()) + require.Equal(t, 3, q.queue.prefixes.Size()) + // Verify queue order didn't change + require.Equal(t, 0, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("100") })) + require.Equal(t, 1, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("11") })) + require.Equal(t, 2, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("000") })) + + // Prefix not in queue. + mhs = q.DequeueMatching(bitstr.Key("011")) + require.Empty(t, mhs) + + // Partial prefix + mhs0 := q.DequeueMatching(bitstr.Key("110")) + if len(mhs0) > 0 { + require.Equal(t, 3, q.queue.queue.Len()) + require.Equal(t, 3, q.queue.prefixes.Size()) + require.Equal(t, 0, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("100") })) + require.Equal(t, 1, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("11") })) + require.Equal(t, 2, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("000") })) + } + mhs1 := q.DequeueMatching(bitstr.Key("111")) + require.Equal(t, 2, q.queue.queue.Len()) + require.Equal(t, 2, q.queue.prefixes.Size()) + require.Equal(t, 0, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("100") })) + require.Equal(t, 1, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("000") })) + require.ElementsMatch(t, append(mhs0, mhs1...), mhMap[bitstr.Key("11")]) + + prefix, mhs, ok := q.Dequeue() + require.True(t, ok) + require.Equal(t, bitstr.Key("100"), prefix) + require.ElementsMatch(t, mhMap[bitstr.Key("100")], mhs) + + mhs = q.DequeueMatching(bitstr.Key("000")) + require.ElementsMatch(t, mhMap[bitstr.Key("000")], mhs) + + require.Equal(t, 0, q.queue.queue.Len()) + require.True(t, q.IsEmpty()) + + mhs = q.DequeueMatching(bitstr.Key("000")) + require.Empty(t, mhs) +} + +func TestProvideRemove(t *testing.T) { + nMultihashesPerPrefix := 1 << 2 + q := NewProvideQueue() + prefixes := []bitstr.Key{ + "0010", + "100", + "010", + } + mhMap := make(map[bitstr.Key][]mh.Multihash) + for _, prefix := range prefixes { + mhs := genMultihashesMatchingPrefix(prefix, nMultihashesPerPrefix) + q.Enqueue(prefix, mhs...) + mhMap[prefix] = mhs + } + require.Equal(t, len(prefixes), q.queue.prefixes.Size()) + require.Equal(t, len(prefixes), q.queue.queue.Len()) + require.Equal(t, len(prefixes)*nMultihashesPerPrefix, q.Size()) + + q.Remove(mhMap[bitstr.Key("0010")][:2]...) + require.Equal(t, len(prefixes)*nMultihashesPerPrefix-2, q.Size()) + require.Equal(t, q.queue.queue.At(0), bitstr.Key("0010")) + + q.Remove(mhMap[bitstr.Key("100")]...) + require.Equal(t, len(prefixes)*nMultihashesPerPrefix-6, q.Size()) + require.Equal(t, q.queue.queue.At(1), bitstr.Key("010")) + + q.Remove(mhMap[bitstr.Key("0010")][2]) + require.Equal(t, len(prefixes)*nMultihashesPerPrefix-7, q.Size()) + require.Equal(t, q.queue.queue.At(0), bitstr.Key("0010")) + + q.Remove(append([]mh.Multihash{mhMap[bitstr.Key("0010")][3]}, mhMap[bitstr.Key("010")][1:3]...)...) + require.Equal(t, 2, q.Size()) + require.Equal(t, q.queue.queue.At(0), bitstr.Key("010")) +} + +func TestProvideClearQueue(t *testing.T) { + nMultihashesPerPrefix := 1 << 4 + q := NewProvideQueue() + require.True(t, q.IsEmpty()) + prefixes := []bitstr.Key{ + "000", + "001", + "010", + "011", + "10", + } + for _, prefix := range prefixes { + mhs := genMultihashesMatchingPrefix(prefix, nMultihashesPerPrefix) + q.Enqueue(prefix, mhs...) + } + + require.False(t, q.IsEmpty()) + require.Equal(t, q.queue.prefixes.Size(), len(prefixes)) + require.Equal(t, q.queue.queue.Len(), len(prefixes)) + require.Equal(t, q.Size(), len(prefixes)*nMultihashesPerPrefix) + + cleared := q.Clear() + require.Equal(t, len(prefixes)*nMultihashesPerPrefix, cleared) + require.True(t, q.IsEmpty()) + + require.True(t, q.keys.IsEmptyLeaf()) + require.True(t, q.queue.prefixes.IsEmptyLeaf()) + require.Equal(t, 0, q.queue.queue.Len()) +} diff --git a/provider/internal/queue/reprovide.go b/provider/internal/queue/reprovide.go new file mode 100644 index 000000000..4c868f0e9 --- /dev/null +++ b/provider/internal/queue/reprovide.go @@ -0,0 +1,70 @@ +package queue + +import ( + "sync" + + "github.com/probe-lab/go-libdht/kad/key/bitstr" + "github.com/probe-lab/go-libdht/kad/trie" +) + +// ReprovideQueue is a thread-safe queue storing non-overlapping, unique +// kademlia keyspace prefixes in the order they were enqueued. +type ReprovideQueue struct { + mu sync.Mutex + queue prefixQueue +} + +// New creates a new ReprovideQueue instance. +func NewReprovideQueue() *ReprovideQueue { + return &ReprovideQueue{queue: prefixQueue{prefixes: trie.New[bitstr.Key, struct{}]()}} +} + +// Enqueue adds the supplied prefix to the queue. +// +// If the prefix is already in the queue, this is a no-op. +// +// If the queue contains superstrings of the supplied prefix, insert the +// supplied prefix at the position of the first superstring in the queue, and +// remove all superstrings from the queue. The prefixes are consolidated around +// the shortest prefix. +func (q *ReprovideQueue) Enqueue(prefix bitstr.Key) { + q.mu.Lock() + defer q.mu.Unlock() + q.queue.Push(prefix) +} + +// Dequeue removes and returns the first prefix from the queue. +func (q *ReprovideQueue) Dequeue() (bitstr.Key, bool) { + q.mu.Lock() + defer q.mu.Unlock() + return q.queue.Pop() +} + +// Remove removes a prefix or all its superstrings from the queue, if any. +func (q *ReprovideQueue) Remove(prefix bitstr.Key) bool { + q.mu.Lock() + defer q.mu.Unlock() + return q.queue.Remove(prefix) +} + +// IsEmpty returns true if the queue is empty. +func (q *ReprovideQueue) IsEmpty() bool { + q.mu.Lock() + defer q.mu.Unlock() + return q.queue.Size() == 0 +} + +// Size returns the number of prefixes currently in the queue. +func (q *ReprovideQueue) Size() int { + q.mu.Lock() + defer q.mu.Unlock() + return q.queue.Size() +} + +// Clear removes all prefixes from the queue and returns the number of removed +// prefixes. +func (q *ReprovideQueue) Clear() int { + q.mu.Lock() + defer q.mu.Unlock() + return q.queue.Clear() +} diff --git a/provider/internal/queue/reprovide_test.go b/provider/internal/queue/reprovide_test.go new file mode 100644 index 000000000..fcf0d90fc --- /dev/null +++ b/provider/internal/queue/reprovide_test.go @@ -0,0 +1,112 @@ +package queue + +import ( + "testing" + + "github.com/probe-lab/go-libdht/kad/key/bitstr" + "github.com/probe-lab/go-libdht/kad/trie" + "github.com/stretchr/testify/require" +) + +func TestReprovideEnqueue(t *testing.T) { + q := NewReprovideQueue() + + q.Enqueue("000") + require.Equal(t, 1, q.Size()) + q.Enqueue("101") + require.Equal(t, 2, q.Size()) + q.Enqueue("000") + require.Equal(t, 2, q.Size()) // Duplicate prefix, size should not change + q.Enqueue("001") + require.Equal(t, 3, q.Size()) + q.Enqueue("1000") + require.Equal(t, 4, q.Size()) + q.Enqueue("1001") + require.Equal(t, 5, q.Size()) + + q.Enqueue("10") + require.Equal(t, 3, q.Size()) // "10" consolidates "1000", "1001" and "101" + require.Equal(t, bitstr.Key("000"), q.queue.queue.At(0)) + require.Equal(t, bitstr.Key("10"), q.queue.queue.At(1)) // "10" has taken the place of "101" + require.Equal(t, bitstr.Key("001"), q.queue.queue.At(2)) +} + +func TestReprovideDequeue(t *testing.T) { + keys := []bitstr.Key{ + "10", + "001", + "00001", + "00000", + "111", + "1101", + } + + q := NewReprovideQueue() + for _, k := range keys { + q.Enqueue(k) + } + require.False(t, q.IsEmpty()) + + for i, k := range keys { + dequeued, ok := q.Dequeue() + require.True(t, ok) + require.Equal(t, k, dequeued) + require.Equal(t, len(keys)-i-1, q.Size()) + } + + require.True(t, q.IsEmpty()) + _, ok := q.Dequeue() + require.False(t, ok) + + require.True(t, q.IsEmpty()) +} + +func TestReprovideRemove(t *testing.T) { + keys := []bitstr.Key{ + "10", + "001", + "00001", + "00000", + "111", + "1101", + } + q := NewReprovideQueue() + for _, k := range keys { + q.Enqueue(k) + } + + ok := q.Remove("001") + require.True(t, ok) + require.Negative(t, q.queue.queue.Index(func(k bitstr.Key) bool { return k == bitstr.Key("001") })) // not in queue + ok, _ = trie.Find(q.queue.prefixes, bitstr.Key("001")) // not in trie + require.False(t, ok) + + ok = q.Remove("1111") // not in queue + require.False(t, ok) + + require.Equal(t, len(keys)-1, q.Size()) + // Test order + require.Equal(t, keys[0], q.queue.queue.At(0)) + require.Equal(t, keys[2], q.queue.queue.At(1)) +} + +func TestReprovideClearQueue(t *testing.T) { + keys := []bitstr.Key{ + "10", + "001", + "00001", + "00000", + "111", + "1101", + } + q := NewReprovideQueue() + for _, k := range keys { + q.Enqueue(k) + } + + cleared := q.Clear() + require.Equal(t, len(keys), cleared) + require.True(t, q.IsEmpty()) + require.Equal(t, 0, q.queue.prefixes.Size()) + require.Equal(t, 0, q.queue.queue.Len()) +} diff --git a/provider/keystore/keystore.go b/provider/keystore/keystore.go new file mode 100644 index 000000000..63eca79ff --- /dev/null +++ b/provider/keystore/keystore.go @@ -0,0 +1,440 @@ +package keystore + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "strings" + + ds "github.com/ipfs/go-datastore" + "github.com/ipfs/go-datastore/namespace" + "github.com/ipfs/go-datastore/query" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace" + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad" + "github.com/probe-lab/go-libdht/kad/key/bit256" + "github.com/probe-lab/go-libdht/kad/key/bitstr" +) + +var ErrClosed = errors.New("keystore is closed") + +// Keystore provides thread-safe storage and retrieval of multihashes, indexed +// by their kademlia 256-bit identifier. +type Keystore interface { + Put(context.Context, ...mh.Multihash) ([]mh.Multihash, error) + Get(context.Context, bitstr.Key) ([]mh.Multihash, error) + ContainsPrefix(context.Context, bitstr.Key) (bool, error) + Delete(context.Context, ...mh.Multihash) error + Empty(context.Context) error + Size(context.Context) (int, error) + Close() error +} + +// operation types for the worker goroutine +type opType uint8 + +const ( + opPut opType = iota + opGet + opContainsPrefix + opDelete + opEmpty + opSize + lastOp +) + +// operation request sent to worker goroutine +type operation struct { + op opType + ctx context.Context + keys []mh.Multihash + prefix bitstr.Key + response chan<- operationResponse +} + +// response from worker goroutine +type operationResponse struct { + multihashes []mh.Multihash + found bool + size int + err error +} + +// keystore indexes multihashes by their kademlia identifier. +type keystore struct { + ds ds.Batching + prefixBits int + batchSize int + + // worker goroutine communication + requests chan operation + close chan struct{} + done chan struct{} +} + +// NewKeystore creates a new Keystore backed by the provided datastore. +func NewKeystore(d ds.Batching, opts ...Option) (Keystore, error) { + cfg, err := getOpts(opts) + if err != nil { + return nil, err + } + ks := &keystore{ + ds: namespace.Wrap(d, ds.NewKey(cfg.path)), + prefixBits: cfg.prefixBits, + batchSize: cfg.batchSize, + requests: make(chan operation), + close: make(chan struct{}), + done: make(chan struct{}), + } + go ks.worker() + + return ks, nil +} + +// dsKey returns the datastore key for the provided binary key. +// +// The function creates a hierarchical datastore key by expanding bits into +// path components (`0` or `1`) separated by `/`, and optionally a +// base64URL-encoded suffix. +// +// Full keys (256-bit): +// The first `prefixBits` bits become individual path components, and the +// remaining bytes (after prefixBits/8) are base64URL encoded as the final +// component. Example: "/0/0/0/0/1/1/1/1/AAAA...A==" +// +// Prefix keys (<256-bit): +// If the key is shorter than 256-bits, only the available bits (up to +// `prefixBits`) become path components. No base64URL suffix is added. This +// creates a prefix that can be used in datastore queries to find all matching +// full keys. +// +// If the prefix is longer than `prefixBits`, only the first `prefixBits` bits +// are used, allowing the returned key to serve as a query prefix for the +// datastore. +func dsKey[K kad.Key[K]](k K, prefixBits int) ds.Key { + b := strings.Builder{} + l := k.BitLen() + for i := range min(prefixBits, l) { + b.WriteRune(rune('0' + k.Bit(i))) + b.WriteRune('/') + } + if l == keyspace.KeyLen { + b.WriteString(base64.URLEncoding.EncodeToString(keyspace.KeyToBytes(k)[prefixBits/8:])) + } + return ds.NewKey(b.String()) +} + +// decodeKey reconstructs a 256-bit binary key from a hierarchical datastore key string. +// +// This function reverses the process of dsKey, converting a datastore key back into +// its original binary representation by parsing the individual bit components and +// base64URL-encoded suffix. +// +// The input datastore key format is expected to be: +// "/bit0/bit1/.../bitN/base64url_suffix" +// +// Returns the reconstructed 256-bit key or an error if base64URL decoding fails. +func (s *keystore) decodeKey(dsk string) (bit256.Key, error) { + bs := make([]byte, 32) + // Extract individual bits from odd positions (skip '/' separators) + for i := range s.prefixBits { + if dsk[2*i+1] == '1' { + bs[i/8] |= byte(1) << (7 - i%8) + } + } + // Decode base64URL suffix and append to remaining bytes + decoded, err := base64.URLEncoding.DecodeString(dsk[2*(s.prefixBits)+1:]) + if err != nil { + return bit256.Key{}, err + } + copy(bs[s.prefixBits/8:], decoded) + return bit256.NewKey(bs), nil +} + +// worker processes operations sequentially in a single goroutine +func (s *keystore) worker() { + defer close(s.done) + + for { + select { + case <-s.close: + return + case op := <-s.requests: + switch op.op { + case opPut: + newKeys, err := s.put(op.ctx, op.keys) + op.response <- operationResponse{multihashes: newKeys, err: err} + + case opGet: + keys, err := s.get(op.ctx, op.prefix) + op.response <- operationResponse{multihashes: keys, err: err} + + case opContainsPrefix: + found, err := s.containsPrefix(op.ctx, op.prefix) + op.response <- operationResponse{found: found, err: err} + + case opDelete: + err := s.delete(op.ctx, op.keys) + op.response <- operationResponse{err: err} + + case opEmpty: + err := empty(op.ctx, s.ds, s.batchSize) + op.response <- operationResponse{err: err} + + case opSize: + size, err := s.size(op.ctx) + op.response <- operationResponse{size: size, err: err} + + default: + op.response <- operationResponse{err: fmt.Errorf("unknown operation %d", op.op)} + } + } + } +} + +// put stores the provided keys while assuming s.lk is already held, and +// returns the keys that weren't present already in the keystore. +func (s *keystore) put(ctx context.Context, keys []mh.Multihash) ([]mh.Multihash, error) { + seen := make(map[bit256.Key]struct{}, len(keys)) + b, err := s.ds.Batch(ctx) + if err != nil { + return nil, err + } + newKeys := make([]mh.Multihash, 0, len(keys)) + + for _, h := range keys { + k := keyspace.MhToBit256(h) + if _, ok := seen[k]; ok { + continue + } + seen[k] = struct{}{} + dsk := dsKey(k, s.prefixBits) + ok, err := s.ds.Has(ctx, dsk) + if err != nil { + return nil, err + } + if !ok { + if err := b.Put(ctx, dsk, h); err != nil { + return nil, err + } + newKeys = append(newKeys, h) + } + } + if err := b.Commit(ctx); err != nil { + return nil, err + } + return newKeys, nil +} + +// get returns all keys whose bit256 representation matches the provided +// prefix. +func (s *keystore) get(ctx context.Context, prefix bitstr.Key) ([]mh.Multihash, error) { + out := make([]mh.Multihash, 0) + longPrefix := prefix.BitLen() > s.prefixBits + + dsk := dsKey(prefix, s.prefixBits).String() + q := query.Query{Prefix: dsk} + for r, err := range ds.QueryIter(ctx, s.ds, q) { + if err != nil { + return nil, err + } + // Depending on prefix length, filter out non matching keys + if longPrefix { + k, err := s.decodeKey(r.Key) + if err != nil { + return nil, err + } + if !keyspace.IsPrefix(prefix, k) { + continue + } + } + out = append(out, mh.Multihash(r.Value)) + } + + return out, nil +} + +// containsPrefix reports whether the Keystore currently holds at least one +// multihash whose kademlia identifier (bit256.Key) starts with the provided +// bit-prefix. +func (s *keystore) containsPrefix(ctx context.Context, prefix bitstr.Key) (bool, error) { + dsk := dsKey(prefix, s.prefixBits).String() + q := query.Query{Prefix: dsk, KeysOnly: true} + longPrefix := prefix.BitLen() > s.prefixBits + if !longPrefix { + // Exact match on hex character, only one possible match + q.Limit = 1 + } + for r, err := range ds.QueryIter(ctx, s.ds, q) { + if err != nil { + return false, err + } + if !longPrefix { + return true, nil + } + k, err := s.decodeKey(r.Key) + if err != nil { + return false, err + } + if keyspace.IsPrefix(prefix, k) { + return true, nil + } + } + return false, nil +} + +// empty deletes all entries under the datastore prefix, assuming s.lk is +// already held. +func empty(ctx context.Context, d ds.Batching, batchSize int) error { + batch, err := d.Batch(ctx) + if err != nil { + return err + } + var writeCount int + q := query.Query{KeysOnly: true} + for res, err := range ds.QueryIter(ctx, d, q) { + if ctx.Err() != nil { + return ctx.Err() + } + if writeCount >= batchSize { + writeCount = 0 + if err = batch.Commit(ctx); err != nil { + return fmt.Errorf("cannot commit keystore updates: %w", err) + } + } + if err != nil { + return fmt.Errorf("cannot read query result from keystore: %w", err) + } + if err = batch.Delete(ctx, ds.NewKey(res.Key)); err != nil { + return fmt.Errorf("cannot delete key from keystore: %w", err) + } + writeCount++ + } + if err = batch.Commit(ctx); err != nil { + return fmt.Errorf("cannot commit keystore updates: %w", err) + } + if err = d.Sync(ctx, ds.NewKey("")); err != nil { + return fmt.Errorf("cannot sync datastore: %w", err) + } + return nil +} + +// delete removes the given keys from datastore. +func (s *keystore) delete(ctx context.Context, keys []mh.Multihash) error { + b, err := s.ds.Batch(ctx) + if err != nil { + return err + } + for _, h := range keys { + dsk := dsKey(keyspace.MhToBit256(h), s.prefixBits) + err := b.Delete(ctx, dsk) + if err != nil { + return err + } + } + return b.Commit(ctx) +} + +// size returns the number of keys currently stored in the Keystore. +func (s *keystore) size(ctx context.Context) (size int, err error) { + q := query.Query{KeysOnly: true} + for _, err = range ds.QueryIter(ctx, s.ds, q) { + if err != nil { + return + } + size++ + } + return +} + +// executeOperation sends an operation request to the worker goroutine and +// waits for the response. It handles the communication protocol and returns +// the results based on the operation type. +func (s *keystore) executeOperation(op opType, ctx context.Context, keys []mh.Multihash, prefix bitstr.Key) ([]mh.Multihash, int, bool, error) { + response := make(chan operationResponse, 1) + select { + case s.requests <- operation{ + op: op, + ctx: ctx, + keys: keys, + prefix: prefix, + response: response, + }: + case <-ctx.Done(): + return nil, 0, false, ctx.Err() + case <-s.close: + return nil, 0, false, ErrClosed + } + + select { + case resp := <-response: + return resp.multihashes, resp.size, resp.found, resp.err + case <-ctx.Done(): + return nil, 0, false, ctx.Err() + } +} + +// Put stores the provided keys in the underlying datastore, grouping them by +// the first prefixLen bits. It returns only the keys that were not previously +// persisted in the datastore (i.e., newly added keys). +func (s *keystore) Put(ctx context.Context, keys ...mh.Multihash) ([]mh.Multihash, error) { + if len(keys) == 0 { + return nil, nil + } + newKeys, _, _, err := s.executeOperation(opPut, ctx, keys, "") + return newKeys, err +} + +// Get returns all keys whose bit256 representation matches the provided +// prefix. +func (s *keystore) Get(ctx context.Context, prefix bitstr.Key) ([]mh.Multihash, error) { + keys, _, _, err := s.executeOperation(opGet, ctx, nil, prefix) + return keys, err +} + +// ContainsPrefix reports whether the Keystore currently holds at least one +// multihash whose kademlia identifier (bit256.Key) starts with the provided +// bit-prefix. +func (s *keystore) ContainsPrefix(ctx context.Context, prefix bitstr.Key) (bool, error) { + _, _, found, err := s.executeOperation(opContainsPrefix, ctx, nil, prefix) + return found, err +} + +// Empty deletes all entries under the datastore prefix. +func (s *keystore) Empty(ctx context.Context) error { + _, _, _, err := s.executeOperation(opEmpty, ctx, nil, "") + return err +} + +// Delete removes the given keys from datastore. +func (s *keystore) Delete(ctx context.Context, keys ...mh.Multihash) error { + if len(keys) == 0 { + return nil + } + _, _, _, err := s.executeOperation(opDelete, ctx, keys, "") + return err +} + +// Size returns the number of keys currently stored in the Keystore. +// +// The size is obtained by iterating over all keys in the underlying +// datastore, so it may be expensive for large stores. +func (s *keystore) Size(ctx context.Context) (int, error) { + _, size, _, err := s.executeOperation(opSize, ctx, nil, "") + return size, err +} + +// Close shuts down the worker goroutine and releases resources. +func (s *keystore) Close() error { + select { + case <-s.close: + // Already closed + return nil + default: + close(s.close) + <-s.done + } + return nil +} diff --git a/provider/keystore/keystore_test.go b/provider/keystore/keystore_test.go new file mode 100644 index 000000000..7671bb194 --- /dev/null +++ b/provider/keystore/keystore_test.go @@ -0,0 +1,284 @@ +package keystore + +import ( + "context" + "crypto/rand" + "strings" + "testing" + + ds "github.com/ipfs/go-datastore" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace" + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad/key" + "github.com/probe-lab/go-libdht/kad/key/bit256" + "github.com/probe-lab/go-libdht/kad/key/bitstr" + + "github.com/ipfs/go-test/random" + "github.com/stretchr/testify/require" +) + +func TestKeystorePutAndGet(t *testing.T) { + t.Run("Keystore", func(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + store, err := NewKeystore(ds) + require.NoError(t, err) + defer store.Close() + + testKeyStorePutAndGetImpl(t, store) + }) + + t.Run("ResettableKeystore", func(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + store, err := NewResettableKeystore(ds) + require.NoError(t, err) + defer store.Close() + + testKeyStorePutAndGetImpl(t, store) + }) +} + +func testKeyStorePutAndGetImpl(t *testing.T, store Keystore) { + mhs := make([]mh.Multihash, 6) + for i := range mhs { + h, err := mh.Sum([]byte{byte(i)}, mh.SHA2_256, -1) + require.NoError(t, err) + mhs[i] = h + } + + added, err := store.Put(context.Background(), mhs...) + require.NoError(t, err) + require.Len(t, added, len(mhs)) + + added, err = store.Put(context.Background(), mhs...) + require.NoError(t, err) + require.Empty(t, added) + + for _, h := range mhs { + prefix := bitstr.Key(key.BitString(keyspace.MhToBit256(h))[:6]) + got, err := store.Get(context.Background(), prefix) + if err != nil { + t.Fatal(err) + } + found := false + for _, m := range got { + if string(m) == string(h) { + found = true + break + } + } + require.True(t, found, "expected to find multihash %v for prefix %s", h, prefix) + } + + p := bitstr.Key(key.BitString(keyspace.MhToBit256(mhs[0]))[:3]) + res, err := store.Get(context.Background(), p) + require.NoError(t, err) + require.NotEmpty(t, res, "expected results for prefix %s", p) + + longPrefix := bitstr.Key(key.BitString(keyspace.MhToBit256(mhs[0]))[:15]) + res, err = store.Get(context.Background(), longPrefix) + if err != nil { + t.Fatal(err) + } + for _, h := range res { + bs := bitstr.Key(key.BitString(keyspace.MhToBit256(h))) + require.True(t, keyspace.IsPrefix(longPrefix, bs), "returned hash does not match long prefix") + } +} + +func genMultihashesMatchingPrefix(prefix bitstr.Key, n int) []mh.Multihash { + mhs := make([]mh.Multihash, 0, n) + for i := 0; len(mhs) < n; i++ { + h := random.Multihashes(1)[0] + k := keyspace.MhToBit256(h) + if keyspace.IsPrefix(prefix, k) { + mhs = append(mhs, h) + } + } + return mhs +} + +func TestKeyStoreContainsPrefix(t *testing.T) { + t.Run("Keystore", func(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + store, err := NewKeystore(ds) + require.NoError(t, err) + defer store.Close() + + testKeystoreContainsPrefixImpl(t, store) + }) + + t.Run("ResettableKeystore", func(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + store, err := NewResettableKeystore(ds) + require.NoError(t, err) + defer store.Close() + + testKeystoreContainsPrefixImpl(t, store) + }) +} + +func testKeystoreContainsPrefixImpl(t *testing.T, store Keystore) { + ctx := context.Background() + + ok, err := store.ContainsPrefix(ctx, bitstr.Key("0000")) + require.NoError(t, err) + require.False(t, ok) + + generated := genMultihashesMatchingPrefix(bitstr.Key(strings.Repeat("0", 10)), 1) + require.True(t, keyspace.IsPrefix(bitstr.Key("0000"), keyspace.MhToBit256(generated[0]))) + store.Put(ctx, generated...) + + ok, err = store.ContainsPrefix(ctx, bitstr.Key("0")) + require.NoError(t, err) + require.True(t, ok) + + ok, err = store.ContainsPrefix(ctx, bitstr.Key("0000")) + require.NoError(t, err) + require.True(t, ok) + + ok, err = store.ContainsPrefix(ctx, bitstr.Key(strings.Repeat("0", 6))) + require.NoError(t, err) + require.True(t, ok) + + ok, err = store.ContainsPrefix(ctx, bitstr.Key(strings.Repeat("0", 10))) + require.NoError(t, err) + require.True(t, ok) + + ok, err = store.ContainsPrefix(ctx, bitstr.Key("1")) + require.NoError(t, err) + require.False(t, ok) + + ok, err = store.ContainsPrefix(ctx, bitstr.Key("0001")) + require.NoError(t, err) + require.False(t, ok) + + ok, err = store.ContainsPrefix(ctx, bitstr.Key(strings.Repeat("0", 8)+"1")) + require.NoError(t, err) + require.False(t, ok) +} + +func TestKeystoreDelete(t *testing.T) { + t.Run("Keystore", func(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + store, err := NewKeystore(ds) + require.NoError(t, err) + defer store.Close() + + testKeystoreDeleteImpl(t, store) + }) + + t.Run("ResettableKeystore", func(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + store, err := NewResettableKeystore(ds) + require.NoError(t, err) + defer store.Close() + + testKeystoreDeleteImpl(t, store) + }) +} + +func testKeystoreDeleteImpl(t *testing.T, store Keystore) { + mhs := random.Multihashes(3) + for i := range mhs { + h, err := mh.Sum([]byte{byte(i)}, mh.SHA2_256, -1) + require.NoError(t, err) + mhs[i] = h + } + _, err := store.Put(context.Background(), mhs...) + require.NoError(t, err) + + delPrefix := bitstr.Key(key.BitString(keyspace.MhToBit256(mhs[0]))[:6]) + err = store.Delete(context.Background(), mhs[0]) + require.NoError(t, err) + + res, err := store.Get(context.Background(), delPrefix) + require.NoError(t, err) + for _, h := range res { + require.NotEqual(t, string(h), string(mhs[0]), "expected deleted hash to be gone") + } + + // other hashes should still be retrievable + otherPrefix := bitstr.Key(key.BitString(keyspace.MhToBit256(mhs[1]))[:6]) + res, err = store.Get(context.Background(), otherPrefix) + require.NoError(t, err) + require.NotEmpty(t, res, "expected remaining hashes for other prefix") +} + +func TestKeystoreSize(t *testing.T) { + t.Run("Keystore", func(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + store, err := NewKeystore(ds) + require.NoError(t, err) + defer store.Close() + + testKeystoreSizeImpl(t, store) + }) + + t.Run("ResettableKeystore", func(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + store, err := NewResettableKeystore(ds) + require.NoError(t, err) + defer store.Close() + + testKeystoreSizeImpl(t, store) + }) +} + +func testKeystoreSizeImpl(t *testing.T, store Keystore) { + ctx := context.Background() + + mhs0 := random.Multihashes(128) + _, err := store.Put(ctx, mhs0...) + require.NoError(t, err) + + size, err := store.Size(ctx) + require.NoError(t, err) + require.Equal(t, len(mhs0), size) + + nKeys := 1 << 12 + batches := 1 << 6 + for range batches { + mhs1 := random.Multihashes(nKeys / batches) + _, err = store.Put(ctx, mhs1...) + require.NoError(t, err) + } + + size, err = store.Size(ctx) + require.NoError(t, err) + require.Equal(t, len(mhs0)+nKeys, size) +} + +func TestDsKey(t *testing.T) { + s := keystore{ + prefixBits: 8, + } + + k := bit256.ZeroKey() + dsk := dsKey(k, s.prefixBits) + expectedPrefix := "/0/0/0/0/0/0/0/0/" + require.Equal(t, expectedPrefix, dsk.String()[:len(expectedPrefix)]) + + s.prefixBits = 16 + + b := [32]byte{} + for range 1024 { + _, err := rand.Read(b[:]) + require.NoError(t, err) + k := bit256.NewKey(b[:]) + + sdk := dsKey(k, s.prefixBits) + require.Equal(t, s.prefixBits+1, strings.Count(sdk.String(), "/")) + decoded, err := s.decodeKey(sdk.String()) + require.NoError(t, err) + require.Equal(t, k, decoded) + } +} diff --git a/provider/keystore/options.go b/provider/keystore/options.go new file mode 100644 index 000000000..2d45b594a --- /dev/null +++ b/provider/keystore/options.go @@ -0,0 +1,74 @@ +package keystore + +import "fmt" + +type config struct { + path string + prefixBits int + batchSize int +} + +// Options for configuring a Keystore. +type Option func(*config) error + +const ( + DefaultPath = "/provider/keystore" + DefaultBatchSize = 1 << 14 + DefaultPrefixBits = 16 +) + +// getOpts creates a config and applies Options to it. +func getOpts(opts []Option) (config, error) { + cfg := config{ + path: DefaultPath, + prefixBits: DefaultPrefixBits, + batchSize: DefaultBatchSize, + } + + for i, opt := range opts { + if err := opt(&cfg); err != nil { + return config{}, fmt.Errorf("option %d error: %s", i, err) + } + } + return cfg, nil +} + +// WithDatastorePath sets the datastore prefix under which multihashes are +// stored. +func WithDatastorePath(path string) Option { + return func(cfg *config) error { + if path == "" { + return fmt.Errorf("datastore name cannot be empty") + } + cfg.path = path + return nil + } +} + +// WithPrefixBits sets how many bits from binary keys become individual path +// components in datastore keys. Higher values create deeper hierarchies but +// enable more granular prefix queries. +// +// Must be a multiple of 8 between 0 and 256 (inclusive) to align with byte +// boundaries. +func WithPrefixBits(prefixBits int) Option { + return func(cfg *config) error { + if prefixBits < 0 || prefixBits > 256 || prefixBits%8 != 0 { + return fmt.Errorf("invalid prefix bits %d, must be a non-negative multiple of 8 less or equal to 256", prefixBits) + } + cfg.prefixBits = prefixBits + return nil + } +} + +// WithBatchSize defines the maximal number of keys per batch when reading or +// writing to the datastore. It is typically used in Empty() and ResetCids(). +func WithBatchSize(size int) Option { + return func(cfg *config) error { + if size <= 0 { + return fmt.Errorf("invalid batch size %d", size) + } + cfg.batchSize = size + return nil + } +} diff --git a/provider/keystore/resettable_keystore.go b/provider/keystore/resettable_keystore.go new file mode 100644 index 000000000..9500c1b47 --- /dev/null +++ b/provider/keystore/resettable_keystore.go @@ -0,0 +1,266 @@ +package keystore + +import ( + "context" + "errors" + + "github.com/ipfs/go-cid" + ds "github.com/ipfs/go-datastore" + "github.com/ipfs/go-datastore/namespace" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace" + mh "github.com/multiformats/go-multihash" +) + +var ErrResetInProgress = errors.New("reset already in progress") + +const ( + opStart opType = iota + opCleanup +) + +type resetOp struct { + op opType + success bool + response chan<- error +} + +// ResettableKeystore is a Keystore implementation that supports atomic reset +// operations using a dual-datastore architecture. It maintains two separate +// datastores (primary and alternate) where only one is active at any time, +// enabling atomic replacement of all stored keys without interrupting +// concurrent operations. +// +// Architecture: +// - Primary datastore: Currently active storage for all read/write operations +// - Alternate datastore: Standby storage used during reset operations +// - The datastores use "/0" and "/1" namespace suffixes and can be swapped +// +// Reset Operation Flow: +// 1. New keys from reset are written to the alternate (inactive) datastore +// 2. Concurrent Put operations are automatically duplicated to both datastores +// to maintain consistency during the transition +// 3. Once all reset keys are written, the datastores are atomically swapped +// 4. The old datastore (now alternate) is cleaned up +// +// Thread Safety: +// - All operations are processed sequentially by a single worker goroutine +// - Reset operations are non-blocking for concurrent reads and writes +// - Only one reset operation can be active at a time +// +// The reset operation allows complete replacement of stored multihashes +// without data loss or service interruption, making it suitable for +// scenarios requiring periodic full dataset updates. +type ResettableKeystore struct { + keystore + + altDs ds.Batching + resetInProgress bool + resetOps chan resetOp // reset operations that must be run in main go routine +} + +var _ Keystore = (*ResettableKeystore)(nil) + +// NewResettableKeystore creates a new ResettableKeystore backed by the +// provided datastore. It automatically adds "/0" and "/1" suffixes to the +// configured datastore path to create two alternate storage locations for +// atomic reset operations. +func NewResettableKeystore(d ds.Batching, opts ...Option) (*ResettableKeystore, error) { + cfg, err := getOpts(opts) + if err != nil { + return nil, err + } + + rks := &ResettableKeystore{ + keystore: keystore{ + ds: namespace.Wrap(d, ds.NewKey(cfg.path+"/0")), + prefixBits: cfg.prefixBits, + batchSize: cfg.batchSize, + requests: make(chan operation), + close: make(chan struct{}), + done: make(chan struct{}), + }, + altDs: namespace.Wrap(d, ds.NewKey(cfg.path+"/1")), + resetOps: make(chan resetOp), + } + + // start worker goroutine + go rks.worker() + + return rks, nil +} + +// worker processes operations sequentially in a single goroutine for ResettableKeystore +func (s *ResettableKeystore) worker() { + defer close(s.done) + + for { + select { + case <-s.close: + return + case op := <-s.requests: + switch op.op { + case opPut: + newKeys, err := s.put(op.ctx, op.keys) + op.response <- operationResponse{multihashes: newKeys, err: err} + + case opGet: + keys, err := s.get(op.ctx, op.prefix) + op.response <- operationResponse{multihashes: keys, err: err} + + case opContainsPrefix: + found, err := s.containsPrefix(op.ctx, op.prefix) + op.response <- operationResponse{found: found, err: err} + + case opDelete: + err := s.delete(op.ctx, op.keys) + op.response <- operationResponse{err: err} + + case opEmpty: + err := empty(op.ctx, s.ds, s.batchSize) + op.response <- operationResponse{err: err} + + case opSize: + size, err := s.size(op.ctx) + op.response <- operationResponse{size: size, err: err} + } + case op := <-s.resetOps: + s.handleResetOp(op) + } + } +} + +// resettablePutLocked handles put operations for ResettableKeystore, with special +// handling during reset operations. +func (s *ResettableKeystore) put(ctx context.Context, keys []mh.Multihash) ([]mh.Multihash, error) { + if s.resetInProgress { + // Reset is in progress, write to alternate datastore in addition to + // current datastore + s.altPut(ctx, keys) + } + return s.keystore.put(ctx, keys) +} + +// altPut writes the given multihashes to the alternate datastore. +func (s *ResettableKeystore) altPut(ctx context.Context, keys []mh.Multihash) error { + b, err := s.altDs.Batch(ctx) + if err != nil { + return err + } + for _, h := range keys { + dsk := dsKey(keyspace.MhToBit256(h), s.prefixBits) + if err := b.Put(ctx, dsk, h); err != nil { + return err + } + } + return b.Commit(ctx) +} + +// handleResetOp processes reset operations that need to happen synchronously. +func (s *ResettableKeystore) handleResetOp(op resetOp) { + if op.op == opStart { + if s.resetInProgress { + op.response <- ErrResetInProgress + return + } + if err := empty(context.Background(), s.altDs, s.batchSize); err != nil { + op.response <- err + return + } + s.resetInProgress = true + op.response <- nil + return + } + + // Cleanup operation + if op.success { + // Swap the active datastore. + oldDs := s.ds + s.ds = s.altDs + s.altDs = oldDs + } + // Empty the unused datastore. + s.resetInProgress = false + op.response <- empty(context.Background(), s.altDs, s.batchSize) +} + +// ResetCids atomically replaces all stored keys with the CIDs received from +// keysChan. The operation is thread-safe and non-blocking for concurrent reads +// and writes. +// +// During the reset: +// - New keys from keysChan are written to an alternate storage location +// - Concurrent Put operations are duplicated to both current and alternate +// locations +// - Once all keys are processed, storage locations are atomically swapped +// - The old storage location is cleaned up +// +// Returns ErrResetInProgress if another reset operation is already running. +// The operation can be cancelled via context, which will clean up partial +// state. +func (s *ResettableKeystore) ResetCids(ctx context.Context, keysChan <-chan cid.Cid) error { + if keysChan == nil { + return nil + } + + opsChan := make(chan error) + select { + case <-ctx.Done(): + return ctx.Err() + case <-s.done: + return ErrClosed + case s.resetOps <- resetOp{op: opStart, response: opsChan}: + select { + case err := <-opsChan: + if err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } + + var success bool + + defer func() { + // Cleanup before returning on success and failure + select { + case s.resetOps <- resetOp{op: opCleanup, success: success, response: opsChan}: + <-opsChan + case <-s.done: + // Safe not to go through the worker since we are done, and we need to + // cleanup + empty(context.Background(), s.altDs, s.batchSize) + } + }() + + keys := make([]mh.Multihash, 0) + + // Read all the keys from the channel and write them to the altDs +loop: + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-s.done: + return ErrClosed + case c, ok := <-keysChan: + if !ok { + break loop + } + keys = append(keys, c.Hash()) + if len(keys) >= s.batchSize { + if err := s.altPut(ctx, keys); err != nil { + return err + } + keys = keys[:0] + } + } + } + // Put final batch + if err := s.altPut(ctx, keys); err != nil { + return err + } + success = true + + return nil +} diff --git a/provider/keystore/resettable_keystore_test.go b/provider/keystore/resettable_keystore_test.go new file mode 100644 index 000000000..0d9a1f12a --- /dev/null +++ b/provider/keystore/resettable_keystore_test.go @@ -0,0 +1,72 @@ +package keystore + +import ( + "context" + "testing" + + "github.com/ipfs/go-cid" + ds "github.com/ipfs/go-datastore" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace" + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad/key" + "github.com/probe-lab/go-libdht/kad/key/bitstr" + + "github.com/stretchr/testify/require" +) + +func TestKeystoreReset(t *testing.T) { + ds := ds.NewMapDatastore() + defer ds.Close() + + store, err := NewResettableKeystore(ds) + require.NoError(t, err) + defer store.Close() + + first := make([]mh.Multihash, 2) + for i := range first { + h, err := mh.Sum([]byte{byte(i)}, mh.SHA2_256, -1) + require.NoError(t, err) + first[i] = h + } + _, err = store.Put(context.Background(), first...) + require.NoError(t, err) + + secondChan := make(chan cid.Cid, 2) + second := make([]mh.Multihash, 2) + for i := range 2 { + h, err := mh.Sum([]byte{byte(i + 10)}, mh.SHA2_256, -1) + require.NoError(t, err) + second[i] = h + secondChan <- cid.NewCidV1(cid.Raw, h) + } + close(secondChan) + + err = store.ResetCids(context.Background(), secondChan) + require.NoError(t, err) + + // old hashes should not be present + for _, h := range first { + prefix := bitstr.Key(key.BitString(keyspace.MhToBit256(h))[:6]) + got, err := store.Get(context.Background(), prefix) + require.NoError(t, err) + for _, m := range got { + require.NotEqual(t, string(m), string(h), "expected old hash %v to be removed", h) + } + } + + // new hashes should be retrievable + for _, h := range second { + prefix := bitstr.Key(key.BitString(keyspace.MhToBit256(h))[:6]) + got, err := store.Get(context.Background(), prefix) + require.NoError(t, err) + found := false + for _, m := range got { + if string(m) == string(h) { + found = true + break + } + } + require.True(t, found, "expected hash %v after reset", h) + } +} diff --git a/provider/options.go b/provider/options.go new file mode 100644 index 000000000..5d3ffda99 --- /dev/null +++ b/provider/options.go @@ -0,0 +1,283 @@ +package provider + +import ( + "errors" + "fmt" + "time" + + "github.com/libp2p/go-libp2p-kad-dht/amino" + pb "github.com/libp2p/go-libp2p-kad-dht/pb" + "github.com/libp2p/go-libp2p-kad-dht/provider/keystore" + "github.com/libp2p/go-libp2p/core/peer" + ma "github.com/multiformats/go-multiaddr" + mh "github.com/multiformats/go-multihash" +) + +const ( + // DefaultMaxReprovideDelay is the default maximum delay allowed when + // reproviding a region. The interval between 2 reprovides of the same region + // is at most ReprovideInterval+MaxReprovideDelay. This variable is necessary + // since regions can grow and shrink depending on the network churn. + DefaultMaxReprovideDelay = 1 * time.Hour + + // DefaultOfflineDelay is the default delay after which a disconnected node + // is considered as Offline. + DefaultOfflineDelay = 2 * time.Hour + // DefaultConnectivityCheckOnlineInterval is the default minimum interval for + // checking whether the node is still online. Such a check is performed when + // a network operation fails, and the ConnectivityCheckOnlineInterval limits + // how often such a check is performed. + DefaultConnectivityCheckOnlineInterval = 1 * time.Minute +) + +type config struct { + replicationFactor int + reprovideInterval time.Duration + maxReprovideDelay time.Duration + + offlineDelay time.Duration + connectivityCheckOnlineInterval time.Duration + + peerid peer.ID + router KadClosestPeersRouter + + keystore keystore.Keystore + + msgSender pb.MessageSender + selfAddrs func() []ma.Multiaddr + addLocalRecord func(mh.Multihash) error + + maxWorkers int + dedicatedPeriodicWorkers int + dedicatedBurstWorkers int + maxProvideConnsPerWorker int +} + +func (cfg *config) apply(opts ...Option) error { + for i, o := range opts { + if err := o(cfg); err != nil { + return fmt.Errorf("reprovider dht option %d failed: %w", i, err) + } + } + return nil +} + +func (c *config) validate() error { + if len(c.peerid) == 0 { + return errors.New("reprovider config: peer id is required") + } + if c.router == nil { + return errors.New("reprovider config: router is required") + } + if c.msgSender == nil { + return errors.New("reprovider config: message sender is required") + } + if c.selfAddrs == nil { + return errors.New("reprovider config: self addrs func is required") + } + if c.dedicatedPeriodicWorkers+c.dedicatedBurstWorkers > c.maxWorkers { + return errors.New("reprovider config: total dedicated workers exceed max workers") + } + return nil +} + +type Option func(opt *config) error + +var DefaultConfig = func(cfg *config) error { + cfg.replicationFactor = amino.DefaultBucketSize + cfg.reprovideInterval = amino.DefaultReprovideInterval + cfg.maxReprovideDelay = DefaultMaxReprovideDelay + cfg.offlineDelay = DefaultOfflineDelay + cfg.connectivityCheckOnlineInterval = DefaultConnectivityCheckOnlineInterval + + cfg.maxWorkers = 4 + cfg.dedicatedPeriodicWorkers = 2 + cfg.dedicatedBurstWorkers = 1 + cfg.maxProvideConnsPerWorker = 20 + + cfg.addLocalRecord = func(mh mh.Multihash) error { return nil } + + return nil +} + +// WithReplicationFactor sets the replication factor for provider records. It +// means that during provide and reprovide operations, each provider records is +// allocated to the ReplicationFactor closest peers in the DHT swarm. +func WithReplicationFactor(n int) Option { + return func(cfg *config) error { + if n <= 0 { + return errors.New("reprovider config: replication factor must be a positive integer") + } + cfg.replicationFactor = n + return nil + } +} + +// WithReprovideInterval sets the interval at which regions are reprovided. +func WithReprovideInterval(d time.Duration) Option { + return func(cfg *config) error { + if d <= 0 { + return errors.New("reprovider config: reprovide interval must be greater than 0") + } + cfg.reprovideInterval = d + return nil + } +} + +// WithMaxReprovideDelay sets the maximum delay allowed when reproviding a +// region. The interval between 2 reprovides of the same region is at most +// ReprovideInterval+MaxReprovideDelay. +// +// This parameter is necessary since regions can grow and shrink depending on +// the network churn. +func WithMaxReprovideDelay(d time.Duration) Option { + return func(cfg *config) error { + if d <= 0 { + return errors.New("reprovider config: max reprovide delay must be greater than 0") + } + cfg.maxReprovideDelay = d + return nil + } +} + +// WithOfflineDelay sets the delay after which a disconnected node is +// considered as offline. When a node cannot connect to peers, it is set to +// `Disconnected`, and after `OfflineDelay` it still cannot connect to peers, +// its state changes to `Offline`. +func WithOfflineDelay(d time.Duration) Option { + return func(cfg *config) error { + if d < 0 { + return errors.New("reprovider config: offline delay must be non-negative") + } + cfg.offlineDelay = d + return nil + } +} + +// WithConnectivityCheckOnlineInterval sets the minimal interval for checking +// whether the node is still online. Such a check is performed when a network +// operation fails, and the ConnectivityCheckOnlineInterval limits how often +// such a check is performed. +func WithConnectivityCheckOnlineInterval(d time.Duration) Option { + return func(cfg *config) error { + cfg.connectivityCheckOnlineInterval = d + return nil + } +} + +// WithPeerID sets the peer ID of the node running the provider. +func WithPeerID(p peer.ID) Option { + return func(cfg *config) error { + cfg.peerid = p + return nil + } +} + +// WithRouter sets the router used to find closest peers in the DHT. +func WithRouter(r KadClosestPeersRouter) Option { + return func(cfg *config) error { + cfg.router = r + return nil + } +} + +// WithMessageSender sets the message sender used to send messages out to the +// DHT swarm. +func WithMessageSender(m pb.MessageSender) Option { + return func(cfg *config) error { + cfg.msgSender = m + return nil + } +} + +// WithSelfAddrs sets the function that returns the self addresses of the node. +// These addresses are written in the provider records advertised by the node. +func WithSelfAddrs(f func() []ma.Multiaddr) Option { + return func(cfg *config) error { + cfg.selfAddrs = f + return nil + } +} + +// WithAddLocalRecord sets the function that adds a provider record to the +// local provider record store. +func WithAddLocalRecord(f func(mh.Multihash) error) Option { + return func(cfg *config) error { + if f == nil { + return errors.New("reprovider config: add local record function cannot be nil") + } + cfg.addLocalRecord = f + return nil + } +} + +// WithMaxWorkers sets the maximum number of workers that can be used for +// provide and reprovide jobs. The job of a worker is to explore a region of +// the keyspace and (re)provide the keys matching the region to the closest +// peers. +// +// You can configure a number of workers dedicated to periodic jobs, and a +// number of workers dedicated to burst jobs. MaxWorkers should be greater or +// equal to DedicatedPeriodicWorkers+DedicatedBurstWorkers. The additional +// workers that aren't dedicated to specific jobs can be used for either job +// type where needed. +func WithMaxWorkers(n int) Option { + return func(cfg *config) error { + if n < 0 { + return errors.New("reprovider config: max workers must be non-negative") + } + cfg.maxWorkers = n + return nil + } +} + +// WithDedicatedPeriodicWorkers sets the number of workers dedicated to +// periodic region reprovides. +func WithDedicatedPeriodicWorkers(n int) Option { + return func(cfg *config) error { + if n < 0 { + return errors.New("reprovider config: dedicated periodic workers must be non-negative") + } + cfg.dedicatedPeriodicWorkers = n + return nil + } +} + +// WithDedicatedBurstWorkers sets the number of workers dedicated to burst +// operations. Burst operations consist in work that isn't scheduled +// beforehands, such as initial provides and catching up with reproviding after +// the node went offline for a while. +func WithDedicatedBurstWorkers(n int) Option { + return func(cfg *config) error { + if n < 0 { + return errors.New("reprovider config: dedicated burst workers must be non-negative") + } + cfg.dedicatedBurstWorkers = n + return nil + } +} + +// WithMaxProvideConnsPerWorker sets the maximum number of connections to +// distinct peers that can be opened by a single worker during a provide +// operation. +func WithMaxProvideConnsPerWorker(n int) Option { + return func(cfg *config) error { + if n <= 0 { + return errors.New("reprovider config: max provide conns per worker must be greater than 0") + } + cfg.maxProvideConnsPerWorker = n + return nil + } +} + +// WithKeystore defines the Keystore used to keep track of the keys that need +// to be reprovided. +func WithKeystore(ks keystore.Keystore) Option { + return func(cfg *config) error { + if ks == nil { + return errors.New("reprovider config: multihash store cannot be nil") + } + cfg.keystore = ks + return nil + } +} diff --git a/provider/provider.go b/provider/provider.go new file mode 100644 index 000000000..423f1fe47 --- /dev/null +++ b/provider/provider.go @@ -0,0 +1,1532 @@ +package provider + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync" + "sync/atomic" + "time" + + pool "github.com/guillaumemichel/reservedpool" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" + + ds "github.com/ipfs/go-datastore" + logging "github.com/ipfs/go-log/v2" + "github.com/ipfs/go-test/random" + "github.com/libp2p/go-libp2p/core/peer" + ma "github.com/multiformats/go-multiaddr" + mh "github.com/multiformats/go-multihash" + + "github.com/probe-lab/go-libdht/kad/key" + "github.com/probe-lab/go-libdht/kad/key/bit256" + "github.com/probe-lab/go-libdht/kad/key/bitstr" + "github.com/probe-lab/go-libdht/kad/trie" + + pb "github.com/libp2p/go-libp2p-kad-dht/pb" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/connectivity" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/queue" + "github.com/libp2p/go-libp2p-kad-dht/provider/keystore" + kb "github.com/libp2p/go-libp2p-kbucket" +) + +const ( + // maxPrefixSize is the maximum size of a prefix used to define a keyspace + // region. + maxPrefixSize = 24 + + // approxPrefixLenGCPCount is the number of GetClosestPeers calls run by the + // approxPrefixLen function. This function makes GetClosestPeers requests to + // get an estimate of the network size, to set the initial keyspace region + // prefix length. A high number increases the precision of the measurement, + // but adds network load and latency before the initial provide request can + // be performed. + approxPrefixLenGCPCount = 4 + // defaultPrefixLenValidity is the validity of the cached average region + // prefix length computed from the schedule. It allows to avoid recomputing + // the average everytime we need to average prefix length. + defaultPrefixLenValidity = 5 * time.Minute + + // retryInterval is the interval at which the provider tries to perform any + // previously failed work (provide or reprovide). + retryInterval = 5 * time.Minute + + // individualProvideThreshold is the threshold for the number of keys to + // trigger a region exploration. If the number of keys to provide for a + // region is less or equal to the threshold, the keys will be individually + // provided. + individualProvideThreshold = 2 + + // maxExplorationPrefixSearches is the maximum number of GetClosestPeers + // operations that are allowed to explore a prefix, preventing an infinite + // loop, since the exit condition depends on the network topology. + // A lower bound estimate on the number of fresh peers returned by GCP is + // replicationFactor/2. Hence, 64 GCP are expected to return at least + // 32*replicatonFactor peers, which should be more than enough, even if the + // supplied prefix is too short. + maxExplorationPrefixSearches = 64 + + // maxConsecutiveProvideFailuresAllowed is the maximum number of consecutive + // provides that are allowed to fail to the same remote peer before cancelling + // all pending requests to this peer. + maxConsecutiveProvideFailuresAllowed = 2 + // minimalRegionReachablePeersRatio is the minimum ratio of reachable peers + // in a region for the provide to be considered a success. + minimalRegionReachablePeersRatio float32 = 0.2 +) + +var ( + // ErrClosed is returned when the provider is closed. + ErrClosed = errors.New("provider: closed") + // ErrOffline is returned when the provider is offline, and cannot process + // the request. When a node is offline, operations on the keystore are + // performed as usual, but keys aren't added to provide queue nor advertised + // to the network. + ErrOffline = errors.New("provider: offline") +) + +const LoggerName = "dht/provider" + +var logger = logging.Logger(LoggerName) + +type KadClosestPeersRouter interface { + GetClosestPeers(context.Context, string) ([]peer.ID, error) +} + +type workerType uint8 + +const ( + periodicWorker workerType = iota + burstWorker +) + +var _ internal.Provider = (*SweepingProvider)(nil) + +type SweepingProvider struct { + done chan struct{} + ctx context.Context + cancelCtx context.CancelFunc + closeOnce sync.Once + wg sync.WaitGroup + cleanupFuncs []func() error + + peerid peer.ID + order bit256.Key + router KadClosestPeersRouter + + connectivity *connectivity.ConnectivityChecker + + keystore keystore.Keystore + + replicationFactor int + + provideQueue *queue.ProvideQueue + provideRunning sync.Mutex + reprovideQueue *queue.ReprovideQueue + lateReprovideRunning sync.Mutex + + workerPool *pool.Pool[workerType] + maxProvideConnsPerWorker int + + cycleStart time.Time + reprovideInterval time.Duration + maxReprovideDelay time.Duration + + schedule *trie.Trie[bitstr.Key, time.Duration] + scheduleLk sync.Mutex + scheduleCursor bitstr.Key + scheduleTimer *time.Timer + scheduleTimerStartedAt time.Time + + ongoingReprovides *trie.Trie[bitstr.Key, struct{}] + ongoingReprovidesLk sync.Mutex + + cachedAvgPrefixLen int + avgPrefixLenLk sync.Mutex + approxPrefixLenRunning sync.Mutex + lastAvgPrefixLen time.Time + avgPrefixLenValidity time.Duration + + msgSender pb.MessageSender + getSelfAddrs func() []ma.Multiaddr + addLocalRecord func(mh.Multihash) error + + provideCounter metric.Int64Counter +} + +// New creates a new SweepingProvider instance with the supplied options. +func New(opts ...Option) (*SweepingProvider, error) { + var cfg config + err := cfg.apply(append([]Option{DefaultConfig}, opts...)...) + if err != nil { + return nil, err + } + cleanupFuncs := []func() error{} + + if cfg.keystore == nil { + // Setup KeyStore if missing + mapDs := ds.NewMapDatastore() + cleanupFuncs = append(cleanupFuncs, mapDs.Close) + cfg.keystore, err = keystore.NewKeystore(mapDs) + if err != nil { + cleanup(cleanupFuncs) + return nil, err + } + } + cleanupFuncs = append(cleanupFuncs, cfg.keystore.Close) + if err := cfg.validate(); err != nil { + cleanup(cleanupFuncs) + return nil, err + } + meter := otel.Meter("github.com/libp2p/go-libp2p-kad-dht/provider") + providerCounter, err := meter.Int64Counter( + "total_provide_count", + metric.WithDescription("Number of successful provides since node is running"), + ) + if err != nil { + cleanup(cleanupFuncs) + return nil, err + } + ctx, cancelCtx := context.WithCancel(context.Background()) + connChecker, err := connectivity.New( + func() bool { + peers, err := cfg.router.GetClosestPeers(ctx, string(cfg.peerid)) + return err == nil && len(peers) > 0 + }, + connectivity.WithOfflineDelay(cfg.offlineDelay), + connectivity.WithOnlineCheckInterval(cfg.connectivityCheckOnlineInterval), + ) + if err != nil { + cancelCtx() + cleanup(cleanupFuncs) + return nil, err + } + cleanupFuncs = append(cleanupFuncs, connChecker.Close) + + prov := &SweepingProvider{ + done: make(chan struct{}), + ctx: ctx, + cancelCtx: cancelCtx, + cleanupFuncs: cleanupFuncs, + + router: cfg.router, + peerid: cfg.peerid, + order: keyspace.PeerIDToBit256(cfg.peerid), + + connectivity: connChecker, + + replicationFactor: cfg.replicationFactor, + reprovideInterval: cfg.reprovideInterval, + maxReprovideDelay: cfg.maxReprovideDelay, + + workerPool: pool.New(cfg.maxWorkers, map[workerType]int{ + periodicWorker: cfg.dedicatedPeriodicWorkers, + burstWorker: cfg.dedicatedBurstWorkers, + }), + maxProvideConnsPerWorker: cfg.maxProvideConnsPerWorker, + + avgPrefixLenValidity: defaultPrefixLenValidity, + cachedAvgPrefixLen: -1, + + cycleStart: time.Now(), + + msgSender: cfg.msgSender, + getSelfAddrs: cfg.selfAddrs, + addLocalRecord: cfg.addLocalRecord, + + keystore: cfg.keystore, + + schedule: trie.New[bitstr.Key, time.Duration](), + scheduleTimer: time.NewTimer(time.Hour), + + provideQueue: queue.NewProvideQueue(), + reprovideQueue: queue.NewReprovideQueue(), + + ongoingReprovides: trie.New[bitstr.Key, struct{}](), + + provideCounter: providerCounter, + } + // Set up callbacks after both provider and connectivity checker are initialized + // This breaks the circular dependency between connectivity, onOnline, and approxPrefixLen + prov.connectivity.SetCallbacks(prov.onOnline, prov.onOffline) + prov.connectivity.Start() + + // Don't need to start schedule timer yet + prov.scheduleTimer.Stop() + + prov.cleanupFuncs = append(prov.cleanupFuncs, prov.workerPool.Close) + + prov.wg.Add(1) + go prov.run() + + return prov, nil +} + +func (s *SweepingProvider) run() { + defer s.wg.Done() + + logger.Debug("Starting SweepingProvider") + retryTicker := time.NewTicker(retryInterval) + defer retryTicker.Stop() + + for { + select { + case <-s.done: + return + case <-retryTicker.C: + if s.connectivity.IsOnline() { + s.catchupPendingWork() + } + case <-s.scheduleTimer.C: + s.handleReprovide() + } + } +} + +// Close stops the provider and releases all resources. +func (s *SweepingProvider) Close() error { + var err error + s.closeOnce.Do(func() { + close(s.done) + s.cancelCtx() + s.wg.Wait() + s.approxPrefixLenRunning.Lock() + _ = struct{}{} // cannot have empty critical section + s.approxPrefixLenRunning.Unlock() + + s.scheduleTimer.Stop() + err = cleanup(s.cleanupFuncs) + }) + return err +} + +func cleanup(funcs []func() error) error { + var errs []error + for i := len(funcs) - 1; i >= 0; i-- { // LIFO: last-added is cleaned up first + if f := funcs[i]; f != nil { + if err := f(); err != nil { + errs = append(errs, err) + } + } + } + return errors.Join(errs...) +} + +func (s *SweepingProvider) closed() bool { + select { + case <-s.done: + return true + default: + return false + } +} + +// scheduleNextReprovideNoLock makes sure the scheduler wakes up in +// `timeUntilReprovide` to reprovide the region identified by `prefix`. +func (s *SweepingProvider) scheduleNextReprovideNoLock(prefix bitstr.Key, timeUntilReprovide time.Duration) { + s.scheduleCursor = prefix + s.scheduleTimer.Reset(timeUntilReprovide) + s.scheduleTimerStartedAt = time.Now() +} + +func (s *SweepingProvider) reschedulePrefix(prefix bitstr.Key) { + s.scheduleLk.Lock() + s.schedulePrefixNoLock(prefix, true) + s.scheduleLk.Unlock() +} + +// schedulePrefixNoLock adds the supplied prefix to the schedule, unless +// already present. +// +// If `justReprovided` is true, it will schedule the next reprovide at most +// s.reprovideInterval+s.maxReprovideDelay in the future, allowing the +// reprovide to be delayed of at most maxReprovideDelay. +// +// If the supplied prefix is the next prefix to be reprovided, update the +// schedule cursor and timer. +func (s *SweepingProvider) schedulePrefixNoLock(prefix bitstr.Key, justReprovided bool) { + nextReprovideTime := s.reprovideTimeForPrefix(prefix) + if justReprovided { + // Schedule next reprovide given that the prefix was just reprovided on + // schedule. In the case the next reprovide time should be delayed due to a + // growth in the number of network peers matching the prefix, don't delay + // more than s.maxReprovideDelay. + nextReprovideTime = min(nextReprovideTime, s.currentTimeOffset()+s.reprovideInterval+s.maxReprovideDelay) + } + // If schedule contains keys starting with prefix, remove them to avoid + // overlap. + if _, ok := keyspace.FindPrefixOfKey(s.schedule, prefix); ok { + // Already scheduled. + return + } + // Unschedule superstrings in schedule if any. + s.unscheduleSubsumedPrefixesNoLock(prefix) + + s.schedule.Add(prefix, nextReprovideTime) + + // Check if the prefix that was just added is the next one to be reprovided. + if s.schedule.IsNonEmptyLeaf() { + // The prefix we insterted is the only element in the schedule. + timeUntilPrefixReprovide := s.timeUntil(nextReprovideTime) + s.scheduleNextReprovideNoLock(prefix, timeUntilPrefixReprovide) + return + } + followingKey := keyspace.NextNonEmptyLeaf(s.schedule, prefix, s.order).Key + if followingKey == s.scheduleCursor { + // The key following prefix is the schedule cursor. + timeUntilPrefixReprovide := s.timeUntil(nextReprovideTime) + _, scheduledAlarm := trie.Find(s.schedule, s.scheduleCursor) + if timeUntilPrefixReprovide < s.timeUntil(scheduledAlarm) { + s.scheduleNextReprovideNoLock(prefix, timeUntilPrefixReprovide) + } + } +} + +// unscheduleSubsumedPrefixes removes all superstrings of `prefix` that are +// scheduled in the future. Assumes that the schedule lock is held. +func (s *SweepingProvider) unscheduleSubsumedPrefixesNoLock(prefix bitstr.Key) { + // Pop prefixes scheduled in the future being covered by the explored peers. + keyspace.PruneSubtrie(s.schedule, prefix) + + // If we removed s.scheduleCursor from schedule, select the next one + if keyspace.IsBitstrPrefix(prefix, s.scheduleCursor) { + next := keyspace.NextNonEmptyLeaf(s.schedule, s.scheduleCursor, s.order) + if next == nil { + s.scheduleNextReprovideNoLock(prefix, s.reprovideInterval) + } else { + timeUntilReprovide := s.timeUntil(next.Data) + s.scheduleNextReprovideNoLock(next.Key, timeUntilReprovide) + logger.Debugf("next scheduled prefix now is %s", s.scheduleCursor) + } + } +} + +// currentTimeOffset returns the current time offset in the reprovide cycle. +func (s *SweepingProvider) currentTimeOffset() time.Duration { + return s.timeOffset(time.Now()) +} + +// timeOffset returns the time offset in the reprovide cycle for the given +// time. +func (s *SweepingProvider) timeOffset(t time.Time) time.Duration { + return t.Sub(s.cycleStart) % s.reprovideInterval +} + +// timeUntil returns the time left (duration) until the given time offset. +func (s *SweepingProvider) timeUntil(d time.Duration) time.Duration { + return s.timeBetween(s.currentTimeOffset(), d) +} + +// timeBetween returns the duration between the two provided offsets, assuming +// it is no more than s.reprovideInterval. +func (s *SweepingProvider) timeBetween(from, to time.Duration) time.Duration { + return (to-from+s.reprovideInterval-1)%s.reprovideInterval + 1 +} + +// reprovideTimeForPrefix calculates the scheduled time offset for reproviding +// keys associated with a given prefix based on its bitstring prefix. The +// function maps the given binary prefix to a fraction of the overall reprovide +// interval (s.reprovideInterval), such that keys with prefixes closer to a +// configured target s.order (in XOR distance) are scheduled earlier and those +// further away later in the cycle. +// +// For any prefix of bit length n, the function generates 2^n distinct +// reprovide times that evenly partition the entire reprovide interval. The +// process first truncates s.order to n bits and then XORs it with the provided +// prefix. The resulting binary string is converted to an integer, +// corresponding to the index of the 2^n possible reprovide times to use for +// the prefix. +// +// This method ensures a deterministic and evenly distributed reprovide +// schedule, where the temporal position within the cycle is based on the +// binary representation of the key's prefix. +func (s *SweepingProvider) reprovideTimeForPrefix(prefix bitstr.Key) time.Duration { + if len(prefix) == 0 { + // Empty prefix: all reprovides occur at the beginning of the cycle. + return 0 + } + if len(prefix) > maxPrefixSize { + // Truncate the prefix to the maximum allowed size to avoid overly fine + // slicing of time. + prefix = prefix[:maxPrefixSize] + } + // Number of possible bitstrings of the same length as prefix. + maxInt := int64(1 << len(prefix)) + // XOR the prefix with the order key to reorder the schedule: keys "close" to + // s.order are scheduled first in the cycle, and those "far" from it are + // scheduled later. + order := bitstr.Key(key.BitString(s.order)[:len(prefix)]) + k := prefix.Xor(order) + val, _ := strconv.ParseInt(string(k), 2, 64) + // Calculate the time offset as a fraction of the overall reprovide interval. + return time.Duration(int64(s.reprovideInterval) * val / maxInt) +} + +// approxPrefixLen makes a few GetClosestPeers calls to get an estimate +// of the prefix length to be used in the network. +// +// This function blocks until GetClosestPeers succeeds or the provider is +// closed. No provide operation can happen until this function returns. +func (s *SweepingProvider) approxPrefixLen() { + cplSum := atomic.Int32{} + cplSamples := atomic.Int32{} + wg := sync.WaitGroup{} + wg.Add(approxPrefixLenGCPCount) + for range approxPrefixLenGCPCount { + go func() { + defer wg.Done() + randomMh := random.Multihashes(1)[0] + for { + if s.closed() || !s.connectivity.IsOnline() { + return + } + peers, err := s.router.GetClosestPeers(s.ctx, string(randomMh)) + if err != nil { + logger.Infof("GetClosestPeers failed during prefix len approximation measurement: %s", err) + } else { + if len(peers) < 2 { + return // Ignore result if less than 2 other peers in DHT. + } + cpl := keyspace.KeyLen + firstPeerKey := keyspace.PeerIDToBit256(peers[0]) + for _, p := range peers[1:] { + cpl = min(cpl, key.CommonPrefixLength(firstPeerKey, keyspace.PeerIDToBit256(p))) + } + cplSum.Add(int32(cpl)) + cplSamples.Add(1) + return + } + + s.connectivity.TriggerCheck() + time.Sleep(time.Second) // retry every second until success + } + }() + } + wg.Wait() + + nSamples := cplSamples.Load() + s.avgPrefixLenLk.Lock() + defer s.avgPrefixLenLk.Unlock() + if nSamples == 0 { + // At most 2 other peers in the DHT -> single region of prefix len 0 + s.cachedAvgPrefixLen = 0 + } else { + s.cachedAvgPrefixLen = int(cplSum.Load() / nSamples) + } + logger.Debugf("prefix len approximation is %d", s.cachedAvgPrefixLen) + s.lastAvgPrefixLen = time.Now() +} + +// getAvgPrefixLenNoLock returns the average prefix length of all scheduled +// prefixes. +// +// Hangs until the first measurement is done if the average prefix length is +// missing. +func (s *SweepingProvider) getAvgPrefixLenNoLock() (int, error) { + s.avgPrefixLenLk.Lock() + defer s.avgPrefixLenLk.Unlock() + + if s.cachedAvgPrefixLen == -1 { + return -1, ErrOffline + } + + if s.lastAvgPrefixLen.Add(s.avgPrefixLenValidity).After(time.Now()) { + // Return cached value if it is still valid. + return s.cachedAvgPrefixLen, nil + } + prefixLenSum := 0 + if !s.schedule.IsEmptyLeaf() { + // Take average prefix length of all scheduled prefixes. + scheduleEntries := keyspace.AllEntries(s.schedule, s.order) + for _, entry := range scheduleEntries { + prefixLenSum += len(entry.Key) + } + s.cachedAvgPrefixLen = prefixLenSum / len(scheduleEntries) + s.lastAvgPrefixLen = time.Now() + } + return s.cachedAvgPrefixLen, nil +} + +// vanillaProvide provides a single key to the network without any +// optimization. It should be used for providing a small number of keys +// (typically 1 or 2), because exploring the keyspace would add too much +// overhead for a small number of keys. +func (s *SweepingProvider) vanillaProvide(k mh.Multihash) (bitstr.Key, error) { + // Add provider record to local provider store. + s.addLocalRecord(k) + // Get peers to which the record will be allocated. + peers, err := s.router.GetClosestPeers(s.ctx, string(k)) + if err != nil { + return "", err + } + coveredPrefix, _ := keyspace.ShortestCoveredPrefix(bitstr.Key(key.BitString(keyspace.MhToBit256(k))), peers) + addrInfo := peer.AddrInfo{ID: s.peerid, Addrs: s.getSelfAddrs()} + keysAllocations := make(map[peer.ID][]mh.Multihash) + for _, p := range peers { + keysAllocations[p] = []mh.Multihash{k} + } + return coveredPrefix, s.sendProviderRecords(keysAllocations, addrInfo) +} + +// exploreSwarm finds all peers whose kademlia identifier matches `prefix` in +// the DHT swarm, and organizes them in keyspace regions. +// +// A region is identified by a keyspace prefix, and contains all the peers +// matching this prefix. A region always has at least s.replicationFactor +// peers. Regions are non-overlapping. +// +// If there less than s.replicationFactor peers match `prefix`, explore +// shorter prefixes until at least s.replicationFactor peers are included in +// the region. +// +// The returned `coveredPrefix` represents the keyspace prefix covered by all +// returned regions combined. It is different to the supplied `prefix` if there +// aren't enough peers matching `prefix`. +func (s *SweepingProvider) exploreSwarm(prefix bitstr.Key) (regions []keyspace.Region, coveredPrefix bitstr.Key, err error) { + peers, err := s.closestPeersToPrefix(prefix) + if err != nil { + return nil, "", fmt.Errorf("exploreSwarm '%s': %w", prefix, err) + } + if len(peers) == 0 { + return nil, "", fmt.Errorf("no peers found when exploring prefix %s", prefix) + } + regions, coveredPrefix = keyspace.RegionsFromPeers(peers, s.replicationFactor, s.order) + return regions, coveredPrefix, nil +} + +// closestPeersToPrefix returns at least s.replicationFactor peers +// corresponding to the branch of the network peers trie matching the provided +// prefix. In the case there aren't enough peers matching the provided prefix, +// it will find and return the closest peers to the prefix, even if they don't +// exactly match it. +func (s *SweepingProvider) closestPeersToPrefix(prefix bitstr.Key) ([]peer.ID, error) { + allClosestPeers := make(map[peer.ID]struct{}) + + nextPrefix := prefix + startTime := time.Now() + coveredPrefixesStack := []bitstr.Key{} + + i := 0 + // Go down the trie to fully cover prefix. +exploration: + for { + if i == maxExplorationPrefixSearches { + return nil, errors.New("closestPeersToPrefix needed more than maxPrefixSearches iterations") + } + if !s.connectivity.IsOnline() { + return nil, errors.New("provider: node is offline") + } + i++ + fullKey := keyspace.FirstFullKeyWithPrefix(nextPrefix, s.order) + closestPeers, err := s.closestPeersToKey(fullKey) + if err != nil { + // We only get an err if something really bad happened, e.g no peers in + // routing table, invalid key, etc. + return nil, err + } + if len(closestPeers) == 0 { + return nil, errors.New("dht lookup did not return any peers") + } + coveredPrefix, coveredPeers := keyspace.ShortestCoveredPrefix(fullKey, closestPeers) + for _, p := range coveredPeers { + allClosestPeers[p] = struct{}{} + } + + coveredPrefixLen := len(coveredPrefix) + if i == 1 { + if coveredPrefixLen <= len(prefix) && coveredPrefix == prefix[:coveredPrefixLen] && len(allClosestPeers) >= s.replicationFactor { + // Exit early if the prefix is fully covered at the first request and + // we have enough (at least replicationFactor) peers. + break exploration + } + } else { + latestPrefix := coveredPrefixesStack[len(coveredPrefixesStack)-1] + for coveredPrefixLen <= len(latestPrefix) && coveredPrefix[:coveredPrefixLen-1] == latestPrefix[:coveredPrefixLen-1] { + // Pop latest prefix from stack, because current prefix is + // complementary. + // e.g latestPrefix=0010, currentPrefix=0011. latestPrefix is + // replaced by 001, unless 000 was also in the stack, etc. + coveredPrefixesStack = coveredPrefixesStack[:len(coveredPrefixesStack)-1] + coveredPrefix = coveredPrefix[:len(coveredPrefix)-1] + coveredPrefixLen = len(coveredPrefix) + + if len(coveredPrefixesStack) == 0 { + if coveredPrefixLen <= len(prefix) && len(allClosestPeers) >= s.replicationFactor { + break exploration + } + // Not enough peers -> add coveredPrefix to stack and continue. + break + } + if coveredPrefixLen == 0 { + logger.Error("coveredPrefixLen==0, coveredPrefixStack ", coveredPrefixesStack) + break exploration + } + latestPrefix = coveredPrefixesStack[len(coveredPrefixesStack)-1] + } + } + // Push coveredPrefix to stack + coveredPrefixesStack = append(coveredPrefixesStack, coveredPrefix) + // Flip last bit of last covered prefix + nextPrefix = keyspace.FlipLastBit(coveredPrefixesStack[len(coveredPrefixesStack)-1]) + } + + peers := make([]peer.ID, 0, len(allClosestPeers)) + for p := range allClosestPeers { + peers = append(peers, p) + } + logger.Debugf("Region %s exploration required %d requests to discover %d peers in %s", prefix, i, len(allClosestPeers), time.Since(startTime)) + return peers, nil +} + +// closestPeersToKey returns a valid peer ID sharing a long common prefix with +// the provided key. Note that the returned peer IDs aren't random, they are +// taken from a static list of preimages. +func (s *SweepingProvider) closestPeersToKey(k bitstr.Key) ([]peer.ID, error) { + p, _ := kb.GenRandPeerIDWithCPL(keyspace.KeyToBytes(k), kb.PeerIDPreimageMaxCpl) + return s.router.GetClosestPeers(s.ctx, string(p)) +} + +type provideJob struct { + pid peer.ID + keys []mh.Multihash +} + +// sendProviderRecords manages reprovides for all given peer ids and allocated +// keys. Upon failure to reprovide a key, or to connect to a peer, it will NOT +// retry. +// +// Returns an error if we were unable to reprovide keys to a given threshold of +// peers. In this case, the region reprovide is considered failed and the +// caller is responsible for trying again. This allows detecting if we are +// offline. +func (s *SweepingProvider) sendProviderRecords(keysAllocations map[peer.ID][]mh.Multihash, addrInfo peer.AddrInfo) error { + nPeers := len(keysAllocations) + if nPeers == 0 { + return nil + } + startTime := time.Now() + errCount := atomic.Uint32{} + nWorkers := s.maxProvideConnsPerWorker + jobChan := make(chan provideJob, nWorkers) + + wg := sync.WaitGroup{} + wg.Add(nWorkers) + for range nWorkers { + go func() { + pmes := genProvideMessage(addrInfo) + defer wg.Done() + for job := range jobChan { + err := s.provideKeysToPeer(job.pid, job.keys, pmes) + if err != nil { + errCount.Add(1) + } + } + }() + } + + for p, keys := range keysAllocations { + jobChan <- provideJob{p, keys} + } + close(jobChan) + wg.Wait() + + errCountLoaded := int(errCount.Load()) + logger.Infof("sent provider records to peers in %s, errors %d/%d", time.Since(startTime), errCountLoaded, len(keysAllocations)) + + if errCountLoaded == nPeers || errCountLoaded > int(float32(nPeers)*(1-minimalRegionReachablePeersRatio)) { + return fmt.Errorf("unable to provide to enough peers (%d/%d)", nPeers-errCountLoaded, nPeers) + } + return nil +} + +// genProvideMessage generates a new provide message with the supplied +// AddrInfo. The message contains no keys, as they will be set later before +// sending the message. +func genProvideMessage(addrInfo peer.AddrInfo) *pb.Message { + pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, []byte{}, 0) + pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]peer.AddrInfo{addrInfo}) + return pmes +} + +// provideKeysToPeer performs the network operation to advertise to the given +// DHT server (p) that we serve all the given keys. +func (s *SweepingProvider) provideKeysToPeer(p peer.ID, keys []mh.Multihash, pmes *pb.Message) error { + errCount := 0 + for _, mh := range keys { + pmes.Key = mh + err := s.msgSender.SendMessage(s.ctx, p, pmes) + if err != nil { + errCount++ + + if errCount == len(keys) || errCount > maxConsecutiveProvideFailuresAllowed { + return fmt.Errorf("failed to provide to %s: %s", p, err.Error()) + } + } else if errCount > 0 { + // Reset error count + errCount = 0 + } + } + return nil +} + +// handleReprovide advances the reprovider schedule and (asynchronously) +// reprovides the region at the current schedule cursor. +// +// Behavior: +// - Determines the next region to reprovide based on the current cursor and +// the schedule, reprovides the region under the cursor, and moves the cursor +// to the next region. +// - Programs the schedule timer (alarm) for the next region’s reprovide +// time. When the timer fires, this method must be invoked again. +// - If the node has been blocked past the reprovide interval or if one or +// more regions’ times are already in the past, those regions are added to +// the reprovide queue for catch-up and a connectivity check is triggered. +// - If the node is currently offline, it skips the immediate reprovide of +// the current region and enqueues it to the reprovide queue for later. +// - If the node is online it removes the current region from the reprovide +// queue (if present) and starts an asynchronous batch reprovide using a +// periodic worker. +func (s *SweepingProvider) handleReprovide() { + s.scheduleLk.Lock() + currentPrefix := s.scheduleCursor + // Get next prefix to reprovide, and set timer for it. + next := keyspace.NextNonEmptyLeaf(s.schedule, currentPrefix, s.order) + + if next == nil { + // Schedule is empty, don't reprovide anything. + s.scheduleLk.Unlock() + return + } + + var nextPrefix bitstr.Key + var timeUntilNextReprovide time.Duration + if next.Key == currentPrefix { + // There is a single prefix in the schedule. + nextPrefix = currentPrefix + timeUntilNextReprovide = s.timeUntil(s.reprovideTimeForPrefix(currentPrefix)) + } else { + currentTimeOffset := s.currentTimeOffset() + timeSinceTimerRunning := s.timeBetween(s.timeOffset(s.scheduleTimerStartedAt), currentTimeOffset) + timeSinceTimerUntilNext := s.timeBetween(s.timeOffset(s.scheduleTimerStartedAt), next.Data) + + if s.scheduleTimerStartedAt.Add(s.reprovideInterval).Before(time.Now()) { + // Alarm was programmed more than reprovideInterval ago, which means that + // no regions has been reprovided since. Add all regions to the reprovide + // queue. This only happens if the main thread gets blocked for more than + // reprovideInterval. + nextKeyFound := false + scheduleEntries := keyspace.AllEntries(s.schedule, s.order) + next = scheduleEntries[0] + for _, entry := range scheduleEntries { + // Add all regions from the schedule to the reprovide queue. The next + // region to be scheduled for reprovide is the one immediately + // following the current time offset in the schedule. + if !nextKeyFound && entry.Data > currentTimeOffset { + next = entry + nextKeyFound = true + } + s.reprovideQueue.Enqueue(entry.Key) + } + // Don't reprovide any region now, but schedule the next one. All regions + // are expected to be reprovided when the provider is catching up with + // failed regions. + s.scheduleNextReprovideNoLock(next.Key, s.timeUntil(next.Data)) + s.scheduleLk.Unlock() + return + } + if timeSinceTimerUntilNext < timeSinceTimerRunning { + // next is scheduled in the past. While next is in the past, add next to + // failedRegions and take nextLeaf as next. + count := 0 + scheduleSize := s.schedule.Size() + for timeSinceTimerUntilNext < timeSinceTimerRunning && count < scheduleSize { + prefix := next.Key + s.reprovideQueue.Enqueue(prefix) + next = keyspace.NextNonEmptyLeaf(s.schedule, next.Key, s.order) + timeSinceTimerUntilNext = s.timeBetween(s.timeOffset(s.scheduleTimerStartedAt), next.Data) + count++ + } + } + + // next is in the future + nextPrefix = next.Key + timeUntilNextReprovide = s.timeUntil(next.Data) + } + + s.scheduleNextReprovideNoLock(nextPrefix, timeUntilNextReprovide) + s.scheduleLk.Unlock() + + // If we are offline, don't try to reprovide region. + if !s.connectivity.IsOnline() { + s.reprovideQueue.Enqueue(currentPrefix) + return + } + + // Remove prefix that is about to be reprovided from the reprovide queue if + // present. + s.reprovideQueue.Remove(currentPrefix) + + s.wg.Add(1) + go func() { + if err := s.workerPool.Acquire(periodicWorker); err == nil { + s.batchReprovide(currentPrefix, true) + s.workerPool.Release(periodicWorker) + } + s.wg.Done() + }() +} + +// handleProvide provides supplied keys to the network if needed and schedules +// the keys to be reprovided if needed. +func (s *SweepingProvider) handleProvide(force, reprovide bool, keys ...mh.Multihash) error { + if len(keys) == 0 { + return nil + } + if reprovide { + // Add keys to list of keys to be reprovided. Returned keys are deduplicated + // newly added keys. + newKeys, err := s.keystore.Put(s.ctx, keys...) + if err != nil { + return fmt.Errorf("couldn't add keys to keystore: %w", err) + } + if !force { + keys = newKeys + } + } + + if s.isOffline() { + return ErrOffline + } + prefixes, err := s.groupAndScheduleKeysByPrefix(keys, reprovide) + if err != nil { + return err + } + if len(prefixes) == 0 { + return nil + } + // Sort prefixes by number of keys. + sortedPrefixesAndKeys := keyspace.SortPrefixesBySize(prefixes) + // Add keys to the provide queue. + for _, prefixAndKeys := range sortedPrefixesAndKeys { + s.provideQueue.Enqueue(prefixAndKeys.Prefix, prefixAndKeys.Keys...) + } + + s.wg.Add(1) + go s.provideLoop() + return nil +} + +// groupAndScheduleKeysByPrefix groups the supplied keys by their prefixes as +// present in the schedule, and if `schedule` is set to true, add these +// prefixes to the schedule to be reprovided. +func (s *SweepingProvider) groupAndScheduleKeysByPrefix(keys []mh.Multihash, schedule bool) (map[bitstr.Key][]mh.Multihash, error) { + seen := make(map[string]struct{}) + prefixTrie := trie.New[bitstr.Key, struct{}]() + prefixes := make(map[bitstr.Key][]mh.Multihash) + avgPrefixLen := -1 + + s.scheduleLk.Lock() + defer s.scheduleLk.Unlock() + for _, h := range keys { + k := keyspace.MhToBit256(h) + kStr := string(keyspace.KeyToBytes(k)) + // Don't add duplicates + if _, ok := seen[kStr]; ok { + continue + } + seen[kStr] = struct{}{} + + if prefix, ok := keyspace.FindPrefixOfKey(prefixTrie, k); ok { + prefixes[prefix] = append(prefixes[prefix], h) + continue + } + + prefix, inSchedule := keyspace.FindPrefixOfKey(s.schedule, k) + if !inSchedule { + if avgPrefixLen == -1 { + var err error + avgPrefixLen, err = s.getAvgPrefixLenNoLock() + if err != nil { + return nil, err + } + } + prefix = bitstr.Key(key.BitString(k)[:avgPrefixLen]) + if schedule { + s.schedulePrefixNoLock(prefix, false) + } + } + mhs := []mh.Multihash{h} + if subtrie, ok := keyspace.FindSubtrie(prefixTrie, prefix); ok { + // If prefixes already contains superstrings of prefix, consolidate the + // keys to prefix. + for _, entry := range keyspace.AllEntries(subtrie, s.order) { + mhs = append(mhs, prefixes[entry.Key]...) + delete(prefixes, entry.Key) + } + keyspace.PruneSubtrie(prefixTrie, prefix) + } + prefixTrie.Add(prefix, struct{}{}) + prefixes[prefix] = mhs + } + return prefixes, nil +} + +func (s *SweepingProvider) isOffline() bool { + s.avgPrefixLenLk.Lock() + defer s.avgPrefixLenLk.Unlock() + return s.cachedAvgPrefixLen == -1 +} + +func (s *SweepingProvider) onOffline() { + s.provideQueue.Clear() + + s.avgPrefixLenLk.Lock() + s.cachedAvgPrefixLen = -1 // Invalidate cached avg prefix len. + s.avgPrefixLenLk.Unlock() +} + +func (s *SweepingProvider) onOnline() { + if s.closed() { + return + } + + s.avgPrefixLenLk.Lock() + cachedAvgPrefixLen := s.cachedAvgPrefixLen + s.avgPrefixLenLk.Unlock() + + if cachedAvgPrefixLen == -1 { + // Provider was previously Offline (not Disconnected). + // Run prefix length measurement, and refresh schedule afterwards. + if !s.approxPrefixLenRunning.TryLock() { + return + } + s.approxPrefixLen() + s.approxPrefixLenRunning.Unlock() + + s.RefreshSchedule() + } + + s.catchupPendingWork() +} + +// catchupPendingWork is called when the provider comes back online after being offline. +// +// 1. Try again to reprovide regions that failed to be reprovided on time. +// 2. Try again to provide keys that failed to be provided. +// +// This function is guarded by s.lateReprovideRunning, ensuring the function +// cannot be called again while it is working on reproviding late regions. +func (s *SweepingProvider) catchupPendingWork() { + if s.closed() || !s.lateReprovideRunning.TryLock() { + return + } + s.wg.Add(2) + go func() { + // Reprovide late regions if any. + s.reprovideLateRegions() + s.lateReprovideRunning.Unlock() + + // Provides are handled after reprovides, because keys pending to be + // provided will be provided as part of a region reprovide if they belong + // to that region. Hence, the provideLoop will use less resources if run + // after the reprovides. + + // Restart provide loop if it was stopped. + s.provideLoop() + }() +} + +// provideLoop is the loop providing keys to the DHT swarm as long as the +// provide queue isn't empty. +// +// The s.provideRunning mutex prevents concurrent executions of the loop. +func (s *SweepingProvider) provideLoop() { + defer s.wg.Done() + if !s.provideRunning.TryLock() { + // Ensure that only one goroutine is running the provide loop at a time. + return + } + defer s.provideRunning.Unlock() + + for !s.provideQueue.IsEmpty() { + if s.closed() { + // Exit loop if provider is closed. + return + } + if !s.connectivity.IsOnline() { + // Don't try to provide if node is offline. + return + } + // Block until we can acquire a worker from the pool. + err := s.workerPool.Acquire(burstWorker) + if err != nil { + // Provider was closed while waiting for a worker. + return + } + prefix, keys, ok := s.provideQueue.Dequeue() + if ok { + s.wg.Add(1) + go func(prefix bitstr.Key, keys []mh.Multihash) { + s.batchProvide(prefix, keys) + s.workerPool.Release(burstWorker) + s.wg.Done() + }(prefix, keys) + } else { + s.workerPool.Release(burstWorker) + } + } +} + +// reprovideLateRegions is the loop reproviding regions that failed to be +// reprovided on time. It returns once the reprovide queue is empty. +func (s *SweepingProvider) reprovideLateRegions() { + defer s.wg.Done() + for !s.reprovideQueue.IsEmpty() { + if s.closed() { + // Exit loop if provider is closed. + return + } + if !s.connectivity.IsOnline() { + // Don't try to reprovide a region if node is offline. + return + } + // Block until we can acquire a worker from the pool. + err := s.workerPool.Acquire(burstWorker) + if err != nil { + // Provider was closed while waiting for a worker. + return + } + prefix, ok := s.reprovideQueue.Dequeue() + if ok { + s.wg.Add(1) + go func(prefix bitstr.Key) { + s.batchReprovide(prefix, false) + s.workerPool.Release(burstWorker) + s.wg.Done() + }(prefix) + } else { + s.workerPool.Release(burstWorker) + } + } +} + +func (s *SweepingProvider) batchProvide(prefix bitstr.Key, keys []mh.Multihash) { + if len(keys) == 0 { + return + } + addrInfo, ok := s.selfAddrInfo() + if !ok { + // Don't provide if the node doesn't have a valid address to include in the + // provider record. + return + } + if len(keys) <= individualProvideThreshold { + // Don't fully explore the region, execute simple DHT provides for these + // keys. It isn't worth it to fully explore a region for just a few keys. + s.individualProvide(prefix, keys, false, false) + return + } + + regions, coveredPrefix, err := s.exploreSwarm(prefix) + if err != nil { + s.failedProvide(prefix, keys, fmt.Errorf("provide '%s': %w", prefix, err)) + return + } + logger.Debugf("provide: requested prefix '%s' (len %d), prefix covered '%s' (len %d)", prefix, len(prefix), coveredPrefix, len(coveredPrefix)) + + // Add any key matching the covered prefix from the provide queue to the + // current provide batch. + extraKeys := s.provideQueue.DequeueMatching(coveredPrefix) + keys = append(keys, extraKeys...) + regions = keyspace.AssignKeysToRegions(regions, keys) + + if !s.provideRegions(regions, addrInfo, false, false) { + logger.Warnf("failed to provide any region for prefix %s", prefix) + } +} + +func (s *SweepingProvider) batchReprovide(prefix bitstr.Key, periodicReprovide bool) { + addrInfo, ok := s.selfAddrInfo() + if !ok { + // Don't provide if the node doesn't have a valid address to include in the + // provider record. + return + } + + // Load keys matching prefix from the keystore. + keys, err := s.keystore.Get(s.ctx, prefix) + if err != nil { + s.failedReprovide(prefix, fmt.Errorf("couldn't reprovide, error when loading keys: %s", err)) + if periodicReprovide { + s.reschedulePrefix(prefix) + } + return + } + if len(keys) == 0 { + logger.Infof("No keys to reprovide for prefix %s", prefix) + return + } + if len(keys) <= individualProvideThreshold { + // Don't fully explore the region, execute simple DHT provides for these + // keys. It isn't worth it to fully explore a region for just a few keys. + s.individualProvide(prefix, keys, true, periodicReprovide) + return + } + + regions, coveredPrefix, err := s.exploreSwarm(prefix) + if err != nil { + s.failedReprovide(prefix, fmt.Errorf("reprovide '%s': %w", prefix, err)) + if periodicReprovide { + s.reschedulePrefix(prefix) + } + return + } + logger.Debugf("reprovide: requested prefix '%s' (len %d), prefix covered '%s' (len %d)", prefix, len(prefix), coveredPrefix, len(coveredPrefix)) + + regions = s.claimRegionReprovide(regions) + + // Remove all keys matching coveredPrefix from provide queue. No need to + // provide them anymore since they are about to be reprovided. + s.provideQueue.DequeueMatching(coveredPrefix) + // Remove covered prefix from the reprovide queue, so since we are about the + // reprovide the region. + s.reprovideQueue.Remove(coveredPrefix) + + // When reproviding a region, remove all scheduled regions starting with + // the currently covered prefix. + s.scheduleLk.Lock() + s.unscheduleSubsumedPrefixesNoLock(coveredPrefix) + s.scheduleLk.Unlock() + + if len(coveredPrefix) < len(prefix) { + // Covered prefix is shorter than the requested one, load all the keys + // matching the covered prefix from the keystore. + keys, err = s.keystore.Get(s.ctx, coveredPrefix) + if err != nil { + err = fmt.Errorf("couldn't reprovide, error when loading keys: %s", err) + s.failedReprovide(prefix, err) + if periodicReprovide { + s.reschedulePrefix(prefix) + } + } + } + regions = keyspace.AssignKeysToRegions(regions, keys) + + if !s.provideRegions(regions, addrInfo, true, periodicReprovide) { + logger.Warnf("failed to reprovide any region for prefix %s", prefix) + } +} + +func (s *SweepingProvider) failedProvide(prefix bitstr.Key, keys []mh.Multihash, err error) { + logger.Warn(err) + // Put keys back to the provide queue. + s.provideQueue.Enqueue(prefix, keys...) + + s.connectivity.TriggerCheck() +} + +func (s *SweepingProvider) failedReprovide(prefix bitstr.Key, err error) { + logger.Warn(err) + // Put prefix in the reprovide queue. + s.reprovideQueue.Enqueue(prefix) + + s.connectivity.TriggerCheck() +} + +// selfAddrInfo returns the current peer.AddrInfo to be used in the provider +// records sent to remote peers. +// +// If the node currently has no valid multiaddress, return an empty AddrInfo +// and false. +func (s *SweepingProvider) selfAddrInfo() (peer.AddrInfo, bool) { + addrs := s.getSelfAddrs() + if len(addrs) == 0 { + logger.Warn("provider: no self addresses available for providing keys") + return peer.AddrInfo{}, false + } + return peer.AddrInfo{ID: s.peerid, Addrs: addrs}, true +} + +// individualProvide provides the keys sharing the same prefix to the network +// without exploring the associated keyspace regions. It performs "normal" DHT +// provides for the supplied keys, handles failures and schedules next +// reprovide is necessary. +func (s *SweepingProvider) individualProvide(prefix bitstr.Key, keys []mh.Multihash, reprovide bool, periodicReprovide bool) { + if len(keys) == 0 { + return + } + + var provideErr error + if len(keys) == 1 { + coveredPrefix, err := s.vanillaProvide(keys[0]) + if err == nil { + s.provideCounter.Add(s.ctx, 1) + } else if !reprovide { + // Put the key back in the provide queue. + s.failedProvide(prefix, keys, fmt.Errorf("individual provide failed for prefix '%s', %w", prefix, err)) + } + provideErr = err + if periodicReprovide { + // Schedule next reprovide for the prefix that was actually covered by + // the GCP, otherwise we may schedule a reprovide for a prefix too short + // or too long. + s.reschedulePrefix(coveredPrefix) + } + } else { + wg := sync.WaitGroup{} + success := atomic.Bool{} + for _, key := range keys { + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.vanillaProvide(key) + if err == nil { + s.provideCounter.Add(s.ctx, 1) + success.Store(true) + } else if !reprovide { + // Individual provide failed, put key back in provide queue. + s.failedProvide(prefix, []mh.Multihash{key}, err) + } + }() + } + wg.Wait() + + if !success.Load() { + // Only errors if all provides failed. + provideErr = fmt.Errorf("all individual provides failed for prefix %s", prefix) + } + if periodicReprovide { + s.reschedulePrefix(prefix) + } + } + if reprovide && provideErr != nil { + s.failedReprovide(prefix, provideErr) + } +} + +// provideRegions contains common logic to batchProvide() and batchReprovide(). +// It iterate over supplied regions, and allocates the regions provider records +// to the appropriate DHT servers. +func (s *SweepingProvider) provideRegions(regions []keyspace.Region, addrInfo peer.AddrInfo, reprovide, periodicReprovide bool) bool { + errCount := 0 + for _, r := range regions { + allKeys := keyspace.AllValues(r.Keys, s.order) + if len(allKeys) == 0 { + if reprovide { + s.releaseRegionReprovide(r.Prefix) + } + continue + } + // Add keys to local provider store + for _, h := range allKeys { + s.addLocalRecord(h) + } + keysAllocations := keyspace.AllocateToKClosest(r.Keys, r.Peers, s.replicationFactor) + err := s.sendProviderRecords(keysAllocations, addrInfo) + if reprovide { + s.releaseRegionReprovide(r.Prefix) + if periodicReprovide { + s.reschedulePrefix(r.Prefix) + } + } + if err != nil { + errCount++ + err = fmt.Errorf("cannot send provider records for region %s: %s", r.Prefix, err) + if reprovide { + s.failedReprovide(r.Prefix, err) + } else { // provide operation + s.failedProvide(r.Prefix, keyspace.AllValues(r.Keys, s.order), err) + } + continue + } + s.provideCounter.Add(s.ctx, int64(len(allKeys))) + } + // If at least 1 regions was provided, we don't consider it a failure. + return errCount < len(regions) +} + +// claimRegionReprovide checks if the region is already being reprovided by +// another thread. If not it marks the region as being currently reprovided. +func (s *SweepingProvider) claimRegionReprovide(regions []keyspace.Region) []keyspace.Region { + out := regions[:0] + s.ongoingReprovidesLk.Lock() + defer s.ongoingReprovidesLk.Unlock() + for _, r := range regions { + if r.Peers.IsEmptyLeaf() { + continue + } + if _, ok := keyspace.FindPrefixOfKey(s.ongoingReprovides, r.Prefix); !ok { + // Prune superstrings of r.Prefix if any + keyspace.PruneSubtrie(s.ongoingReprovides, r.Prefix) + out = append(out, r) + s.ongoingReprovides.Add(r.Prefix, struct{}{}) + } + } + return out +} + +// releaseRegionReprovide marks the region as no longer being reprovided. +func (s *SweepingProvider) releaseRegionReprovide(prefix bitstr.Key) { + s.ongoingReprovidesLk.Lock() + defer s.ongoingReprovidesLk.Unlock() + s.ongoingReprovides.Remove(prefix) +} + +// ProvideOnce only sends provider records for the given keys out to the DHT +// swarm. It does NOT take the responsibility to reprovide these keys. +// +// Returns an error if the keys couldn't be added to the provide queue. This +// can happen if the provider is closed or if the node is currently Offline +// (either never bootstrapped, or disconnected since more than `OfflineDelay`). +// The schedule and provide queue depend on the network size, hence recent +// network connectivity is essential. +func (s *SweepingProvider) ProvideOnce(keys ...mh.Multihash) error { + if s.closed() { + return ErrClosed + } + return s.handleProvide(true, false, keys...) +} + +// StartProviding provides the given keys to the DHT swarm unless they were +// already provided in the past. The keys will be periodically reprovided until +// StopProviding is called for the same keys or user defined garbage collection +// deletes the keys. +// +// Returns an error if the keys couldn't be added to the provide queue. This +// can happen if the provider is closed or if the node is currently Offline +// (either never bootstrapped, or disconnected since more than `OfflineDelay`). +// The schedule and provide queue depend on the network size, hence recent +// network connectivity is essential. +func (s *SweepingProvider) StartProviding(force bool, keys ...mh.Multihash) error { + if s.closed() { + return ErrClosed + } + return s.handleProvide(force, true, keys...) +} + +// StopProviding stops reproviding the given keys to the DHT swarm. The node +// stops being referred as a provider when the provider records in the DHT +// swarm expire. +func (s *SweepingProvider) StopProviding(keys ...mh.Multihash) error { + if s.closed() { + return ErrClosed + } + err := s.keystore.Delete(s.ctx, keys...) + if err != nil { + err = fmt.Errorf("failed to stop providing keys: %w", err) + } + s.provideQueue.Remove(keys...) + return err +} + +// Clear clears the all the keys from the provide queue and returns the number +// of keys that were cleared. +// +// The keys are not deleted from the keystore, so they will continue to be +// reprovided as scheduled. +func (s *SweepingProvider) Clear() int { + if s.closed() { + return 0 + } + return s.provideQueue.Clear() +} + +// ProvideState encodes the current relationship between this node and `key`. +type ProvideState uint8 + +const ( + StateUnknown ProvideState = iota // we have no record of the key + StateQueued // key is queued to be provided + StateProvided // key was provided at least once +) + +// ProvideStatus reports the provider’s view of a key. +// +// When `state == StateProvided`, `lastProvide` is the wall‑clock time of the +// most recent successful provide operation (UTC). +// For `StateQueued` or `StateUnknown`, `lastProvide` is the zero `time.Time`. +func (s *SweepingProvider) ProvideStatus(key mh.Multihash) (state ProvideState, lastProvide time.Time) { + // TODO: implement me + return StateUnknown, time.Time{} +} + +// AddToSchedule makes sure the prefixes associated with the supplied keys are +// scheduled to be reprovided. +// +// Returns an error if the provider is closed or if the node is currently +// Offline (either never bootstrapped, or disconnected since more than +// `OfflineDelay`). The schedule depends on the network size, hence recent +// network connectivity is essential. +func (s *SweepingProvider) AddToSchedule(keys ...mh.Multihash) error { + if s.closed() { + return ErrClosed + } + if s.isOffline() { + return ErrOffline + } + _, err := s.groupAndScheduleKeysByPrefix(keys, true) + return err +} + +// RefreshSchedule scans the KeyStore for any keys that are not currently +// scheduled for reproviding. If such keys are found, it schedules their +// associated keyspace region to be reprovided. +// +// This function doesn't remove prefixes that have no keys from the schedule. +// This is done automatically during the reprovide operation if a region has no +// keys. +// +// Returns an error if the provider is closed or if the node is currently +// Offline (either never bootstrapped, or disconnected since more than +// `OfflineDelay`). The schedule depends on the network size, hence recent +// network connectivity is essential. +func (s *SweepingProvider) RefreshSchedule() error { + if s.closed() { + return ErrClosed + } + // Look for prefixes not included in the schedule + s.scheduleLk.Lock() + prefixLen, err := s.getAvgPrefixLenNoLock() + if err != nil { + s.scheduleLk.Unlock() + return err + } + gaps := keyspace.TrieGaps(s.schedule) + s.scheduleLk.Unlock() + + missing := make([]bitstr.Key, 0, len(gaps)) + for _, p := range gaps { + if len(p) >= prefixLen { + missing = append(missing, p) + } else { + missing = append(missing, keyspace.ExtendBinaryPrefix(p, prefixLen)...) + } + } + + // Only keep the missing prefixes for which there are keys in the KeyStore. + toInsert := make([]bitstr.Key, 0) + for _, p := range missing { + ok, err := s.keystore.ContainsPrefix(s.ctx, p) + if err != nil { + logger.Warnf("couldn't refresh schedule for prefix %s: %s", p, err) + } + if ok { + toInsert = append(toInsert, p) + } + } + if len(toInsert) == 0 { + return nil + } + + // Insert prefixes into the schedule + s.scheduleLk.Lock() + for _, p := range toInsert { + s.schedulePrefixNoLock(p, false) + } + s.scheduleLk.Unlock() + return nil +} diff --git a/provider/provider_test.go b/provider/provider_test.go new file mode 100644 index 000000000..0fa74acb8 --- /dev/null +++ b/provider/provider_test.go @@ -0,0 +1,1424 @@ +//go:build go1.25 +// +build go1.25 + +package provider + +import ( + "bytes" + "context" + "crypto/sha256" + "errors" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "testing/synctest" + "time" + + "github.com/guillaumemichel/reservedpool" + ds "github.com/ipfs/go-datastore" + logging "github.com/ipfs/go-log/v2" + "github.com/ipfs/go-test/random" + "github.com/libp2p/go-libp2p/core/peer" + ma "github.com/multiformats/go-multiaddr" + mh "github.com/multiformats/go-multihash" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" + + "github.com/probe-lab/go-libdht/kad/key" + "github.com/probe-lab/go-libdht/kad/key/bit256" + "github.com/probe-lab/go-libdht/kad/key/bitstr" + "github.com/probe-lab/go-libdht/kad/trie" + + pb "github.com/libp2p/go-libp2p-kad-dht/pb" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/connectivity" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/keyspace" + "github.com/libp2p/go-libp2p-kad-dht/provider/internal/queue" + "github.com/libp2p/go-libp2p-kad-dht/provider/keystore" + kb "github.com/libp2p/go-libp2p-kbucket" + + "github.com/stretchr/testify/require" +) + +func genMultihashes(n int) []mh.Multihash { + mhs := make([]mh.Multihash, n) + var err error + for i := range n { + mhs[i], err = mh.Sum([]byte(strconv.Itoa(i)), mh.SHA2_256, -1) + if err != nil { + panic(err) + } + } + return mhs +} + +// genBalancedMultihashes generates 2^exponent multihashes, with balanced +// prefixes, in a random order. +// +// e.g genBalancedMultihashes(3) will generate 8 multihashes, with each +// kademlia identifier starting with a distinct prefix (000, 001, 010, ..., +// 111) of len 3. +func genBalancedMultihashes(exponent int) []mh.Multihash { + n := 1 << exponent + mhs := make([]mh.Multihash, 0, n) + seen := make(map[bitstr.Key]struct{}, n) + for i := 0; len(mhs) < n; i++ { + h, err := mh.Sum([]byte(strconv.Itoa(i)), mh.SHA2_256, -1) + if err != nil { + panic(err) + } + prefix := bitstr.Key(key.BitString(keyspace.MhToBit256(h))[:exponent]) + if _, ok := seen[prefix]; !ok { + mhs = append(mhs, h) + seen[prefix] = struct{}{} + } + } + return mhs +} + +func genMultihashesMatchingPrefix(prefix bitstr.Key, n int) []mh.Multihash { + mhs := make([]mh.Multihash, 0, n) + for i := 0; len(mhs) < n; i++ { + h := random.Multihashes(1)[0] + k := keyspace.MhToBit256(h) + if keyspace.IsPrefix(prefix, k) { + mhs = append(mhs, h) + } + } + return mhs +} + +var _ pb.MessageSender = (*mockMsgSender)(nil) + +type mockMsgSender struct { + sendMessageFunc func(ctx context.Context, p peer.ID, m *pb.Message) error +} + +func (ms *mockMsgSender) SendRequest(ctx context.Context, p peer.ID, m *pb.Message) (*pb.Message, error) { + // Unused + return nil, nil +} + +func (ms *mockMsgSender) SendMessage(ctx context.Context, p peer.ID, m *pb.Message) error { + if ms.sendMessageFunc == nil { + return nil + } + return ms.sendMessageFunc(ctx, p, m) +} + +var _ KadClosestPeersRouter = (*mockRouter)(nil) + +type mockRouter struct { + getClosestPeersFunc func(ctx context.Context, k string) ([]peer.ID, error) +} + +func (r *mockRouter) GetClosestPeers(ctx context.Context, k string) ([]peer.ID, error) { + if r.getClosestPeersFunc == nil { + return nil, nil + } + return r.getClosestPeersFunc(ctx, k) +} + +func TestProvideKeysToPeer(t *testing.T) { + msgCount := 0 + msgSender := &mockMsgSender{ + sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error { + msgCount++ + return errors.New("error") + }, + } + prov := SweepingProvider{ + msgSender: msgSender, + } + + nKeys := 16 + pid, err := peer.Decode("12BoooooPEER") + require.NoError(t, err) + mhs := genMultihashes(nKeys) + pmes := &pb.Message{} + + // All ADD_PROVIDER RPCs fail, return an error after reprovideInitialFailuresAllowed+1 attempts + err = prov.provideKeysToPeer(pid, mhs, pmes) + require.Error(t, err) + require.Equal(t, maxConsecutiveProvideFailuresAllowed+1, msgCount) + + // Only fail 33% of requests. The operation should be considered a success. + msgCount = 0 + msgSender.sendMessageFunc = func(ctx context.Context, p peer.ID, m *pb.Message) error { + msgCount++ + if msgCount%3 == 0 { + return errors.New("error") + } + return nil + } + err = prov.provideKeysToPeer(pid, mhs, pmes) + require.NoError(t, err) + require.Equal(t, nKeys, msgCount) +} + +func TestKeysAllocationsToPeers(t *testing.T) { + nKeys := 1024 + nPeers := 128 + replicationFactor := 10 + + mhs := genMultihashes(nKeys) + keysTrie := trie.New[bit256.Key, mh.Multihash]() + for _, c := range mhs { + keysTrie.Add(keyspace.MhToBit256(c), c) + } + peers := random.Peers(nPeers) + peersTrie := trie.New[bit256.Key, peer.ID]() + for i := range peers { + peersTrie.Add(keyspace.PeerIDToBit256(peers[i]), peers[i]) + } + keysAllocations := keyspace.AllocateToKClosest(keysTrie, peersTrie, replicationFactor) + + for _, c := range mhs { + k := sha256.Sum256(c) + closestPeers := kb.SortClosestPeers(peers, k[:])[:replicationFactor] + for _, p := range closestPeers[:replicationFactor] { + require.Contains(t, keysAllocations[p], c) + } + for _, p := range closestPeers[replicationFactor:] { + require.NotContains(t, keysAllocations[p], c) + } + } +} + +func TestReprovideTimeForPrefixWithOrderZero(t *testing.T) { + s := SweepingProvider{ + reprovideInterval: 16 * time.Second, + order: bit256.ZeroKey(), + } + + require.Equal(t, 0*time.Second, s.reprovideTimeForPrefix("0")) + require.Equal(t, 8*time.Second, s.reprovideTimeForPrefix("1")) + require.Equal(t, 0*time.Second, s.reprovideTimeForPrefix("000")) + require.Equal(t, 8*time.Second, s.reprovideTimeForPrefix("1000")) + require.Equal(t, 10*time.Second, s.reprovideTimeForPrefix("1010")) + require.Equal(t, 15*time.Second, s.reprovideTimeForPrefix("1111")) +} + +func TestReprovideTimeForPrefixWithCustomOrder(t *testing.T) { + s := SweepingProvider{ + reprovideInterval: 16 * time.Second, + order: bit256.NewKey(bytes.Repeat([]byte{0xFF}, 32)), // 111...1 + } + + require.Equal(t, 0*time.Second, s.reprovideTimeForPrefix("1")) + require.Equal(t, 8*time.Second, s.reprovideTimeForPrefix("0")) + require.Equal(t, 0*time.Second, s.reprovideTimeForPrefix("111")) + require.Equal(t, 8*time.Second, s.reprovideTimeForPrefix("0111")) + require.Equal(t, 10*time.Second, s.reprovideTimeForPrefix("0101")) + require.Equal(t, 15*time.Second, s.reprovideTimeForPrefix("0000")) +} + +func TestClosestPeersToPrefixRandom(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + replicationFactor := 10 + nPeers := 128 + peers := random.Peers(nPeers) + peersTrie := trie.New[bit256.Key, peer.ID]() + for _, p := range peers { + peersTrie.Add(keyspace.PeerIDToBit256(p), p) + } + + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + sortedPeers := kb.SortClosestPeers(peers, kb.ConvertKey(k)) + return sortedPeers[:min(replicationFactor, len(peers))], nil + }, + } + + r := SweepingProvider{ + router: router, + replicationFactor: replicationFactor, + connectivity: noopConnectivityChecker(), + } + r.connectivity.Start() + defer r.connectivity.Close() + + synctest.Wait() + require.True(t, r.connectivity.IsOnline()) + + for _, prefix := range []bitstr.Key{"", "0", "1", "00", "01", "10", "11", "000", "001", "010", "011", "100", "101", "110", "111"} { + closestPeers, err := r.closestPeersToPrefix(prefix) + require.NoError(t, err, "failed for prefix %s", prefix) + subtrieSize := 0 + currPrefix := prefix + // Reduce prefix if necessary as closestPeersToPrefix always returns at + // least replicationFactor peers if possible. + for { + subtrie, ok := keyspace.FindSubtrie(peersTrie, currPrefix) + require.True(t, ok) + subtrieSize = subtrie.Size() + if subtrieSize >= replicationFactor { + break + } + currPrefix = currPrefix[:len(currPrefix)-1] + } + require.Len(t, closestPeers, subtrieSize, "prefix: %s", prefix) + } + }) +} + +func TestGroupAndScheduleKeysByPrefix(t *testing.T) { + prov := SweepingProvider{ + order: bit256.ZeroKey(), + reprovideInterval: time.Hour, + + schedule: trie.New[bitstr.Key, time.Duration](), + scheduleTimer: time.NewTimer(time.Hour), + + cachedAvgPrefixLen: 3, + lastAvgPrefixLen: time.Now(), + } + + mhs00000 := genMultihashesMatchingPrefix("00000", 3) + mhs00000 = append(mhs00000, mhs00000[0]) + mhs1000 := genMultihashesMatchingPrefix("0100", 2) + + mhs := append(mhs00000, mhs1000...) + + prefixes, err := prov.groupAndScheduleKeysByPrefix(mhs, false) + require.NoError(t, err) + require.Len(t, prefixes, 2) + require.Contains(t, prefixes, bitstr.Key("000")) + require.Len(t, prefixes["000"], 3) // no duplicate entry + require.Contains(t, prefixes, bitstr.Key("010")) + require.Len(t, prefixes["010"], 2) + + // Schedule is still empty + require.True(t, prov.schedule.IsEmptyLeaf()) + + prefixes, err = prov.groupAndScheduleKeysByPrefix(mhs, true) + require.NoError(t, err) + require.Len(t, prefixes, 2) + require.Contains(t, prefixes, bitstr.Key("000")) + require.Len(t, prefixes["000"], 3) + require.Contains(t, prefixes, bitstr.Key("010")) + require.Len(t, prefixes["010"], 2) + + // Schedule now contains the 2 prefixes + require.Equal(t, 2, prov.schedule.Size()) + + // Manually add prefix to schedule + prov.schedule.Add(bitstr.Key("11111"), 0*time.Second) + mhs11111 := genMultihashesMatchingPrefix("11111", 4) + mhs1110 := genMultihashesMatchingPrefix("1110", 4) + mhs = append(mhs11111, mhs1110...) + prefixes, err = prov.groupAndScheduleKeysByPrefix(mhs, true) + require.NoError(t, err) + // All keys should be consolidated into "111" + require.Len(t, prefixes, 1) + require.Contains(t, prefixes, bitstr.Key("111")) + require.Len(t, prefixes["111"], 8) + + // "11111" is removed from schedule + found, _ := trie.Find(prov.schedule, bitstr.Key("11111")) + require.False(t, found) + found, _ = trie.Find(prov.schedule, bitstr.Key("111")) + require.True(t, found) + + prov.schedule.Add(bitstr.Key("10"), 0*time.Second) + + mhs1 := genMultihashesMatchingPrefix("10", 6) + prefixes, err = prov.groupAndScheduleKeysByPrefix(mhs1, true) + require.NoError(t, err) + require.Len(t, prefixes, 1) + require.Contains(t, prefixes, bitstr.Key("10")) + require.Len(t, prefixes["10"], 6) +} + +func noWarningsNorAbove(obsLogs *observer.ObservedLogs) bool { + return obsLogs.Filter(func(le observer.LoggedEntry) bool { + return le.Level >= zap.WarnLevel + }).Len() == 0 +} + +func takeAllContainsErr(obsLogs *observer.ObservedLogs, errStr string) bool { + for _, le := range obsLogs.TakeAll() { + if le.Level >= zap.WarnLevel && strings.Contains(le.Message, errStr) { + return true + } + } + return false +} + +func noopConnectivityChecker() *connectivity.ConnectivityChecker { + connChecker, err := connectivity.New(func() bool { return true }) + if err != nil { + panic(err) + } + return connChecker +} + +func provideCounter() metric.Int64Counter { + meter := otel.Meter("github.com/libp2p/go-libp2p-kad-dht/provider") + provideCounter, err := meter.Int64Counter( + "total_provide_count", + metric.WithDescription("Number of successful provides since node is running"), + ) + if err != nil { + panic(err) + } + return provideCounter +} + +func TestIndividualProvideSingle(t *testing.T) { + obsCore, obsLogs := observer.New(zap.WarnLevel) + logging.SetPrimaryCore(obsCore) + logging.SetAllLoggers(logging.LevelError) + logging.SetLogLevel(LoggerName, "warn") + + mhs := genMultihashes(1) + prefix := bitstr.Key("1011101111") + + closestPeers := []peer.ID{peer.ID("12BoooooPEER1"), peer.ID("12BoooooPEER2")} + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + return closestPeers, nil + }, + } + + advertisements := make(map[peer.ID]int, len(closestPeers)) + msgSenderLk := sync.Mutex{} + msgSender := &mockMsgSender{ + sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error { + msgSenderLk.Lock() + defer msgSenderLk.Unlock() + advertisements[p]++ + return nil + }, + } + r := SweepingProvider{ + router: router, + msgSender: msgSender, + reprovideInterval: time.Hour, + maxProvideConnsPerWorker: 2, + provideQueue: queue.NewProvideQueue(), + reprovideQueue: queue.NewReprovideQueue(), + connectivity: noopConnectivityChecker(), + schedule: trie.New[bitstr.Key, time.Duration](), + scheduleTimer: time.NewTimer(time.Hour), + getSelfAddrs: func() []ma.Multiaddr { return nil }, + addLocalRecord: func(mh mh.Multihash) error { return nil }, + provideCounter: provideCounter(), + } + + assertAdvertisementCount := func(n int) { + msgSenderLk.Lock() + defer msgSenderLk.Unlock() + for _, count := range advertisements { + require.Equal(t, n, count) + } + } + + // Providing no keys returns no error + r.individualProvide(prefix, nil, false, false) + require.True(t, noWarningsNorAbove(obsLogs)) + assertAdvertisementCount(0) + + // Providing a single key - success + r.individualProvide(prefix, mhs, false, false) + require.True(t, noWarningsNorAbove(obsLogs)) + assertAdvertisementCount(1) + + // Providing a single key - failure + gcpErr := errors.New("GetClosestPeers error") + router.getClosestPeersFunc = func(ctx context.Context, k string) ([]peer.ID, error) { + return nil, gcpErr + } + r.individualProvide(prefix, mhs, false, false) + require.True(t, takeAllContainsErr(obsLogs, gcpErr.Error())) + assertAdvertisementCount(1) + // Verify failed key ends up in the provide queue. + _, keys, ok := r.provideQueue.Dequeue() + require.True(t, ok) + require.Equal(t, mhs, keys) + + // Reproviding a single key - failure + r.individualProvide(prefix, mhs, true, true) + require.True(t, takeAllContainsErr(obsLogs, gcpErr.Error())) + assertAdvertisementCount(1) + // Verify failed prefix ends up in the reprovide queue. + dequeued, ok := r.reprovideQueue.Dequeue() + require.True(t, ok) + require.Equal(t, prefix, dequeued) +} + +func TestIndividualProvideMultiple(t *testing.T) { + obsCore, obsLogs := observer.New(zap.WarnLevel) + logging.SetPrimaryCore(obsCore) + logging.SetAllLoggers(logging.LevelError) + logging.SetLogLevel(LoggerName, "warn") + + ks := genMultihashes(2) + prefix := bitstr.Key("") + closestPeers := []peer.ID{peer.ID("12BoooooPEER1"), peer.ID("12BoooooPEER2")} + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + return closestPeers, nil + }, + } + advertisements := make(map[string]map[peer.ID]int, len(closestPeers)) + for _, k := range ks { + advertisements[string(k)] = make(map[peer.ID]int, len(closestPeers)) + } + msgSenderLk := sync.Mutex{} + msgSender := &mockMsgSender{ + sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error { + msgSenderLk.Lock() + defer msgSenderLk.Unlock() + _, k, err := mh.MHFromBytes(m.GetKey()) + require.NoError(t, err) + advertisements[string(k)][p]++ + return nil + }, + } + r := SweepingProvider{ + router: router, + msgSender: msgSender, + reprovideInterval: time.Hour, + maxProvideConnsPerWorker: 2, + provideQueue: queue.NewProvideQueue(), + reprovideQueue: queue.NewReprovideQueue(), + connectivity: noopConnectivityChecker(), + schedule: trie.New[bitstr.Key, time.Duration](), + scheduleTimer: time.NewTimer(time.Hour), + getSelfAddrs: func() []ma.Multiaddr { return nil }, + addLocalRecord: func(mh mh.Multihash) error { return nil }, + provideCounter: provideCounter(), + } + + assertAdvertisementCount := func(n int) { + msgSenderLk.Lock() + defer msgSenderLk.Unlock() + for _, peerAllocs := range advertisements { + for _, count := range peerAllocs { + require.Equal(t, n, count) + } + } + } + + // Providing two keys - success + r.individualProvide(prefix, ks, false, false) + require.True(t, noWarningsNorAbove(obsLogs)) + assertAdvertisementCount(1) + + // Providing two keys - failure + gcpErr := errors.New("GetClosestPeers error") + router.getClosestPeersFunc = func(ctx context.Context, k string) ([]peer.ID, error) { + return nil, gcpErr + } + r.individualProvide(prefix, ks, false, false) + require.True(t, takeAllContainsErr(obsLogs, gcpErr.Error())) + assertAdvertisementCount(1) + // Assert keys are added to provide queue + require.Equal(t, len(ks), r.provideQueue.Size()) + pendingKeys := []mh.Multihash{} + for !r.provideQueue.IsEmpty() { + _, keys, ok := r.provideQueue.Dequeue() + require.True(t, ok) + pendingKeys = append(pendingKeys, keys...) + } + require.ElementsMatch(t, pendingKeys, ks) + + // Reproviding two keys - failure + r.individualProvide(prefix, ks, true, true) + require.True(t, takeAllContainsErr(obsLogs, "all individual provides failed for prefix")) + assertAdvertisementCount(1) + // Assert prefix is added to reprovide queue. + dequeued, ok := r.reprovideQueue.Dequeue() + require.True(t, ok) + require.Equal(t, prefix, dequeued) + + // Providing two keys - 1 success, 1 failure + lk := sync.Mutex{} + counter := 0 + router.getClosestPeersFunc = func(ctx context.Context, k string) ([]peer.ID, error) { + lk.Lock() + defer lk.Unlock() + counter++ + if counter%2 == 1 { + return nil, errors.New("GetClosestPeers error") + } + return closestPeers, nil + } + + r.individualProvide(prefix, ks, false, false) + require.True(t, takeAllContainsErr(obsLogs, gcpErr.Error())) + // Verify one key was now provided 2x, and other key only 1x since it just failed. + msgSenderLk.Lock() + require.Equal(t, 3, advertisements[string(ks[0])][closestPeers[0]]+advertisements[string(ks[1])][closestPeers[0]]) + require.Equal(t, 3, advertisements[string(ks[0])][closestPeers[1]]+advertisements[string(ks[1])][closestPeers[1]]) + msgSenderLk.Unlock() + + // Failed key was added to provide queue + require.Equal(t, 1, r.provideQueue.Size()) + _, pendingKeys, ok = r.provideQueue.Dequeue() + require.True(t, ok) + require.Len(t, pendingKeys, 1) + require.Contains(t, ks, pendingKeys[0]) + require.True(t, r.reprovideQueue.IsEmpty()) + require.True(t, r.provideQueue.IsEmpty()) + + r.individualProvide(prefix, ks, true, true) + require.True(t, noWarningsNorAbove(obsLogs)) + // Verify only one of the 2 keys was provided. Providing failed for the other. + msgSenderLk.Lock() + require.Equal(t, 4, advertisements[string(ks[0])][closestPeers[0]]+advertisements[string(ks[1])][closestPeers[0]]) + require.Equal(t, 4, advertisements[string(ks[0])][closestPeers[1]]+advertisements[string(ks[1])][closestPeers[1]]) + msgSenderLk.Unlock() + + // Failed key shouldn't be added to provide nor reprovide queue, since the + // reprovide didn't completely failed. + require.True(t, r.reprovideQueue.IsEmpty()) + require.True(t, r.provideQueue.IsEmpty()) +} + +func TestHandleReprovide(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + online := atomic.Bool{} + online.Store(true) + connectivityCheckInterval := time.Second + offlineDelay := time.Minute + connChecker, err := connectivity.New( + func() bool { return online.Load() }, + connectivity.WithOfflineDelay(offlineDelay), + connectivity.WithOnlineCheckInterval(connectivityCheckInterval), + ) + require.NoError(t, err) + defer connChecker.Close() + + prov := SweepingProvider{ + order: bit256.ZeroKey(), + + connectivity: connChecker, + + cycleStart: time.Now(), + scheduleTimer: time.NewTimer(time.Hour), + schedule: trie.New[bitstr.Key, time.Duration](), + + reprovideQueue: queue.NewReprovideQueue(), + workerPool: reservedpool.New[workerType](1, nil), // single worker + + reprovideInterval: time.Minute, + maxReprovideDelay: 5 * time.Second, + + getSelfAddrs: func() []ma.Multiaddr { return nil }, + } + prov.scheduleTimer.Stop() + connChecker.Start() + defer connChecker.Close() + + prefixes := []bitstr.Key{ + "00", + "10", + "11", + } + + // Empty schedule -> early return + prov.handleReprovide() + require.Zero(t, prov.scheduleCursor) + + // Single prefix in schedule + prov.schedule.Add(prefixes[0], prov.reprovideTimeForPrefix(prefixes[0])) + prov.scheduleCursor = prefixes[0] + prov.handleReprovide() + require.Equal(t, prefixes[0], prov.scheduleCursor) + + // Two prefixes in schedule + time.Sleep(1) // advance 1 tick into the reprovide cycle + prov.schedule.Add(prefixes[1], prov.reprovideTimeForPrefix(prefixes[1])) + prov.handleReprovide() // reprovides prefixes[0], set scheduleCursor to prefixes[1] + require.Equal(t, prefixes[1], prov.scheduleCursor) + + // Wait more than reprovideInterval to call handleReprovide again. + // All prefixes should be added to the reprovide queue. + time.Sleep(prov.reprovideInterval + 1) + require.True(t, prov.reprovideQueue.IsEmpty()) + prov.handleReprovide() + require.Equal(t, prefixes[1], prov.scheduleCursor) + + require.Equal(t, 2, prov.reprovideQueue.Size()) + dequeued, ok := prov.reprovideQueue.Dequeue() + require.True(t, ok) + require.Equal(t, prefixes[0], dequeued) + dequeued, ok = prov.reprovideQueue.Dequeue() + require.True(t, ok) + require.Equal(t, prefixes[1], dequeued) + require.True(t, prov.reprovideQueue.IsEmpty()) + + // Go in time past prefixes[1] and prefixes[2] + prov.schedule.Add(prefixes[2], prov.reprovideTimeForPrefix(prefixes[2])) + time.Sleep(3 * prov.reprovideInterval / 4) + // reprovides prefixes[1], add prefixes[2] to reprovide queue, set + // scheduleCursor to prefixes[0] + prov.handleReprovide() + require.Equal(t, prefixes[0], prov.scheduleCursor) + + require.Equal(t, 1, prov.reprovideQueue.Size()) + dequeued, ok = prov.reprovideQueue.Dequeue() + require.True(t, ok) + require.Equal(t, prefixes[2], dequeued) + require.True(t, prov.reprovideQueue.IsEmpty()) + + time.Sleep(prov.reprovideInterval / 4) + + // Node goes offline -> prefixes are queued + online.Store(false) + prov.connectivity.TriggerCheck() + synctest.Wait() + require.False(t, prov.connectivity.IsOnline()) + require.True(t, prov.reprovideQueue.IsEmpty()) + prov.handleReprovide() + require.Equal(t, 1, prov.reprovideQueue.Size()) + + // Node comes back online + online.Store(true) + time.Sleep(connectivityCheckInterval) + synctest.Wait() + require.True(t, prov.connectivity.IsOnline()) + }) +} + +func TestClose(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + pid, err := peer.Decode("12BoooooPEER") + require.NoError(t, err) + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + return []peer.ID{peer.ID("12BoooooPEER1"), peer.ID("12BoooooPEER2")}, nil + }, + } + msgSender := &mockMsgSender{ + sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error { + if ctx.Err() != nil { + return ctx.Err() + } + return nil + }, + } + prov, err := New( + WithPeerID(pid), + WithRouter(router), + WithMessageSender(msgSender), + WithReplicationFactor(1), + + WithMaxWorkers(4), + WithDedicatedBurstWorkers(0), + WithDedicatedPeriodicWorkers(0), + + WithSelfAddrs(func() []ma.Multiaddr { + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + require.NoError(t, err) + return []ma.Multiaddr{addr} + }), + ) + require.NoError(t, err) + synctest.Wait() + + mhs := genMultihashes(128) + err = prov.StartProviding(false, mhs...) + require.NoError(t, err) + synctest.Wait() // wait for connectivity check + time.Sleep(prov.reprovideInterval / 2) // some keys should have been reprovided + synctest.Wait() + + err = prov.Close() + require.NoError(t, err) + synctest.Wait() + + newMh := random.Multihashes(1)[0] + + err = prov.StartProviding(false, newMh) + require.ErrorIs(t, err, ErrClosed) + err = prov.StopProviding(newMh) + require.ErrorIs(t, err, ErrClosed) + err = prov.ProvideOnce(newMh) + require.ErrorIs(t, err, ErrClosed) + require.Equal(t, 0, prov.Clear()) + + err = prov.workerPool.Acquire(burstWorker) + require.ErrorIs(t, err, reservedpool.ErrClosed) + }) +} + +func TestProvideOnce(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + pid, err := peer.Decode("12BoooooPEER") + require.NoError(t, err) + + online := atomic.Bool{} // false, start offline + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + if online.Load() { + return []peer.ID{pid}, nil + } + return nil, errors.New("offline") + }, + } + provideCount := atomic.Int32{} + msgSender := &mockMsgSender{ + sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error { + provideCount.Add(1) + return nil + }, + } + + checkInterval := time.Second + offlineDelay := time.Minute + + opts := []Option{ + WithPeerID(pid), + WithRouter(router), + WithMessageSender(msgSender), + WithSelfAddrs(func() []ma.Multiaddr { + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + require.NoError(t, err) + return []ma.Multiaddr{addr} + }), + WithOfflineDelay(offlineDelay), + WithConnectivityCheckOnlineInterval(checkInterval), + } + prov, err := New(opts...) + require.NoError(t, err) + defer prov.Close() + + h := genMultihashes(1)[0] + + // Node is offline, ProvideOne should error + err = prov.ProvideOnce(h) + require.ErrorIs(t, err, ErrOffline) + require.True(t, prov.provideQueue.IsEmpty()) + require.Equal(t, int32(0), provideCount.Load(), "should not have provided when offline 0") + + // Wait for provider to come online + online.Store(true) + time.Sleep(checkInterval) // trigger connectivity check + synctest.Wait() + require.True(t, prov.connectivity.IsOnline()) + + // Set the provider as disconnected + online.Store(false) + synctest.Wait() + err = prov.ProvideOnce(h) + require.NoError(t, err) + synctest.Wait() // wait for ProvideOnce to finish + require.Equal(t, int32(0), provideCount.Load(), "should not have provided when offline 1") + // Ensure the key is in the provide queue + _, keys, ok := prov.provideQueue.Dequeue() + require.True(t, ok) + require.Equal(t, 1, len(keys)) + require.Equal(t, h, keys[0]) + + // Set the provider as online + online.Store(true) + time.Sleep(checkInterval) // trigger connectivity check + synctest.Wait() + require.True(t, prov.connectivity.IsOnline()) + err = prov.ProvideOnce(h) + require.NoError(t, err) + synctest.Wait() + require.Equal(t, int32(1), provideCount.Load()) + }) +} + +func TestStartProvidingSingle(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + pid, err := peer.Decode("12BoooooPEER") + require.NoError(t, err) + replicationFactor := 4 + h := genMultihashes(1)[0] + + reprovideInterval := time.Hour + + prefixLen := 4 + peers := make([]peer.ID, replicationFactor) + seen := make(map[peer.ID]struct{}, replicationFactor) + peers[0], err = peer.Decode("12BooooPEER1") + require.NoError(t, err) + kbKey := keyspace.KeyToBytes(keyspace.PeerIDToBit256(peers[0])) + for i := range peers[1:] { + p, err := kb.GenRandPeerIDWithCPL(kbKey, uint(prefixLen)) + require.NoError(t, err) + if _, ok := seen[p]; ok { + continue + } + seen[p] = struct{}{} + peers[i+1] = p + } + // Sort peers from closest to h, to furthest + slices.SortFunc(peers, func(a, b peer.ID) int { + targetKey := keyspace.MhToBit256(h) + return keyspace.PeerIDToBit256(a).Xor(targetKey).Compare(keyspace.PeerIDToBit256(b).Xor(targetKey)) + }) + + getClosestPeersCount := atomic.Int32{} + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + getClosestPeersCount.Add(1) + return peers, nil + }, + } + provideCount := atomic.Int32{} + msgSender := &mockMsgSender{ + sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error { + provideCount.Add(1) + return nil + }, + } + checkInterval := time.Second + offlineDelay := time.Minute + opts := []Option{ + WithReplicationFactor(replicationFactor), + WithReprovideInterval(reprovideInterval), + WithPeerID(pid), + WithRouter(router), + WithMessageSender(msgSender), + WithSelfAddrs(func() []ma.Multiaddr { + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + require.NoError(t, err) + return []ma.Multiaddr{addr} + }), + WithOfflineDelay(offlineDelay), + WithConnectivityCheckOnlineInterval(checkInterval), + } + prov, err := New(opts...) + require.NoError(t, err) + defer prov.Close() + + synctest.Wait() + require.True(t, prov.connectivity.IsOnline()) + prov.avgPrefixLenLk.Lock() + require.Greater(t, prov.cachedAvgPrefixLen, 0) // TODO: FLAKY + prov.avgPrefixLenLk.Unlock() + + err = prov.StartProviding(true, h) + require.NoError(t, err) + synctest.Wait() + require.Equal(t, int32(len(peers)), provideCount.Load()) + expectedGCPCount := 1 + approxPrefixLenGCPCount + 1 // 1 for initial, approxPrefixLenGCPCount for prefix length estimation, 1 for the provide + require.Equal(t, expectedGCPCount, int(getClosestPeersCount.Load())) + + // Verify reprovide is scheduled. + prefix := bitstr.Key(key.BitString(keyspace.MhToBit256(h))[:prefixLen]) + prov.scheduleLk.Lock() + require.Equal(t, 1, prov.schedule.Size()) + found, reprovideTime := trie.Find(prov.schedule, prefix) + if !found { + t.Log(prefix) + t.Log(keyspace.AllEntries(prov.schedule, prov.order)[0].Key) + require.FailNow(t, "prefix not inserted in schedule") + } + require.Equal(t, prov.reprovideTimeForPrefix(prefix), reprovideTime) + prov.scheduleLk.Unlock() + + // Try to provide the same key again -> noop + err = prov.StartProviding(false, h) + require.NoError(t, err) + synctest.Wait() + require.Equal(t, int32(len(peers)), provideCount.Load()) + require.Equal(t, expectedGCPCount, int(getClosestPeersCount.Load())) + + // Verify reprovide happens as scheduled. + time.Sleep(reprovideTime) + synctest.Wait() + expectedGCPCount++ // for the reprovide + require.Equal(t, 2*int32(len(peers)), provideCount.Load()) + require.Equal(t, expectedGCPCount, int(getClosestPeersCount.Load())) + + time.Sleep(reprovideInterval) + synctest.Wait() + expectedGCPCount++ // for the reprovide + require.Equal(t, 3*int32(len(peers)), provideCount.Load()) + require.Equal(t, expectedGCPCount, int(getClosestPeersCount.Load())) + + time.Sleep(reprovideInterval) + synctest.Wait() + expectedGCPCount++ // for the reprovide + require.Equal(t, 4*int32(len(peers)), provideCount.Load()) + require.Equal(t, expectedGCPCount, int(getClosestPeersCount.Load())) + }) +} + +const bitsPerByte = 8 + +func TestStartProvidingMany(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + pid, err := peer.Decode("12BoooooPEER") + require.NoError(t, err) + + nKeysExponent := 10 + nKeys := 1 << nKeysExponent + mhs := genBalancedMultihashes(nKeysExponent) + + replicationFactor := 4 + peerPrefixBitlen := 6 + require.LessOrEqual(t, peerPrefixBitlen, bitsPerByte) + var nPeers byte = 1 << peerPrefixBitlen // 2**peerPrefixBitlen + peers := make([]peer.ID, nPeers) + for i := range nPeers { + b := i << (bitsPerByte - peerPrefixBitlen) + k := [32]byte{b} + peers[i], err = kb.GenRandPeerIDWithCPL(k[:], uint(peerPrefixBitlen)) + require.NoError(t, err) + } + + reprovideInterval := time.Hour + + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + sortedPeers := kb.SortClosestPeers(peers, kb.ConvertKey(k)) + return sortedPeers[:min(replicationFactor, len(peers))], nil + }, + } + msgSenderLk := sync.Mutex{} + addProviderRpcs := make(map[string]map[peer.ID]int) // key -> peerid -> count + provideCount := atomic.Int32{} + msgSender := &mockMsgSender{ + sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error { + msgSenderLk.Lock() + defer msgSenderLk.Unlock() + _, k, err := mh.MHFromBytes(m.GetKey()) + require.NoError(t, err) + if _, ok := addProviderRpcs[string(k)]; !ok { + addProviderRpcs[string(k)] = make(map[peer.ID]int) + } + addProviderRpcs[string(k)][p]++ + provideCount.Add(1) + return nil + }, + } + opts := []Option{ + WithReprovideInterval(reprovideInterval), + WithReplicationFactor(replicationFactor), + WithMaxWorkers(1), + WithDedicatedBurstWorkers(0), + WithDedicatedPeriodicWorkers(0), + WithPeerID(pid), + WithRouter(router), + WithMessageSender(msgSender), + WithSelfAddrs(func() []ma.Multiaddr { + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + require.NoError(t, err) + return []ma.Multiaddr{addr} + }), + } + prov, err := New(opts...) + require.NoError(t, err) + defer prov.Close() + synctest.Wait() + require.True(t, prov.connectivity.IsOnline()) + + err = prov.StartProviding(true, mhs...) + require.NoError(t, err) + synctest.Wait() + require.Equal(t, int32(len(mhs)*replicationFactor), provideCount.Load()) + + // Each key should have been provided at least once. + msgSenderLk.Lock() + require.Equal(t, nKeys, len(addProviderRpcs)) + for k, holders := range addProviderRpcs { + // Verify that all keys have been provided to exactly replicationFactor + // distinct peers. + require.Len(t, holders, replicationFactor) + for _, count := range holders { + require.Equal(t, 1, count) + } + // Verify provider records are assigned to the closest peers + closestPeers := kb.SortClosestPeers(peers, kb.ConvertKey(k))[:replicationFactor] + for _, p := range closestPeers { + require.Contains(t, holders, p) + } + } + + step := 10 * time.Second + // Test reprovides, clear addProviderRpcs + clear(addProviderRpcs) + msgSenderLk.Unlock() + for range reprovideInterval / step { + time.Sleep(step) + } + synctest.Wait() + require.Equal(t, 2*int32(len(mhs)*replicationFactor), provideCount.Load(), "should have reprovided all keys at least once") + + msgSenderLk.Lock() + require.Equal(t, nKeys, len(addProviderRpcs)) + for k, holders := range addProviderRpcs { + // Verify that all keys have been provided to exactly replicationFactor + // distinct peers. + require.Len(t, holders, replicationFactor, key.BitString(keyspace.MhToBit256([]byte(k)))) + for _, count := range holders { + require.Equal(t, 1, count) + } + // Verify provider records are assigned to the closest peers + closestPeers := kb.SortClosestPeers(peers, kb.ConvertKey(k))[:replicationFactor] + for _, p := range closestPeers { + require.Contains(t, holders, p) + } + } + + // Test reprovides again, clear addProviderRpcs + clear(addProviderRpcs) + msgSenderLk.Unlock() + for range reprovideInterval / step { + time.Sleep(step) + } + synctest.Wait() + require.Equal(t, 3*int32(len(mhs)*replicationFactor), provideCount.Load(), "should have reprovided all keys at least twice") + + msgSenderLk.Lock() + require.Equal(t, nKeys, len(addProviderRpcs)) + for k, holders := range addProviderRpcs { + // Verify that all keys have been provided to exactly replicationFactor + // distinct peers. + require.Len(t, holders, replicationFactor) + for _, count := range holders { + require.Equal(t, 1, count) + } + // Verify provider records are assigned to the closest peers + closestPeers := kb.SortClosestPeers(peers, kb.ConvertKey(k))[:replicationFactor] + for _, p := range closestPeers { + require.Contains(t, holders, p) + } + } + msgSenderLk.Unlock() + }) +} + +func TestStartProvidingUnstableNetwork(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + pid, err := peer.Decode("12BoooooPEER") + require.NoError(t, err) + + nKeysExponent := 10 + nKeys := 1 << nKeysExponent + mhs := genBalancedMultihashes(nKeysExponent) + + replicationFactor := 4 + peerPrefixBitlen := 6 + require.LessOrEqual(t, peerPrefixBitlen, bitsPerByte) + var nPeers byte = 1 << peerPrefixBitlen // 2**peerPrefixBitlen + peers := make([]peer.ID, nPeers) + for i := range nPeers { + b := i << (bitsPerByte - peerPrefixBitlen) + k := [32]byte{b} + peers[i], err = kb.GenRandPeerIDWithCPL(k[:], uint(peerPrefixBitlen)) + require.NoError(t, err) + } + + reprovideInterval := time.Hour + connectivityCheckInterval := time.Minute + offlineDelay := time.Hour + + routerOffline := atomic.Bool{} + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + if routerOffline.Load() { + return nil, errors.New("offline") + } + sortedPeers := kb.SortClosestPeers(peers, kb.ConvertKey(k)) + return sortedPeers[:min(replicationFactor, len(peers))], nil + }, + } + msgSenderLk := sync.Mutex{} + addProviderRpcs := make(map[string]map[peer.ID]int) // key -> peerid -> count + provideCount := atomic.Int32{} + msgSender := &mockMsgSender{ + sendMessageFunc: func(ctx context.Context, p peer.ID, m *pb.Message) error { + msgSenderLk.Lock() + defer msgSenderLk.Unlock() + if routerOffline.Load() { + return errors.New("offline") + } + _, k, err := mh.MHFromBytes(m.GetKey()) + require.NoError(t, err) + if _, ok := addProviderRpcs[string(k)]; !ok { + addProviderRpcs[string(k)] = make(map[peer.ID]int) + } + addProviderRpcs[string(k)][p]++ + provideCount.Add(1) + return nil + }, + } + opts := []Option{ + WithReprovideInterval(reprovideInterval), + WithReplicationFactor(replicationFactor), + WithMaxWorkers(1), + WithDedicatedBurstWorkers(0), + WithDedicatedPeriodicWorkers(0), + WithPeerID(pid), + WithRouter(router), + WithMessageSender(msgSender), + WithSelfAddrs(func() []ma.Multiaddr { + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + require.NoError(t, err) + return []ma.Multiaddr{addr} + }), + WithOfflineDelay(offlineDelay), + WithConnectivityCheckOnlineInterval(connectivityCheckInterval), + } + prov, err := New(opts...) + require.NoError(t, err) + defer prov.Close() + + synctest.Wait() + prov.avgPrefixLenLk.Lock() + require.Greater(t, prov.cachedAvgPrefixLen, 0) + prov.avgPrefixLenLk.Unlock() + + routerOffline.Store(true) + time.Sleep(connectivityCheckInterval) // wait for connectivity check to become available again + synctest.Wait() + + err = prov.StartProviding(true, mhs...) + require.NoError(t, err) + synctest.Wait() + require.Equal(t, int32(0), provideCount.Load(), "should not have provided when disconnected") + require.False(t, prov.connectivity.IsOnline()) + + routerOffline.Store(false) + time.Sleep(connectivityCheckInterval) // connectivity check triggered + synctest.Wait() + require.True(t, prov.connectivity.IsOnline()) + + msgSenderLk.Lock() + require.Equal(t, nKeys, len(addProviderRpcs)) + for _, peers := range addProviderRpcs { + // Verify that all keys have been provided to exactly replicationFactor + // distinct peers. + require.Len(t, peers, replicationFactor) + } + msgSenderLk.Unlock() + }) +} + +func TestAddToSchedule(t *testing.T) { + prov := SweepingProvider{ + reprovideInterval: time.Hour, + schedule: trie.New[bitstr.Key, time.Duration](), + scheduleTimer: time.NewTimer(time.Hour), + + cachedAvgPrefixLen: 4, + avgPrefixLenValidity: time.Minute, + lastAvgPrefixLen: time.Now(), + } + + ok, _ := trie.Find(prov.schedule, "0000") + + require.False(t, ok) + keys := genMultihashesMatchingPrefix("0000", 4) + prov.AddToSchedule(keys...) + ok, _ = trie.Find(prov.schedule, "0000") + require.True(t, ok) + require.Equal(t, 1, prov.schedule.Size()) + + // Nothing should have changed + prov.AddToSchedule(keys...) + ok, _ = trie.Find(prov.schedule, "0000") + require.True(t, ok) + require.Equal(t, 1, prov.schedule.Size()) + + keys = append(keys, append(genMultihashesMatchingPrefix("0111", 1), genMultihashesMatchingPrefix("1000", 3)...)...) + prov.AddToSchedule(keys...) + require.Equal(t, 3, prov.schedule.Size()) + ok, _ = trie.Find(prov.schedule, "0000") + require.True(t, ok) + ok, _ = trie.Find(prov.schedule, "0111") + require.True(t, ok) + ok, _ = trie.Find(prov.schedule, "1000") + require.True(t, ok) +} + +func TestRefreshSchedule(t *testing.T) { + ctx := context.Background() + mapDs := ds.NewMapDatastore() + defer mapDs.Close() + ks, err := keystore.NewKeystore(mapDs) + require.NoError(t, err) + + prov := SweepingProvider{ + ctx: ctx, + keystore: ks, + + reprovideInterval: time.Hour, + schedule: trie.New[bitstr.Key, time.Duration](), + scheduleTimer: time.NewTimer(time.Hour), + + cachedAvgPrefixLen: 4, + avgPrefixLenValidity: time.Minute, + lastAvgPrefixLen: time.Now(), + } + + // Schedule is empty + require.Equal(t, 0, prov.schedule.Size()) + prov.RefreshSchedule() + require.Equal(t, 0, prov.schedule.Size()) + + // Add key to keystore + k := genMultihashesMatchingPrefix("00000", 1)[0] + ks.Put(ctx, k) + + // Refresh schedule should add the key to the schedule + require.Equal(t, 0, prov.schedule.Size()) + prov.RefreshSchedule() + require.Equal(t, 1, prov.schedule.Size()) + ok, _ := trie.Find(prov.schedule, bitstr.Key("0000")) + require.True(t, ok) + + // Add another key starting with same prefix to keystore + k = genMultihashesMatchingPrefix("00001", 1)[0] + ks.Put(ctx, k) + prov.RefreshSchedule() + require.Equal(t, 1, prov.schedule.Size()) + ok, _ = trie.Find(prov.schedule, bitstr.Key("0000")) + require.True(t, ok) + + // Add multiple keys and verify associated prefixes are scheduled. + newPrefixes := []bitstr.Key{"0100", "0110", "0111"} + keys := make([]mh.Multihash, 0, len(newPrefixes)) + for _, p := range newPrefixes { + keys = append(keys, genMultihashesMatchingPrefix(p, 1)...) + } + ks.Put(ctx, keys...) + prov.RefreshSchedule() + // Assert that only the prefixes containing matching keys in the KeyStore + // have been added to the schedule. + require.Equal(t, 1+len(newPrefixes), prov.schedule.Size()) + for _, p := range newPrefixes { + ok, _ = trie.Find(prov.schedule, p) + require.True(t, ok) + } +} + +func TestOperationsOffline(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + pid, err := peer.Decode("12BoooooPEER") + require.NoError(t, err) + + checkInterval := time.Second + offlineDelay := time.Minute + + online := atomic.Bool{} // false, start offline + + router := &mockRouter{ + getClosestPeersFunc: func(ctx context.Context, k string) ([]peer.ID, error) { + if online.Load() { + return []peer.ID{pid}, nil + } + return nil, errors.New("offline") + }, + } + opts := []Option{ + WithReprovideInterval(time.Hour), + WithReplicationFactor(1), + WithMaxWorkers(1), + WithDedicatedBurstWorkers(0), + WithDedicatedPeriodicWorkers(0), + WithPeerID(pid), + WithRouter(router), + WithMessageSender(&mockMsgSender{}), + WithSelfAddrs(func() []ma.Multiaddr { + addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/4001") + require.NoError(t, err) + return []ma.Multiaddr{addr} + }), + WithOfflineDelay(offlineDelay), + WithConnectivityCheckOnlineInterval(checkInterval), + } + prov, err := New(opts...) + require.NoError(t, err) + defer prov.Close() + + k := random.Multihashes(1)[0] + + // Not bootstrapped yet, OFFLINE + err = prov.ProvideOnce(k) + require.ErrorIs(t, err, ErrOffline) + err = prov.StartProviding(false, k) + require.ErrorIs(t, err, ErrOffline) + err = prov.StartProviding(true, k) + require.ErrorIs(t, err, ErrOffline) + err = prov.RefreshSchedule() + require.ErrorIs(t, err, ErrOffline) + err = prov.AddToSchedule(k) + require.ErrorIs(t, err, ErrOffline) + err = prov.StopProviding(k) // no error for StopProviding + require.NoError(t, err) + + online.Store(true) + time.Sleep(checkInterval) // trigger connectivity check + synctest.Wait() + require.True(t, prov.connectivity.IsOnline()) + + // ONLINE, operations shouldn't error + err = prov.ProvideOnce(k) + require.NoError(t, err) + err = prov.StartProviding(false, k) + require.NoError(t, err) + err = prov.StartProviding(true, k) + require.NoError(t, err) + err = prov.RefreshSchedule() + require.NoError(t, err) + err = prov.AddToSchedule(k) + require.NoError(t, err) + err = prov.StopProviding(k) // no error for StopProviding + require.NoError(t, err) + + online.Store(false) + time.Sleep(checkInterval) // wait for connectivity check to finish + prov.connectivity.TriggerCheck() + synctest.Wait() + require.False(t, prov.connectivity.IsOnline()) + + // DISCONNECTED, operations shoudln't error until node is OFFLINE + err = prov.ProvideOnce(k) + require.NoError(t, err) + err = prov.StartProviding(false, k) + require.NoError(t, err) + err = prov.StartProviding(true, k) + require.NoError(t, err) + err = prov.RefreshSchedule() + require.NoError(t, err) + err = prov.AddToSchedule(k) + require.NoError(t, err) + err = prov.StopProviding(k) // no error for StopProviding + require.NoError(t, err) + + prov.provideQueue.Enqueue("0000", k) + require.Equal(t, 1, prov.provideQueue.Size()) + time.Sleep(offlineDelay) + synctest.Wait() + + // OFFLINE + // Verify that provide queue has been emptied by the onOffline callback + require.True(t, prov.provideQueue.IsEmpty()) + prov.avgPrefixLenLk.Lock() + require.Equal(t, -1, prov.cachedAvgPrefixLen) + prov.avgPrefixLenLk.Unlock() + + // All operations should error again + err = prov.ProvideOnce(k) + require.ErrorIs(t, err, ErrOffline) + err = prov.StartProviding(false, k) + require.ErrorIs(t, err, ErrOffline) + err = prov.StartProviding(true, k) + require.ErrorIs(t, err, ErrOffline) + err = prov.RefreshSchedule() + require.ErrorIs(t, err, ErrOffline) + err = prov.AddToSchedule(k) + require.ErrorIs(t, err, ErrOffline) + err = prov.StopProviding(k) // no error for StopProviding + require.NoError(t, err) + }) +} diff --git a/routing.go b/routing.go index 589bb5d20..1e4796f58 100644 --- a/routing.go +++ b/routing.go @@ -457,7 +457,7 @@ func (dht *IpfsDHT) classicProvide(ctx context.Context, keyMH multihash.Multihas logger.Debugf("putProvider(%s, %s)", internal.LoggableProviderRecordBytes(keyMH), p) err := dht.protoMessenger.PutProviderAddrs(ctx, p, keyMH, peer.AddrInfo{ ID: dht.self, - Addrs: dht.filterAddrs(dht.host.Addrs()), + Addrs: dht.FilteredAddrs(), }) if err != nil { logger.Debug(err)