Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
crypto: decryptReuse
  • Loading branch information
cheggaaa committed Mar 23, 2026
commit 20ab7a31d6786accd928b9fa7e9c997d2c413226
10 changes: 6 additions & 4 deletions commonspace/object/tree/objecttree/objecttree.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate
ot.tree.IterateSkip(id, iterate)
return
}
var buf []byte
decrypt := func(c *Change) (decrypted []byte, err error) {
// the change is not encrypted
if c.ReadKeyId == "" {
Expand All @@ -664,7 +665,8 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate
err = fmt.Errorf("no data in change %s", c.Id)
return
}
decrypted, err = readKey.Decrypt(c.Data)
buf, err = readKey.DecryptReuse(buf, c.Data)
decrypted = buf
return
}

Expand All @@ -679,9 +681,9 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate
return iterate(c)
}

var decrypted []byte
decrypted, err = decrypt(c)
if err != nil {
decrypted, decErr := decrypt(c)
if decErr != nil {
err = decErr
return false
}

Expand Down
120 changes: 120 additions & 0 deletions commonspace/object/tree/objecttree/objecttree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/anyproto/any-sync/commonspace/object/acl/recordverifier"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/crypto"
)

var ctx = context.Background()
Expand Down Expand Up @@ -1957,3 +1958,122 @@ func TestObjectTree(t *testing.T) {
require.Equal(t, objTree.Heads(), otherTree.Heads())
})
}

func TestObjectTree_IterateFrom_DecryptReuse(t *testing.T) {
aclList, _ := prepareAclList(t)
exec := list.NewAclExecutor("spaceId")
err := exec.Execute("a.init::a")
require.NoError(t, err)
aAccount := exec.ActualAccounts()["a"]

store := createStore(ctx, t)
root, err := CreateObjectTreeRoot(ObjectTreeCreatePayload{
PrivKey: aAccount.Keys.SignKey,
ChangeType: "changeType",
ChangePayload: nil,
SpaceId: "spaceId",
IsEncrypted: true,
}, aAccount.Acl)
require.NoError(t, err)
headStorage, err := headstorage.New(ctx, store)
require.NoError(t, err)
treeStorage, err := CreateStorage(ctx, root, headStorage, store)
require.NoError(t, err)
tree, err := BuildKeyFilterableObjectTree(treeStorage, aAccount.Acl)
require.NoError(t, err)

_ = aclList
messages := []string{"first", "second", "third", "fourth", "fifth"}
for _, msg := range messages {
_, err = tree.AddContent(ctx, SignableChangeContent{
Data: []byte(msg),
Key: aAccount.Keys.SignKey,
IsSnapshot: false,
ShouldBeEncrypted: true,
DataType: mockDataType,
})
require.NoError(t, err)
}

var decryptedData [][]byte
err = tree.IterateRoot(func(change *Change, decrypted []byte) (any, error) {
cp := make([]byte, len(decrypted))
copy(cp, decrypted)
return string(cp), nil
}, func(change *Change) bool {
if change.Model != nil {
if s, ok := change.Model.(string); ok {
decryptedData = append(decryptedData, []byte(s))
}
}
return true
})
require.NoError(t, err)
require.Equal(t, len(messages), len(decryptedData))
for i, msg := range messages {
require.Equal(t, msg, string(decryptedData[i]))
}
}

func benchmarkIterateFrom(b *testing.B, numChanges int) {
key, err := crypto.NewRandomAES()
if err != nil {
b.Fatal(err)
}
keyId := "key1"

tr := new(Tree)
tr.AddFast(newSnapshot("root", ""))

for i := 0; i < numChanges; i++ {
id := fmt.Sprint(i + 1)
prevId := "root"
if i > 0 {
prevId = fmt.Sprint(i)
}
data := []byte(fmt.Sprintf("change data payload number %d with some content", i))
encrypted, encErr := key.Encrypt(data)
if encErr != nil {
b.Fatal(encErr)
}
ch := newChange(id, "root", prevId)
ch.ReadKeyId = keyId
ch.Data = encrypted
tr.Add(ch)
}

ot := &objectTree{
id: "root",
tree: tr,
keys: map[string]crypto.SymKey{keyId: key},
}
convert := func(change *Change, decrypted []byte) (any, error) {
return string(decrypted), nil
}
iterate := func(change *Change) bool {
return true
}

// clear cached models before benchmark
resetModels := func() {
tr.IterateSkip("root", func(c *Change) bool {
c.Model = nil
return true
})
}

b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
resetModels()
_ = ot.IterateFrom("root", convert, iterate)
}
}

