Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Revert "fix: allocation handling for string in deserialize"
This reverts commit 4894e5e.
  • Loading branch information
proost committed Feb 6, 2026
commit 2a59f114871027c9b068bbfca39350c9cf3f2da8
50 changes: 16 additions & 34 deletions tuple/include/array_of_strings_sketch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,11 @@

namespace datasketches {

template<typename Allocator>
struct array_of_strings_types {
using string_allocator = typename std::allocator_traits<Allocator>::template rebind_alloc<char>;
using string_type = std::basic_string<char, std::char_traits<char>, string_allocator>;
using array_allocator = typename std::allocator_traits<Allocator>::template rebind_alloc<string_type>;
using array_of_strings = array<string_type, array_allocator>;
};

// default update policy for an array of strings
template<typename Allocator = std::allocator<char>>
template<typename Allocator = std::allocator<std::string>>
class default_array_of_strings_update_policy {
public:
using string_allocator = typename array_of_strings_types<Allocator>::string_allocator;
using string_type = typename array_of_strings_types<Allocator>::string_type;
using array_allocator = typename array_of_strings_types<Allocator>::array_allocator;
using array_of_strings = typename array_of_strings_types<Allocator>::array_of_strings;
using array_of_strings = array<std::string, Allocator>;

explicit default_array_of_strings_update_policy(const Allocator& allocator = Allocator());

Expand All @@ -59,12 +48,9 @@ class default_array_of_strings_update_policy {

// serializer/deserializer for an array of strings
// Requirements: all strings must be valid UTF-8 and array size must be <= 127.
template<typename Allocator = std::allocator<char>>
template<typename Allocator = std::allocator<std::string>>
struct default_array_of_strings_serde {
using string_allocator = typename array_of_strings_types<Allocator>::string_allocator;
using string_type = typename array_of_strings_types<Allocator>::string_type;
using array_allocator = typename array_of_strings_types<Allocator>::array_allocator;
using array_of_strings = typename array_of_strings_types<Allocator>::array_of_strings;
using array_of_strings = array<std::string, Allocator>;
using summary_allocator = typename std::allocator_traits<Allocator>::template rebind_alloc<array_of_strings>;

explicit default_array_of_strings_serde(const Allocator& allocator = Allocator());
Expand All @@ -80,29 +66,27 @@ struct default_array_of_strings_serde {
summary_allocator summary_allocator_;
static void check_num_nodes(uint8_t num_nodes);
static uint32_t compute_total_bytes(const array_of_strings& item);
static void check_utf8(const string_type& value);
static void check_utf8(const std::string& value);
};

/**
* Hashes an array of strings using ArrayOfStrings-compatible hashing.
*/
template<typename Allocator = std::allocator<char>>
uint64_t hash_array_of_strings_key(const typename array_of_strings_types<Allocator>::array_of_strings& key);
template<typename Allocator = std::allocator<std::string>>
uint64_t hash_array_of_strings_key(const array<std::string, Allocator>& key);

/**
* Extended class of compact_tuple_sketch for array of strings
* Requirements: all strings must be valid UTF-8 and array size must be <= 127.
*/
template<typename Allocator = std::allocator<char>>
template<typename Allocator = std::allocator<std::string>>
class compact_array_of_strings_tuple_sketch:
public compact_tuple_sketch<
typename array_of_strings_types<Allocator>::array_of_strings,
typename std::allocator_traits<Allocator>::template rebind_alloc<
typename array_of_strings_types<Allocator>::array_of_strings
>
array<std::string, Allocator>,
typename std::allocator_traits<Allocator>::template rebind_alloc<array<std::string, Allocator>>
> {
public:
using array_of_strings = typename array_of_strings_types<Allocator>::array_of_strings;
using array_of_strings = array<std::string, Allocator>;
using summary_allocator = typename std::allocator_traits<Allocator>::template rebind_alloc<array_of_strings>;
using Base = compact_tuple_sketch<array_of_strings, summary_allocator>;
using vector_bytes = typename Base::vector_bytes;
Expand Down Expand Up @@ -149,15 +133,13 @@ class compact_array_of_strings_tuple_sketch:
/**
* Convenience alias for update_tuple_sketch for array of strings
*/
template<typename Allocator = std::allocator<char>,
template<typename Allocator = std::allocator<std::string>,
typename Policy = default_array_of_strings_update_policy<Allocator>>
using update_array_of_strings_tuple_sketch = update_tuple_sketch<
typename array_of_strings_types<Allocator>::array_of_strings,
typename array_of_strings_types<Allocator>::array_of_strings,
array<std::string, Allocator>,
array<std::string, Allocator>,
Policy,
typename std::allocator_traits<Allocator>::template rebind_alloc<
typename array_of_strings_types<Allocator>::array_of_strings
>
typename std::allocator_traits<Allocator>::template rebind_alloc<array<std::string, Allocator>>
>;

/**
Expand All @@ -166,7 +148,7 @@ using update_array_of_strings_tuple_sketch = update_tuple_sketch<
* @param ordered optional flag to specify if an ordered sketch should be produced
* @return compact array of strings sketch
*/
template<typename Allocator = std::allocator<char>, typename Policy = default_array_of_strings_update_policy<Allocator>>
template<typename Allocator = std::allocator<std::string>, typename Policy = default_array_of_strings_update_policy<Allocator>>
compact_array_of_strings_tuple_sketch<Allocator> compact_array_of_strings_sketch(
const update_array_of_strings_tuple_sketch<Allocator, Policy>& sketch, bool ordered = true);

Expand Down
32 changes: 13 additions & 19 deletions tuple/include/array_of_strings_sketch_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,15 @@ default_array_of_strings_update_policy<Allocator>::default_array_of_strings_upda

template<typename Allocator>
auto default_array_of_strings_update_policy<Allocator>::create() const -> array_of_strings {
const string_type empty{string_allocator(allocator_)};
return array_of_strings(0, empty, array_allocator(allocator_));
return array_of_strings(0, "", allocator_);
}

template<typename Allocator>
void default_array_of_strings_update_policy<Allocator>::update(
array_of_strings& array, const array_of_strings& input
) const {
const auto length = static_cast<size_t>(input.size());
const string_type empty{string_allocator(allocator_)};
array = array_of_strings(static_cast<uint8_t>(length), empty, array_allocator(allocator_));
array = array_of_strings(static_cast<uint8_t>(length), "", allocator_);
for (size_t i = 0; i < length; ++i) array[i] = input[i];
}

Expand All @@ -52,18 +50,16 @@ void default_array_of_strings_update_policy<Allocator>::update(
array_of_strings& array, const array_of_strings* input
) const {
if (input == nullptr) {
const string_type empty{string_allocator(allocator_)};
array = array_of_strings(0, empty, array_allocator(allocator_));
array = array_of_strings(0, "", allocator_);
return;
}
const auto length = static_cast<size_t>(input->size());
const string_type empty{string_allocator(allocator_)};
array = array_of_strings(static_cast<uint8_t>(length), empty, array_allocator(allocator_));
array = array_of_strings(static_cast<uint8_t>(length), "", allocator_);
for (size_t i = 0; i < length; ++i) array[i] = (*input)[i];
}

template<typename Allocator>
uint64_t hash_array_of_strings_key(const typename array_of_strings_types<Allocator>::array_of_strings& key) {
uint64_t hash_array_of_strings_key(const array<std::string, Allocator>& key) {
// Matches Java Util.PRIME for ArrayOfStrings key hashing.
static constexpr uint64_t STRING_ARR_HASH_SEED = 0x7A3CCA71ULL;
XXHash64 hasher(STRING_ARR_HASH_SEED);
Expand Down Expand Up @@ -128,7 +124,7 @@ void default_array_of_strings_serde<Allocator>::serialize(
const uint8_t num_nodes = static_cast<uint8_t>(items[i].size());
write(os, total_bytes);
write(os, num_nodes);
const string_type* data = items[i].data();
const std::string* data = items[i].data();
for (uint8_t j = 0; j < num_nodes; ++j) {
check_utf8(data[j]);
const uint32_t length = static_cast<uint32_t>(data[j].size());
Expand All @@ -148,12 +144,11 @@ void default_array_of_strings_serde<Allocator>::deserialize(
const uint8_t num_nodes = read<uint8_t>(is);
if (!is) throw std::runtime_error("array_of_strings stream read failed");
check_num_nodes(num_nodes);
const string_type empty{string_allocator(allocator_)};
array_of_strings array(num_nodes, empty, array_allocator(allocator_));
array_of_strings array(num_nodes, "", allocator_);
for (uint8_t j = 0; j < num_nodes; ++j) {
const uint32_t length = read<uint32_t>(is);
if (!is) throw std::runtime_error("array_of_strings stream read failed");
string_type value(length, '\0', string_allocator(allocator_));
std::string value(length, '\0');
if (length != 0) {
is.read(&value[0], length);
if (!is) throw std::runtime_error("array_of_strings stream read failed");
Expand All @@ -179,7 +174,7 @@ size_t default_array_of_strings_serde<Allocator>::serialize(
check_memory_size(bytes_written + total_bytes, capacity);
bytes_written += copy_to_mem(total_bytes, ptr8 + bytes_written);
bytes_written += copy_to_mem(num_nodes, ptr8 + bytes_written);
const string_type* data = items[i].data();
const std::string* data = items[i].data();
for (uint8_t j = 0; j < num_nodes; ++j) {
check_utf8(data[j]);
const uint32_t length = static_cast<uint32_t>(data[j].size());
Expand Down Expand Up @@ -207,12 +202,11 @@ size_t default_array_of_strings_serde<Allocator>::deserialize(
uint8_t num_nodes;
bytes_read += copy_from_mem(ptr8 + bytes_read, num_nodes);
check_num_nodes(num_nodes);
const string_type empty{string_allocator(allocator_)};
array_of_strings array(num_nodes, empty, array_allocator(allocator_));
array_of_strings array(num_nodes, "", allocator_);
for (uint8_t j = 0; j < num_nodes; ++j) {
uint32_t length;
bytes_read += copy_from_mem(ptr8 + bytes_read, length);
string_type value(length, '\0', string_allocator(allocator_));
std::string value(length, '\0');
if (length != 0) {
bytes_read += copy_from_mem(ptr8 + bytes_read, &value[0], length);
}
Expand Down Expand Up @@ -242,15 +236,15 @@ uint32_t default_array_of_strings_serde<Allocator>::compute_total_bytes(const ar
const auto count = item.size();
check_num_nodes(static_cast<uint8_t>(count));
size_t total = sizeof(uint32_t) + sizeof(uint8_t) + count * sizeof(uint32_t);
const string_type* data = item.data();
const std::string* data = item.data();
for (uint32_t j = 0; j < count; ++j) {
total += data[j].size();
}
return static_cast<uint32_t>(total);
}

template<typename Allocator>
void default_array_of_strings_serde<Allocator>::check_utf8(const string_type& value) {
void default_array_of_strings_serde<Allocator>::check_utf8(const std::string& value) {
if (!utf8::is_valid(value.begin(), value.end())) {
throw std::runtime_error("array_of_strings contains invalid UTF-8");
}
Expand Down
11 changes: 2 additions & 9 deletions tuple/test/aos_sketch_deserialize_from_java_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,13 @@
* under the License.
*/

#include <algorithm>
#include <catch2/catch.hpp>
#include <fstream>
#include <vector>

#include "array_of_strings_sketch.hpp"

namespace datasketches {
using types = array_of_strings_types<std::allocator<char>>;
using string_type = types::string_type;

static bool equals_string(const string_type& lhs, const std::string& rhs) {
return lhs.size() == rhs.size() && std::equal(lhs.begin(), lhs.end(), rhs.begin());
}
// assume the binary sketches for this test have been generated by datasketches-java code
// in the subdirectory called "java" in the root directory of this project
static std::string testBinaryInputPath = std::string(TEST_BINARY_INPUT_PATH) + "../../java/";
Expand Down Expand Up @@ -200,7 +193,7 @@ namespace datasketches {
if (entry.second.size() != expected.size()) continue;
bool equal = true;
for (size_t j = 0; j < expected.size(); ++j) {
if (!equals_string(entry.second[j], expected[j])) {
if (entry.second[j] != expected[j]) {
equal = false;
break;
}
Expand Down Expand Up @@ -255,7 +248,7 @@ namespace datasketches {
if (entry.second.size() != expected.size()) continue;
bool equal = true;
for (size_t j = 0; j < expected.size(); ++j) {
if (!equals_string(entry.second[j], expected[j])) {
if (entry.second[j] != expected[j]) {
equal = false;
break;
}
Expand Down
58 changes: 20 additions & 38 deletions tuple/test/aos_sketch_serialize_for_java.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,13 @@
namespace datasketches {

using aos_sketch = update_array_of_strings_tuple_sketch<>;
using types = array_of_strings_types<std::allocator<char>>;
using array_of_strings = types::array_of_strings;
using string_allocator = types::string_allocator;
using string_type = types::string_type;
using array_allocator = types::array_allocator;
using array_of_strings = array<std::string>;

static array_of_strings make_array(std::initializer_list<std::string> items) {
const string_type empty{string_allocator()};
array_of_strings array(static_cast<uint8_t>(items.size()), empty, array_allocator());
array_of_strings array(static_cast<uint8_t>(items.size()), "");
size_t i = 0;
for (const auto& item: items) {
array[static_cast<uint8_t>(i)] = string_type(item.data(), item.size(), string_allocator());
array[static_cast<uint8_t>(i)] = item;
++i;
}
return array;
Expand All @@ -48,13 +43,10 @@ TEST_CASE("aos sketch generate one value", "[serialize_for_java]") {
for (const unsigned n: n_arr) {
auto sketch = aos_sketch::builder().build();
for (unsigned i = 0; i < n; ++i) {
const string_type empty{string_allocator()};
array_of_strings key(1, empty, array_allocator());
const std::string key_value = std::to_string(i);
key[0] = string_type(key_value.data(), key_value.size(), string_allocator());
array_of_strings value(1, empty, array_allocator());
const std::string value_str = "value" + std::to_string(i);
value[0] = string_type(value_str.data(), value_str.size(), string_allocator());
array_of_strings key(1, "");
key[0] = std::to_string(i);
array_of_strings value(1, "");
value[0] = "value" + std::to_string(i);
sketch.update(hash_array_of_strings_key(key), value);
}
REQUIRE(sketch.is_empty() == (n == 0));
Expand All @@ -69,17 +61,12 @@ TEST_CASE("aos sketch generate three values", "[serialize_for_java]") {
for (const unsigned n: n_arr) {
auto sketch = aos_sketch::builder().build();
for (unsigned i = 0; i < n; ++i) {
const string_type empty{string_allocator()};
array_of_strings key(1, empty, array_allocator());
const std::string key_value = std::to_string(i);
key[0] = string_type(key_value.data(), key_value.size(), string_allocator());
array_of_strings value(3, empty, array_allocator());
const std::string value_a = "a" + std::to_string(i);
const std::string value_b = "b" + std::to_string(i);
const std::string value_c = "c" + std::to_string(i);
value[0] = string_type(value_a.data(), value_a.size(), string_allocator());
value[1] = string_type(value_b.data(), value_b.size(), string_allocator());
value[2] = string_type(value_c.data(), value_c.size(), string_allocator());
array_of_strings key(1, "");
key[0] = std::to_string(i);
array_of_strings value(3, "");
value[0] = "a" + std::to_string(i);
value[1] = "b" + std::to_string(i);
value[2] = "c" + std::to_string(i);
sketch.update(hash_array_of_strings_key(key), value);
}
REQUIRE(sketch.is_empty() == (n == 0));
Expand All @@ -95,10 +82,9 @@ TEST_CASE("aos sketch generate non-empty no entries", "[serialize_for_java]") {
.set_resize_factor(resize_factor::X8)
.set_p(0.01f)
.build();
const string_type empty{string_allocator()};
array_of_strings key(1, empty, array_allocator());
array_of_strings key(1, "");
key[0] = "key1";
array_of_strings value(1, empty, array_allocator());
array_of_strings value(1, "");
value[0] = "value1";
sketch.update(hash_array_of_strings_key(key), value);
REQUIRE_FALSE(sketch.is_empty());
Expand All @@ -112,15 +98,11 @@ TEST_CASE("aos sketch generate multi key strings", "[serialize_for_java]") {
for (const unsigned n: n_arr) {
auto sketch = aos_sketch::builder().build();
for (unsigned i = 0; i < n; ++i) {
const string_type empty{string_allocator()};
array_of_strings key(2, empty, array_allocator());
const std::string key0 = "key" + std::to_string(i);
const std::string key1 = "subkey" + std::to_string(i % 10);
key[0] = string_type(key0.data(), key0.size(), string_allocator());
key[1] = string_type(key1.data(), key1.size(), string_allocator());
array_of_strings value(1, empty, array_allocator());
const std::string value_str = "value" + std::to_string(i);
value[0] = string_type(value_str.data(), value_str.size(), string_allocator());
array_of_strings key(2, "");
key[0] = "key" + std::to_string(i);
key[1] = "subkey" + std::to_string(i % 10);
array_of_strings value(1, "");
value[0] = "value" + std::to_string(i);
sketch.update(hash_array_of_strings_key(key), value);
}
REQUIRE(sketch.is_empty() == (n == 0));
Expand Down
Loading