Skip to content

Commit 61cfd12

Browse files
provider: AllEntries and AllValues helpers
1 parent 74c899f commit 61cfd12

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

provider/internal/helpers/trie.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package helpers
2+
3+
import (
4+
"github.com/probe-lab/go-libdht/kad"
5+
"github.com/probe-lab/go-libdht/kad/trie"
6+
)
7+
8+
// AllEntries returns all entries (key + value) stored in the trie `t` sorted
9+
// by their keys in the supplied `order`.
10+
func AllEntries[K0 kad.Key[K0], K1 kad.Key[K1], D any](t *trie.Trie[K0, D], order K1) []*trie.Entry[K0, D] {
11+
return allEntriesAtDepth(t, order, 0)
12+
}
13+
14+
func allEntriesAtDepth[K0 kad.Key[K0], K1 kad.Key[K1], D any](t *trie.Trie[K0, D], order K1, depth int) []*trie.Entry[K0, D] {
15+
if t == nil || t.IsEmptyLeaf() {
16+
return nil
17+
}
18+
if t.IsNonEmptyLeaf() {
19+
return []*trie.Entry[K0, D]{{Key: *t.Key(), Data: t.Data()}}
20+
}
21+
b := int(order.Bit(depth))
22+
return append(allEntriesAtDepth(t.Branch(b), order, depth+1),
23+
allEntriesAtDepth(t.Branch(1-b), order, depth+1)...)
24+
}
25+
26+
// AllValues returns all values stored in the trie `t` sorted by their keys in
27+
// the supplied `order`.
28+
func AllValues[K0 kad.Key[K0], K1 kad.Key[K1], D any](t *trie.Trie[K0, D], order K1) []D {
29+
entries := AllEntries(t, order)
30+
out := make([]D, len(entries))
31+
for i, entry := range entries {
32+
out[i] = entry.Data
33+
}
34+
return out
35+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package helpers
2+
3+
import (
4+
"testing"
5+
6+
"github.com/probe-lab/go-libdht/kad/key/bitstr"
7+
"github.com/probe-lab/go-libdht/kad/trie"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestAllEntries(t *testing.T) {
13+
tr := trie.New[bitstr.Key, string]()
14+
elements := []struct {
15+
key bitstr.Key
16+
fruit string
17+
}{
18+
{
19+
key: bitstr.Key("000"), // convert to Key if it’s a distinct type
20+
fruit: "apple",
21+
},
22+
{
23+
key: bitstr.Key("010"),
24+
fruit: "banana",
25+
},
26+
{
27+
key: bitstr.Key("101"),
28+
fruit: "cherry",
29+
},
30+
{
31+
key: bitstr.Key("111"),
32+
fruit: "durian",
33+
},
34+
}
35+
36+
for _, e := range elements {
37+
tr.Add(e.key, e.fruit)
38+
}
39+
40+
// Test in 0 -> 1 order
41+
entries := AllEntries(tr, bitstr.Key("000"))
42+
require.Equal(t, len(elements), len(entries))
43+
for i := range entries {
44+
require.Equal(t, entries[i].Key, elements[i].key)
45+
require.Equal(t, entries[i].Data, elements[i].fruit)
46+
}
47+
48+
// Test in reverse order (1 -> 0)
49+
entries = AllEntries(tr, bitstr.Key("111"))
50+
require.Equal(t, len(elements), len(entries))
51+
for i := range entries {
52+
require.Equal(t, entries[i].Key, elements[len(elements)-1-i].key)
53+
require.Equal(t, entries[i].Data, elements[len(elements)-1-i].fruit)
54+
}
55+
}

0 commit comments

Comments
 (0)