func BenchmarkObjectTree_IterateFrom_Encrypted100(b *testing.B) {
benchmarkIterateFrom(b, 100)
}

func BenchmarkObjectTree_IterateFrom_Encrypted1000(b *testing.B) {
benchmarkIterateFrom(b, 1000)
}
18 changes: 18 additions & 0 deletions util/crypto/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ func (k *AESKey) Decrypt(ciphertext []byte) ([]byte, error) {
return plain, nil
}

// DecryptReuse is like Decrypt but reuses dst's underlying array to avoid allocation.
func (k *AESKey) DecryptReuse(dst, ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(k.raw[:KeyBytes])
if err != nil {
return nil, err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := ciphertext[:NonceBytes]
plain, err := aesgcm.Open(dst[:0], nonce, ciphertext[NonceBytes:], nil)
if err != nil {
return nil, err
}
return plain, nil
}

// Marshall marshalls the key into proto
func (k *AESKey) Marshall() ([]byte, error) {
msg := &cryptoproto.Key{
Expand Down
92 changes: 92 additions & 0 deletions util/crypto/aes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package crypto

import (
"crypto/rand"
"testing"

"github.com/stretchr/testify/require"
)

func TestAESKey_DecryptReuse(t *testing.T) {
key := NewAES()
messages := [][]byte{
[]byte("hello world"),
[]byte("short"),
[]byte("a much longer message that should test buffer growth behavior properly"),
[]byte("back to short"),
}

var buf []byte
for _, msg := range messages {
encrypted, err := key.Encrypt(msg)
require.NoError(t, err)

buf, err = key.DecryptReuse(buf, encrypted)
require.NoError(t, err)
require.Equal(t, msg, buf)

// verify matches regular Decrypt
plain, err := key.Decrypt(encrypted)
require.NoError(t, err)
require.Equal(t, plain, buf)
}
}

func TestAESKey_DecryptReuse_NilDst(t *testing.T) {
key := NewAES()
msg := []byte("test message")
encrypted, err := key.Encrypt(msg)
require.NoError(t, err)

result, err := key.DecryptReuse(nil, encrypted)
require.NoError(t, err)
require.Equal(t, msg, result)
}

func TestAESKey_DecryptReuse_BufferReuse(t *testing.T) {
key := NewAES()
msg := make([]byte, 256)
_, err := rand.Read(msg)
require.NoError(t, err)
encrypted, err := key.Encrypt(msg)
require.NoError(t, err)

// first call allocates
buf, err := key.DecryptReuse(nil, encrypted)
require.NoError(t, err)
firstPtr := &buf[:cap(buf)][cap(buf)-1]

// second call with same-size message should reuse underlying array
buf, err = key.DecryptReuse(buf, encrypted)
require.NoError(t, err)
secondPtr := &buf[:cap(buf)][cap(buf)-1]

require.Equal(t, firstPtr, secondPtr, "expected buffer reuse")
}

func BenchmarkAESKey_Decrypt(b *testing.B) {
key := NewAES()
msg := make([]byte, 1024)
rand.Read(msg)
encrypted, _ := key.Encrypt(msg)

b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = key.Decrypt(encrypted)
}
}

func BenchmarkAESKey_DecryptReuse(b *testing.B) {
key := NewAES()
msg := make([]byte, 1024)
rand.Read(msg)
encrypted, _ := key.Encrypt(msg)

var buf []byte
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
buf, _ = key.DecryptReuse(buf, encrypted)
}
}
2 changes: 2 additions & 0 deletions util/crypto/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ type SymKey interface {

// Decrypt decrypts the message and returns the result
Decrypt(message []byte) ([]byte, error)
// DecryptReuse is like Decrypt but reuses dst's underlying array to avoid allocation.
DecryptReuse(dst, message []byte) ([]byte, error)
// Encrypt encrypts the message and returns the result
Encrypt(message []byte) ([]byte, error)
// Marshall wraps key in proto encoding and marshalls it
Expand Down
Loading