From f14c0e563b75590e6c4d1a0c82dfbd7e833e23b5 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Tue, 16 Sep 2025 09:16:09 -0700 Subject: [PATCH] Enhance core string types with built-in manipulation methods by integrating common string operations, previously in extensions, as methods on the `StringValue` class. PiperOrigin-RevId: 807720233 --- MODULE.bazel | 2 +- common/BUILD | 3 + common/value.h | 70 ++++++ common/values/string_value.cc | 354 +++++++++++++++++++++++++++++ common/values/string_value.h | 42 ++++ common/values/string_value_test.cc | 48 ++++ extensions/BUILD | 4 +- extensions/strings.cc | 213 ++--------------- extensions/strings_test.cc | 2 + 9 files changed, 541 insertions(+), 197 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index c193f5f5b..19fc67613 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -63,7 +63,7 @@ bazel_dep( ) bazel_dep( name = "cel-spec", - version = "0.23.0", + version = "0.24.0", repo_name = "com_google_cel_spec", ) diff --git a/common/BUILD b/common/BUILD index d800b36be..dd41f145d 100644 --- a/common/BUILD +++ b/common/BUILD @@ -628,10 +628,12 @@ cc_library( ":native_type", ":optional_ref", ":type", + ":typeinfo", ":unknown", ":value_kind", "//base:attributes", "//common/internal:byte_string", + "//common/internal:reference_count", "//eval/internal:cel_value_equal", "//eval/public:cel_value", "//eval/public:message_wrapper", @@ -656,6 +658,7 @@ cc_library( "//internal:utf8", "//internal:well_known_types", "//runtime:runtime_options", + "//runtime/internal:errors", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", diff --git a/common/value.h b/common/value.h index 0e38646a7..d24490a93 100644 --- a/common/value.h +++ b/common/value.h @@ -39,6 +39,7 @@ #include "common/native_type.h" #include "common/optional_ref.h" #include "common/type.h" +#include "common/typeinfo.h" #include "common/value_kind.h" #include "common/values/bool_value.h" // IWYU pragma: export #include "common/values/bytes_value.h" // IWYU pragma: export @@ -2537,6 +2538,75 @@ ErrorValueAssign::operator()(absl::Status status) const { return common_internal::ImplicitlyConvertibleStatus(); } +inline absl::StatusOr StringValue::Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR( + Join(list, descriptor_pool, message_factory, arena, &result)); + return result; +} + +inline absl::StatusOr StringValue::Split( + const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(Split(delimiter, limit, arena, &result)); + return result; +} + +inline absl::Status StringValue::Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return Split(delimiter, /*limit=*/-1, arena, result); +} + +inline absl::StatusOr StringValue::Split( + const StringValue& delimiter, google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + return Split(delimiter, /*limit=*/-1, arena); +} + +inline absl::StatusOr StringValue::Replace( + const StringValue& needle, const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + Value result; + CEL_RETURN_IF_ERROR(Replace(needle, replacement, limit, arena, &result)); + return result; +} + +inline absl::Status StringValue::Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + return Replace(needle, replacement, /*limit=*/-1, arena, result); +} + +inline absl::StatusOr StringValue::Replace( + const StringValue& needle, const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + return Replace(needle, replacement, /*limit=*/-1, arena); +} + namespace common_internal { template diff --git a/common/values/string_value.cc b/common/values/string_value.cc index ba065d275..5e6d954a8 100644 --- a/common/values/string_value.cc +++ b/common/values/string_value.cc @@ -13,24 +13,33 @@ // limitations under the License. #include +#include #include +#include #include +#include +#include #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/strings/ascii.h" #include "absl/strings/cord.h" +#include "absl/strings/cord_buffer.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "common/internal/byte_string.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "internal/utf8.h" #include "internal/well_known_types.h" +#include "runtime/internal/errors.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/io/zero_copy_stream.h" @@ -219,4 +228,349 @@ bool StringValue::Contains(const StringValue& string) const { [&](const absl::Cord& rhs) -> bool { return Contains(rhs); })); } +namespace { + +bool LowerAsciiImpl(absl::string_view in, std::string* absl_nonnull out) { + if (in.empty()) { + return false; + } + size_t pos; + for (pos = 0; pos < in.size(); ++pos) { + if (absl::ascii_isupper(in[pos])) { + break; + } + } + if (pos == in.size()) { + return false; + } + out->resize(in.size()); + char* out_data = out->data(); + for (size_t i = 0; i < in.size(); ++i) { + out_data[i] = absl::ascii_tolower(in[i]); + } + return true; +} + +absl::Cord LowerAsciiImpl(const absl::Cord& in) { + if (in.empty()) { + return in; + } + size_t pos; + absl::Cord::CharIterator begin = in.char_begin(); + absl::Cord::CharIterator end = in.char_end(); + for (pos = 0; begin != end; ++pos, ++begin) { + if (absl::ascii_isupper(*begin)) { + break; + } + } + if (begin == end) { + return in; + } + absl::Cord out = in.Subcord(0, pos); + size_t n = in.size() - pos; + bool first = true; + while (begin != end) { + absl::CordBuffer buffer = first + ? out.GetAppendBuffer(n) + : absl::CordBuffer::CreateWithDefaultLimit(n); + absl::Span data = buffer.available_up_to(n); + size_t i; + for (i = 0; i < data.size() && begin != end; ++i, ++begin) { + data[i] = absl::ascii_tolower(*begin); + } + buffer.IncreaseLengthBy(i); + out.Append(std::move(buffer)); + n -= i; + first = false; + } + return out; +} + +} // namespace + +StringValue StringValue::LowerAscii(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + std::string out; + if (!(LowerAsciiImpl)(value_.GetSmall(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kMedium: { + std::string out; + if (!(LowerAsciiImpl)(value_.GetMedium(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kLarge: + return StringValue::Wrap((LowerAsciiImpl)(value_.GetLarge())); + } +} + +namespace { + +bool UpperAsciiImpl(absl::string_view in, std::string* absl_nonnull out) { + if (in.empty()) { + return false; + } + size_t pos; + for (pos = 0; pos < in.size(); ++pos) { + if (absl::ascii_islower(in[pos])) { + break; + } + } + if (pos == in.size()) { + return false; + } + out->resize(in.size()); + char* out_data = out->data(); + for (size_t i = 0; i < in.size(); ++i) { + out_data[i] = absl::ascii_toupper(in[i]); + } + return true; +} + +absl::Cord UpperAsciiImpl(const absl::Cord& in) { + if (in.empty()) { + return in; + } + size_t pos; + absl::Cord::CharIterator begin = in.char_begin(); + absl::Cord::CharIterator end = in.char_end(); + for (pos = 0; begin != end; ++pos, ++begin) { + if (absl::ascii_islower(*begin)) { + break; + } + } + if (begin == end) { + return in; + } + absl::Cord out = in.Subcord(0, pos); + size_t n = in.size() - pos; + bool first = true; + while (begin != end) { + absl::CordBuffer buffer = first + ? out.GetAppendBuffer(n) + : absl::CordBuffer::CreateWithDefaultLimit(n); + absl::Span data = buffer.available_up_to(n); + size_t i; + for (i = 0; i < data.size() && begin != end; ++i, ++begin) { + data[i] = absl::ascii_toupper(*begin); + } + buffer.IncreaseLengthBy(i); + out.Append(std::move(buffer)); + n -= i; + first = false; + } + return out; +} + +} // namespace + +StringValue StringValue::UpperAscii(google::protobuf::Arena* absl_nonnull arena) const { + ABSL_DCHECK(arena != nullptr); + + switch (value_.GetKind()) { + case common_internal::ByteStringKind::kSmall: { + std::string out; + if (!(UpperAsciiImpl)(value_.GetSmall(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kMedium: { + std::string out; + if (!(UpperAsciiImpl)(value_.GetMedium(), &out)) { + return *this; + } + return StringValue::From(std::move(out), arena); + } + case common_internal::ByteStringKind::kLarge: + return StringValue::Wrap((UpperAsciiImpl)(value_.GetLarge())); + } +} + +absl::Status StringValue::Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { + ABSL_DCHECK(descriptor_pool != nullptr); + ABSL_DCHECK(message_factory != nullptr); + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + std::string joined; + + CEL_ASSIGN_OR_RETURN(auto iterator, list.NewIterator()); + + CEL_ASSIGN_OR_RETURN( + absl::optional element, + iterator->Next1(descriptor_pool, message_factory, arena)); + if (element) { + if (auto string_element = element->AsString(); string_element) { + string_element->AppendToString(&joined); + } else { + *result = + ErrorValue(runtime_internal::CreateNoMatchingOverloadError("join")); + return absl::OkStatus(); + } + while (true) { + CEL_ASSIGN_OR_RETURN( + element, iterator->Next1(descriptor_pool, message_factory, arena)); + if (!element) { + break; + } + AppendToString(&joined); + if (auto string_element = element->AsString(); string_element) { + string_element->AppendToString(&joined); + } else { + *result = + ErrorValue(runtime_internal::CreateNoMatchingOverloadError("join")); + return absl::OkStatus(); + } + } + } + + if (joined.size() > common_internal::kSmallByteStringCapacity) { + joined.shrink_to_fit(); + } + + *result = StringValue::From(std::move(joined), arena); + return absl::OkStatus(); +} + +absl::Status StringValue::Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (limit == 0) { + // Per spec, when limit is 0 return an empty list. + *result = ListValue(); + return absl::OkStatus(); + } + if (limit < 0) { + // Per spec, when limit is negative treat it as unlimited splits. + limit = std::numeric_limits::max(); + } + + std::vector> splits; + + size_t pos = 0; + const size_t len = value_.size(); + + while (pos < len && limit > 1) { + if (delimiter.IsEmpty()) { + if (pos >= len) { + break; + } + size_t char_len = 1; + value_.Visit(absl::Overload( + [&](absl::string_view s) { + char_len = cel::internal::Utf8Decode(s.substr(pos), nullptr); + }, + [&](const absl::Cord& s) { + char_len = cel::internal::Utf8Decode( + s.Subcord(pos, len - pos).char_begin(), nullptr); + })); + splits.push_back({pos, pos + char_len}); + pos += char_len; + --limit; + continue; + } + absl::optional next = value_.Find(delimiter.value_, pos); + if (!next) { + break; + } + splits.push_back(std::pair{pos, *next}); + pos = *next + delimiter.value_.size(); + --limit; + ABSL_DCHECK_LE(pos, len); + } + + if (splits.empty() || !delimiter.IsEmpty() || pos < len) { + splits.push_back(std::pair{pos, len}); + } + + auto builder = NewListValueBuilder(arena); + builder->Reserve(splits.size()); + for (const std::pair& split : splits) { + builder->UnsafeAdd( + StringValue(value_.Substring(split.first, split.second))); + } + *result = std::move(*builder).Build(); + return absl::OkStatus(); +} + +absl::Status StringValue::Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const { + ABSL_DCHECK(arena != nullptr); + ABSL_DCHECK(result != nullptr); + + if (limit == 0) { + // Per spec, when limit is 0 return the original string. + *result = *this; + return absl::OkStatus(); + } + if (limit < 0) { + // Per spec, when limit is negative treat it as unlimited replacements. + limit = std::numeric_limits::max(); + } + + size_t pos = 0; + const size_t len = value_.size(); + const size_t needle_len = needle.value_.size(); + std::string res_str; + res_str.reserve(len); + + while (pos < len && limit > 0) { + if (needle.IsEmpty()) { + replacement.AppendToString(&res_str); + + size_t char_len = 0; + value_.Visit(absl::Overload( + [&](absl::string_view s) { + char_len = cel::internal::Utf8Decode(s.substr(pos), nullptr); + }, + [&](const absl::Cord& s) { + char_len = cel::internal::Utf8Decode( + s.Subcord(pos, len - pos).char_begin(), nullptr); + })); + value_.Substring(pos, char_len).AppendToString(&res_str); + pos += char_len; + --limit; + continue; + } + absl::optional next = value_.Find(needle.value_, pos); + if (!next) { + break; + } + + value_.Substring(pos, *next).AppendToString(&res_str); + replacement.AppendToString(&res_str); + + pos = *next + needle_len; + --limit; + } + + if (needle.IsEmpty() && limit > 0) { + replacement.AppendToString(&res_str); + } + + if (pos < len) { + value_.Substring(pos, len).AppendToString(&res_str); + } + + *result = StringValue::From(std::move(res_str), arena); + return absl::OkStatus(); +} + } // namespace cel diff --git a/common/values/string_value.h b/common/values/string_value.h index f7dcfc8d1..205b54b9c 100644 --- a/common/values/string_value.h +++ b/common/values/string_value.h @@ -28,6 +28,7 @@ #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -208,6 +209,47 @@ class StringValue final : private common_internal::ValueMixin { bool Contains(const absl::Cord& string) const; bool Contains(const StringValue& string) const; + StringValue LowerAscii(google::protobuf::Arena* absl_nonnull arena) const; + + StringValue UpperAscii(google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Join(const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Join( + const ListValue& list, + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + google::protobuf::MessageFactory* absl_nonnull message_factory, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Split(const StringValue& delimiter, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const; + absl::Status Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Split(const StringValue& delimiter, + google::protobuf::Arena* absl_nonnull arena) const; + + absl::Status Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Replace(const StringValue& needle, + const StringValue& replacement, int64_t limit, + google::protobuf::Arena* absl_nonnull arena) const; + absl::Status Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena, + Value* absl_nonnull result) const; + absl::StatusOr Replace(const StringValue& needle, + const StringValue& replacement, + google::protobuf::Arena* absl_nonnull arena) const; + absl::optional TryFlat() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.TryFlat(); diff --git a/common/values/string_value_test.cc b/common/values/string_value_test.cc index 244fd3f7e..d92580b82 100644 --- a/common/values/string_value_test.cc +++ b/common/values/string_value_test.cc @@ -208,5 +208,53 @@ TEST_F(StringValueTest, Contains) { .Contains(StringValue(absl::Cord("string is large enough")))); } +TEST_F(StringValueTest, LowerAscii) { + EXPECT_EQ(StringValue("UPPER lower").LowerAscii(arena()), "upper lower"); + EXPECT_EQ(StringValue(absl::Cord("UPPER lower")).LowerAscii(arena()), + "upper lower"); + EXPECT_EQ(StringValue("upper lower").LowerAscii(arena()), "upper lower"); + EXPECT_EQ(StringValue(absl::Cord("upper lower")).LowerAscii(arena()), + "upper lower"); + EXPECT_EQ(StringValue("").LowerAscii(arena()), ""); + EXPECT_EQ(StringValue(absl::Cord("")).LowerAscii(arena()), ""); + const std::string kLongMixed = + "A long STRING with MiXeD case to test conversion to lower case!"; + const std::string kLongLower = + "a long string with mixed case to test conversion to lower case!"; + EXPECT_EQ(StringValue(absl::Cord(kLongMixed)).LowerAscii(arena()), + kLongLower); + std::string very_long_mixed(10000, 'A'); + std::string very_long_lower(10000, 'a'); + EXPECT_EQ( + StringValue(absl::MakeFragmentedCord({very_long_mixed.substr(0, 5000), + very_long_mixed.substr(5000)})) + .LowerAscii(arena()), + very_long_lower); +} + +TEST_F(StringValueTest, UpperAscii) { + EXPECT_EQ(StringValue("UPPER lower").UpperAscii(arena()), "UPPER LOWER"); + EXPECT_EQ(StringValue(absl::Cord("UPPER lower")).UpperAscii(arena()), + "UPPER LOWER"); + EXPECT_EQ(StringValue("UPPER LOWER").UpperAscii(arena()), "UPPER LOWER"); + EXPECT_EQ(StringValue(absl::Cord("UPPER LOWER")).UpperAscii(arena()), + "UPPER LOWER"); + EXPECT_EQ(StringValue("").UpperAscii(arena()), ""); + EXPECT_EQ(StringValue(absl::Cord("")).UpperAscii(arena()), ""); + const std::string kLongMixed = + "A long STRING with MiXeD case to test conversion to UPPER case!"; + const std::string kLongUpper = + "A LONG STRING WITH MIXED CASE TO TEST CONVERSION TO UPPER CASE!"; + EXPECT_EQ(StringValue(absl::Cord(kLongMixed)).UpperAscii(arena()), + kLongUpper); + std::string very_long_mixed(10000, 'a'); + std::string very_long_upper(10000, 'A'); + EXPECT_EQ( + StringValue(absl::MakeFragmentedCord({very_long_mixed.substr(0, 5000), + very_long_mixed.substr(5000)})) + .UpperAscii(arena()), + very_long_upper); +} + } // namespace } // namespace cel diff --git a/extensions/BUILD b/extensions/BUILD index 52d25a888..696785c1c 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -511,16 +511,13 @@ cc_library( "//eval/public:cel_function_registry", "//eval/public:cel_options", "//internal:status_macros", - "//internal:utf8", "//runtime:function_adapter", "//runtime:function_registry", "//runtime:runtime_options", - "//runtime/internal:errors", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", @@ -536,6 +533,7 @@ cc_test( "//checker:type_checker_builder", "//checker:validation_result", "//common:decl", + "//common:type", "//common:value", "//compiler:compiler_factory", "//compiler:standard_library", diff --git a/extensions/strings.cc b/extensions/strings.cc index 3f9c73a33..110ff0b37 100644 --- a/extensions/strings.cc +++ b/extensions/strings.cc @@ -14,18 +14,15 @@ #include "extensions/strings.h" -#include #include #include #include -#include #include #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/ascii.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "checker/internal/builtins_arena.h" @@ -37,10 +34,8 @@ #include "eval/public/cel_options.h" #include "extensions/formatting.h" #include "internal/status_macros.h" -#include "internal/utf8.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" -#include "runtime/internal/errors.h" #include "runtime/runtime_options.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -67,35 +62,7 @@ absl::StatusOr Join2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - std::string result; - CEL_ASSIGN_OR_RETURN(auto iterator, value.NewIterator()); - Value element; - if (iterator->HasNext()) { - CEL_RETURN_IF_ERROR( - iterator->Next(descriptor_pool, message_factory, arena, &element)); - if (auto string_element = element.AsString(); string_element) { - string_element->NativeValue(AppendToStringVisitor{result}); - } else { - return ErrorValue{ - runtime_internal::CreateNoMatchingOverloadError("join")}; - } - } - std::string separator_scratch; - absl::string_view separator_view = separator.NativeString(separator_scratch); - while (iterator->HasNext()) { - result.append(separator_view); - CEL_RETURN_IF_ERROR( - iterator->Next(descriptor_pool, message_factory, arena, &element)); - if (auto string_element = element.AsString(); string_element) { - string_element->NativeValue(AppendToStringVisitor{result}); - } else { - return ErrorValue{ - runtime_internal::CreateNoMatchingOverloadError("join")}; - } - } - result.shrink_to_fit(); - // We assume the original string was well-formed. - return StringValue(arena, std::move(result)); + return separator.Join(value, descriptor_pool, message_factory, arena); } absl::StatusOr Join1( @@ -103,117 +70,15 @@ absl::StatusOr Join1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - return Join2(value, StringValue{}, descriptor_pool, message_factory, arena); + return StringValue().Join(value, descriptor_pool, message_factory, arena); } -struct SplitWithEmptyDelimiter { - google::protobuf::Arena* absl_nonnull arena; - int64_t& limit; - ListValueBuilder& builder; - - absl::StatusOr operator()(absl::string_view string) const { - char32_t rune; - size_t count; - std::string buffer; - buffer.reserve(4); - while (!string.empty() && limit > 1) { - std::tie(rune, count) = internal::Utf8Decode(string); - buffer.clear(); - internal::Utf8Encode(buffer, rune); - CEL_RETURN_IF_ERROR( - builder.Add(StringValue(arena, absl::string_view(buffer)))); - --limit; - string.remove_prefix(count); - } - if (!string.empty()) { - CEL_RETURN_IF_ERROR(builder.Add(StringValue(arena, string))); - } - return std::move(builder).Build(); - } - - absl::StatusOr operator()(const absl::Cord& string) const { - auto begin = string.char_begin(); - auto end = string.char_end(); - char32_t rune; - size_t count; - std::string buffer; - while (begin != end && limit > 1) { - std::tie(rune, count) = internal::Utf8Decode(begin); - buffer.clear(); - internal::Utf8Encode(buffer, rune); - CEL_RETURN_IF_ERROR( - builder.Add(StringValue(arena, absl::string_view(buffer)))); - --limit; - absl::Cord::Advance(&begin, count); - } - if (begin != end) { - buffer.clear(); - while (begin != end) { - auto chunk = absl::Cord::ChunkRemaining(begin); - buffer.append(chunk); - absl::Cord::Advance(&begin, chunk.size()); - } - buffer.shrink_to_fit(); - CEL_RETURN_IF_ERROR(builder.Add(StringValue(arena, std::move(buffer)))); - } - return std::move(builder).Build(); - } -}; - absl::StatusOr Split3( const StringValue& string, const StringValue& delimiter, int64_t limit, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - if (limit == 0) { - // Per spec, when limit is 0 return an empty list. - return ListValue{}; - } - if (limit < 0) { - // Per spec, when limit is negative treat is as unlimited. - limit = std::numeric_limits::max(); - } - auto builder = NewListValueBuilder(arena); - if (string.IsEmpty()) { - // If string is empty, it doesn't matter what the delimiter is or the limit. - // We just return a list with a single empty string. - builder->Reserve(1); - CEL_RETURN_IF_ERROR(builder->Add(StringValue{})); - return std::move(*builder).Build(); - } - if (delimiter.IsEmpty()) { - // If the delimiter is empty, we split between every code point. - return string.NativeValue(SplitWithEmptyDelimiter{arena, limit, *builder}); - } - // At this point we know the string is not empty and the delimiter is not - // empty. - std::string delimiter_scratch; - absl::string_view delimiter_view = delimiter.NativeString(delimiter_scratch); - std::string content_scratch; - absl::string_view content_view = string.NativeString(content_scratch); - while (limit > 1 && !content_view.empty()) { - auto pos = content_view.find(delimiter_view); - if (pos == absl::string_view::npos) { - break; - } - // We assume the original string was well-formed. - CEL_RETURN_IF_ERROR( - builder->Add(StringValue(arena, content_view.substr(0, pos)))); - --limit; - content_view.remove_prefix(pos + delimiter_view.size()); - if (content_view.empty()) { - // We found the delimiter at the end of the string. Add an empty string - // to the end of the list. - CEL_RETURN_IF_ERROR(builder->Add(StringValue{})); - return std::move(*builder).Build(); - } - } - // We have one left in the limit or do not have any more matches. Add - // whatever is left as the remaining entry. - // - // We assume the original string was well-formed. - CEL_RETURN_IF_ERROR(builder->Add(StringValue(arena, content_view))); - return std::move(*builder).Build(); + return string.Split(delimiter, limit, arena); } absl::StatusOr Split2( @@ -221,27 +86,7 @@ absl::StatusOr Split2( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - return Split3(string, delimiter, -1, descriptor_pool, message_factory, arena); -} - -absl::StatusOr LowerAscii(const StringValue& string, - const google::protobuf::DescriptorPool* absl_nonnull, - google::protobuf::MessageFactory* absl_nonnull, - google::protobuf::Arena* absl_nonnull arena) { - std::string content = string.NativeString(); - absl::AsciiStrToLower(&content); - // We assume the original string was well-formed. - return StringValue(arena, std::move(content)); -} - -absl::StatusOr UpperAscii(const StringValue& string, - const google::protobuf::DescriptorPool* absl_nonnull, - google::protobuf::MessageFactory* absl_nonnull, - google::protobuf::Arena* absl_nonnull arena) { - std::string content = string.NativeString(); - absl::AsciiStrToUpper(&content); - // We assume the original string was well-formed. - return StringValue(arena, std::move(content)); + return string.Split(delimiter, arena); } absl::StatusOr Replace2(const StringValue& string, @@ -250,38 +95,7 @@ absl::StatusOr Replace2(const StringValue& string, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull arena) { - if (limit == 0) { - // When the replacement limit is 0, the result is the original string. - return string; - } - if (limit < 0) { - // Per spec, when limit is negative treat is as unlimited. - limit = std::numeric_limits::max(); - } - - std::string result; - std::string old_sub_scratch; - absl::string_view old_sub_view = old_sub.NativeString(old_sub_scratch); - std::string new_sub_scratch; - absl::string_view new_sub_view = new_sub.NativeString(new_sub_scratch); - std::string content_scratch; - absl::string_view content_view = string.NativeString(content_scratch); - while (limit > 0 && !content_view.empty()) { - auto pos = content_view.find(old_sub_view); - if (pos == absl::string_view::npos) { - break; - } - result.append(content_view.substr(0, pos)); - result.append(new_sub_view); - --limit; - content_view.remove_prefix(pos + old_sub_view.size()); - } - // Add the remainder of the string. - if (!content_view.empty()) { - result.append(content_view); - } - - return StringValue(arena, std::move(result)); + return string.Replace(old_sub, new_sub, limit, arena); } absl::StatusOr Replace1( @@ -290,8 +104,21 @@ absl::StatusOr Replace1( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::MessageFactory* absl_nonnull message_factory, google::protobuf::Arena* absl_nonnull arena) { - return Replace2(string, old_sub, new_sub, -1, descriptor_pool, - message_factory, arena); + return string.Replace(old_sub, new_sub, -1, arena); +} + +StringValue LowerAscii(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.LowerAscii(arena); +} + +StringValue UpperAscii(const StringValue& string, + const google::protobuf::DescriptorPool* absl_nonnull, + google::protobuf::MessageFactory* absl_nonnull, + google::protobuf::Arena* absl_nonnull arena) { + return string.UpperAscii(arena); } const Type& ListStringType() { diff --git a/extensions/strings_test.cc b/extensions/strings_test.cc index e2eb5e71f..a2ec7f582 100644 --- a/extensions/strings_test.cc +++ b/extensions/strings_test.cc @@ -15,6 +15,7 @@ #include "extensions/strings.h" #include +#include #include #include "cel/expr/syntax.pb.h" @@ -24,6 +25,7 @@ #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/decl.h" +#include "common/type.h" #include "common/value.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h"