Skip to content

Commit ba2bde3

Browse files
pdan101Evergreen Agent
authored andcommitted
SERVER-51563 Support expression $trim in SBE
1 parent 8f376bd commit ba2bde3

File tree

13 files changed

+404
-163
lines changed

13 files changed

+404
-163
lines changed

jstests/aggregation/expressions/trim.js

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"use strict";
66
load("jstests/aggregation/extras/utils.js"); // For assertErrorCode, testExpression and
77
// testExpressionWithCollation.
8+
load("jstests/libs/sbe_assert_error_override.js");
89

910
const coll = db.trim_expressions;
1011

@@ -81,10 +82,48 @@ assert.eq(
8182
{_id: 4, proof: null},
8283
]);
8384

84-
// Test that errors are reported correctly.
85-
assertErrorCode(coll, [{$project: {x: {$trim: " x "}}}], 50696);
86-
assertErrorCode(coll, [{$project: {x: {$trim: {input: 4}}}}], 50699);
87-
assertErrorCode(coll, [{$project: {x: {$trim: {input: {$add: [4, 2]}}}}}], 50699);
88-
assertErrorCode(coll, [{$project: {x: {$trim: {input: "$_id"}}}}], 50699);
89-
assertErrorCode(coll, [{$project: {x: {$trim: {input: " x ", chars: "$_id"}}}}], 50700);
85+
// Semantically same as the tests above but non-constant input for 'chars'
86+
coll.drop();
87+
assert.commandWorked(coll.insert([
88+
{_id: 0, proof: "Left as an exercise for the reader∎", extra: "∎"},
89+
{_id: 1, proof: "∎∃ proof∎", extra: "∎"},
90+
{
91+
_id: 2,
92+
proof: "Just view the problem as a continuous DAG whose elements are taylor series∎",
93+
extra: "∎"
94+
},
95+
{_id: 3, proof: null},
96+
{_id: 4},
97+
]));
98+
assert.eq(
99+
coll.aggregate(
100+
[{$sort: {_id: 1}}, {$project: {proof: {$rtrim: {input: "$proof", chars: "$extra"}}}}])
101+
.toArray(),
102+
[
103+
{_id: 0, proof: "Left as an exercise for the reader"},
104+
{_id: 1, proof: "∎∃ proof"},
105+
{
106+
_id: 2,
107+
proof: "Just view the problem as a continuous DAG whose elements are taylor series"
108+
},
109+
{_id: 3, proof: null},
110+
{_id: 4, proof: null},
111+
]);
112+
113+
coll.drop();
114+
assert.commandWorked(coll.insert([
115+
{_id: 0, nonObject: " x "},
116+
{_id: 1, constantNum: 4},
117+
]));
118+
119+
// Test that errors are reported correctly (for all of $trim, $ltrim, $rtrim).
120+
for (const op of ["$trim", "$ltrim", "$rtrim"]) {
121+
assertErrorCode(coll, [{$project: {x: {[op]: {}}}}], 50695);
122+
assertErrorCode(coll, [{$project: {x: {[op]: "$nonObject"}}}], 50696);
123+
assertErrorCode(coll, [{$project: {x: {[op]: {input: "$constantNum"}}}}], 50699);
124+
assertErrorCode(
125+
coll, [{$project: {x: {[op]: {input: {$add: ["$constantNum", "$constantNum"]}}}}}], 50699);
126+
assertErrorCode(coll, [{$project: {x: {[op]: {input: "$_id"}}}}], 50699);
127+
assertErrorCode(coll, [{$project: {x: {[op]: {input: "$nonObject", chars: "$_id"}}}}], 50700);
128+
}
90129
}());

jstests/libs/sbe_assert_error_override.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ const equivalentErrorCodesList = [
163163
[5787903, 7548606],
164164
[5787908, 7548606],
165165
[ErrorCodes.BadValue, 4938500],
166+
[50700, 5156303],
167+
[50699, 5156302],
168+
[50697, 5156304],
169+
[50698, 5156305],
166170
[5155800, 34473],
167171
[5155801, 34470],
168172
];

src/mongo/db/SConscript

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,6 +1709,7 @@ env.Library(
17091709
LIBDEPS=[
17101710
'$BUILD_DIR/mongo/bson/util/bson_extract',
17111711
'$BUILD_DIR/mongo/crypto/fle_crypto',
1712+
'$BUILD_DIR/mongo/db/query/str_trim_utils',
17121713
'$BUILD_DIR/mongo/scripting/scripting',
17131714
'$BUILD_DIR/mongo/scripting/scripting_common',
17141715
'$BUILD_DIR/mongo/util/pcre_util',

src/mongo/db/exec/sbe/SConscript

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ sbeEnv.Library(
6666
],
6767
LIBDEPS_PRIVATE=[
6868
'$BUILD_DIR/mongo/db/bson/dotted_path_support',
69+
'$BUILD_DIR/mongo/db/query/str_trim_utils',
6970
'$BUILD_DIR/mongo/db/sorter/sorter_idl',
7071
'$BUILD_DIR/mongo/db/sorter/sorter_stats',
7172
],

src/mongo/db/exec/sbe/expressions/expression.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = {
692692
{"strLenBytes", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::strLenBytes, false}},
693693
{"toLower", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::toLower, false}},
694694
{"toUpper", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::toUpper, false}},
695+
{"trim", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::trim, false}},
696+
{"ltrim", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::ltrim, false}},
697+
{"rtrim", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::rtrim, false}},
695698
{"coerceToBool", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::coerceToBool, false}},
696699
{"coerceToString",
697700
BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::coerceToString, false}},

src/mongo/db/exec/sbe/vm/vm.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
#include "mongo/db/query/collation/collation_index_key.h"
8787
#include "mongo/db/query/datetime/date_time_support.h"
8888
#include "mongo/db/query/query_knobs_gen.h"
89+
#include "mongo/db/query/str_trim_utils.h"
8990
#include "mongo/db/storage/column_store.h"
9091
#include "mongo/db/storage/key_string.h"
9192
#include "mongo/logv2/log.h"
@@ -4036,6 +4037,26 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinConcatArrays(Ari
40364037
return {true, resTag, resVal};
40374038
}
40384039

4040+
FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinTrim(ArityType arity,
4041+
bool trimLeft,
4042+
bool trimRight) {
4043+
auto [ownedChars, tagChars, valChars] = getFromStack(1);
4044+
auto [ownedInput, tagInput, valInput] = getFromStack(0);
4045+
4046+
if (!value::isString(tagInput)) {
4047+
return {false, value::TypeTags::Nothing, 0};
4048+
}
4049+
4050+
auto replacementChars = !value::isNullish(tagChars)
4051+
? str_trim_utils::extractCodePointsFromChars(value::getStringView(tagChars, valChars))
4052+
: str_trim_utils::kDefaultTrimWhitespaceChars;
4053+
auto inputString = value::getStringView(tagInput, valInput);
4054+
4055+
auto [strTag, strValue] = sbe::value::makeNewString(
4056+
str_trim_utils::doTrim(inputString, replacementChars, trimLeft, trimRight));
4057+
return {true, strTag, strValue};
4058+
}
4059+
40394060
FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggConcatArraysCapped(
40404061
ArityType arity) {
40414062
auto [ownArr, tagArr, valArr] = getFromStack(0);
@@ -6935,6 +6956,12 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builtin
69356956
return builtinToUpper(arity);
69366957
case Builtin::toLower:
69376958
return builtinToLower(arity);
6959+
case Builtin::trim:
6960+
return builtinTrim(arity, true, true);
6961+
case Builtin::ltrim:
6962+
return builtinTrim(arity, true, false);
6963+
case Builtin::rtrim:
6964+
return builtinTrim(arity, false, true);
69386965
case Builtin::coerceToBool:
69396966
return builtinCoerceToBool(arity);
69406967
case Builtin::coerceToString:
@@ -7260,6 +7287,12 @@ std::string builtinToString(Builtin b) {
72607287
return "toUpper";
72617288
case Builtin::toLower:
72627289
return "toLower";
7290+
case Builtin::trim:
7291+
return "trim";
7292+
case Builtin::ltrim:
7293+
return "ltrim";
7294+
case Builtin::rtrim:
7295+
return "rtrim";
72637296
case Builtin::coerceToBool:
72647297
return "coerceToBool";
72657298
case Builtin::coerceToString:

src/mongo/db/exec/sbe/vm/vm.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,9 @@ enum class Builtin : uint8_t {
696696
coerceToString,
697697
concat,
698698
concatArrays,
699+
trim,
700+
ltrim,
701+
rtrim,
699702

700703
// Agg function to concatenate arrays, failing when the accumulator reaches a specified size.
701704
aggConcatArraysCapped,
@@ -1664,6 +1667,9 @@ class ByteCode {
16641667
FastTuple<bool, value::TypeTags, value::Value> builtinRound(ArityType arity);
16651668
FastTuple<bool, value::TypeTags, value::Value> builtinConcat(ArityType arity);
16661669
FastTuple<bool, value::TypeTags, value::Value> builtinConcatArrays(ArityType arity);
1670+
FastTuple<bool, value::TypeTags, value::Value> builtinTrim(ArityType arity,
1671+
bool trimLeft,
1672+
bool trimRight);
16671673
FastTuple<bool, value::TypeTags, value::Value> builtinAggConcatArraysCapped(ArityType arity);
16681674
FastTuple<bool, value::TypeTags, value::Value> builtinAggSetUnion(ArityType arity);
16691675
FastTuple<bool, value::TypeTags, value::Value> builtinAggSetUnionCapped(ArityType arity);

src/mongo/db/pipeline/expression.cpp

Lines changed: 12 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
#include "mongo/db/query/collation/collator_interface.h"
8282
#include "mongo/db/query/datetime/date_time_support.h"
8383
#include "mongo/db/query/query_knobs_gen.h"
84+
#include "mongo/db/query/sort_pattern.h"
85+
#include "mongo/db/query/str_trim_utils.h"
8486
#include "mongo/db/query/util/make_data_structure.h"
8587
#include "mongo/db/record_id.h"
8688
#include "mongo/db/stats/counters.h"
@@ -6022,74 +6024,6 @@ intrusive_ptr<Expression> ExpressionTrim::parse(ExpressionContext* const expCtx,
60226024
return new ExpressionTrim(expCtx, trimType, name, input, characters);
60236025
}
60246026

6025-
namespace {
6026-
const std::vector<StringData> kDefaultTrimWhitespaceChars = {
6027-
"\0"_sd, // Null character. Avoid using "\u0000" syntax to work around a gcc bug:
6028-
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=53690.
6029-
"\u0020"_sd, // Space
6030-
"\u0009"_sd, // Horizontal tab
6031-
"\u000A"_sd, // Line feed/new line
6032-
"\u000B"_sd, // Vertical tab
6033-
"\u000C"_sd, // Form feed
6034-
"\u000D"_sd, // Horizontal tab
6035-
"\u00A0"_sd, // Non-breaking space
6036-
"\u1680"_sd, // Ogham space mark
6037-
"\u2000"_sd, // En quad
6038-
"\u2001"_sd, // Em quad
6039-
"\u2002"_sd, // En space
6040-
"\u2003"_sd, // Em space
6041-
"\u2004"_sd, // Three-per-em space
6042-
"\u2005"_sd, // Four-per-em space
6043-
"\u2006"_sd, // Six-per-em space
6044-
"\u2007"_sd, // Figure space
6045-
"\u2008"_sd, // Punctuation space
6046-
"\u2009"_sd, // Thin space
6047-
"\u200A"_sd // Hair space
6048-
};
6049-
6050-
/**
6051-
* Assuming 'charByte' is the beginning of a UTF-8 code point, returns the number of bytes that
6052-
* should be used to represent the code point. Said another way, computes how many continuation
6053-
* bytes are expected to be present after 'charByte' in a UTF-8 encoded string.
6054-
*/
6055-
inline size_t numberOfBytesForCodePoint(char charByte) {
6056-
if ((charByte & 0b11111000) == 0b11110000) {
6057-
return 4;
6058-
} else if ((charByte & 0b11110000) == 0b11100000) {
6059-
return 3;
6060-
} else if ((charByte & 0b11100000) == 0b11000000) {
6061-
return 2;
6062-
} else {
6063-
return 1;
6064-
}
6065-
}
6066-
6067-
/**
6068-
* Returns a vector with one entry per code point to trim, or throws an exception if 'utf8String'
6069-
* contains invalid UTF-8.
6070-
*/
6071-
std::vector<StringData> extractCodePointsFromChars(StringData utf8String,
6072-
StringData expressionName) {
6073-
std::vector<StringData> codePoints;
6074-
std::size_t i = 0;
6075-
while (i < utf8String.size()) {
6076-
uassert(50698,
6077-
str::stream() << "Failed to parse \"chars\" argument to " << expressionName
6078-
<< ": Detected invalid UTF-8. Got continuation byte when expecting "
6079-
"the start of a new code point.",
6080-
!str::isUTF8ContinuationByte(utf8String[i]));
6081-
codePoints.push_back(utf8String.substr(i, numberOfBytesForCodePoint(utf8String[i])));
6082-
i += numberOfBytesForCodePoint(utf8String[i]);
6083-
}
6084-
uassert(50697,
6085-
str::stream()
6086-
<< "Failed to parse \"chars\" argument to " << expressionName
6087-
<< ": Detected invalid UTF-8. Missing expected continuation byte at end of string.",
6088-
i <= utf8String.size());
6089-
return codePoints;
6090-
}
6091-
} // namespace
6092-
60936027
Value ExpressionTrim::evaluate(const Document& root, Variables* variables) const {
60946028
auto unvalidatedInput = _children[_kInput]->evaluate(root, variables);
60956029
if (unvalidatedInput.nullish()) {
@@ -6103,7 +6037,11 @@ Value ExpressionTrim::evaluate(const Document& root, Variables* variables) const
61036037
const StringData input(unvalidatedInput.getStringData());
61046038

61056039
if (!_children[_kCharacters]) {
6106-
return Value(doTrim(input, kDefaultTrimWhitespaceChars));
6040+
return Value(
6041+
str_trim_utils::doTrim(input,
6042+
str_trim_utils::kDefaultTrimWhitespaceChars,
6043+
_trimType == TrimType::kBoth || _trimType == TrimType::kLeft,
6044+
_trimType == TrimType::kBoth || _trimType == TrimType::kRight));
61076045
}
61086046
auto unvalidatedUserChars = _children[_kCharacters]->evaluate(root, variables);
61096047
if (unvalidatedUserChars.nullish()) {
@@ -6115,65 +6053,11 @@ Value ExpressionTrim::evaluate(const Document& root, Variables* variables) const
61156053
<< typeName(unvalidatedUserChars.getType()) << ") instead.",
61166054
unvalidatedUserChars.getType() == BSONType::String);
61176055

6118-
return Value(
6119-
doTrim(input, extractCodePointsFromChars(unvalidatedUserChars.getStringData(), _name)));
6120-
}
6121-
6122-
bool ExpressionTrim::codePointMatchesAtIndex(const StringData& input,
6123-
std::size_t indexOfInput,
6124-
const StringData& testCP) {
6125-
for (size_t i = 0; i < testCP.size(); ++i) {
6126-
if (indexOfInput + i >= input.size() || input[indexOfInput + i] != testCP[i]) {
6127-
return false;
6128-
}
6129-
}
6130-
return true;
6131-
};
6132-
6133-
StringData ExpressionTrim::trimFromLeft(StringData input, const std::vector<StringData>& trimCPs) {
6134-
std::size_t bytesTrimmedFromLeft = 0u;
6135-
while (bytesTrimmedFromLeft < input.size()) {
6136-
// Look for any matching code point to trim.
6137-
auto matchingCP = std::find_if(trimCPs.begin(), trimCPs.end(), [&](auto& testCP) {
6138-
return codePointMatchesAtIndex(input, bytesTrimmedFromLeft, testCP);
6139-
});
6140-
if (matchingCP == trimCPs.end()) {
6141-
// Nothing to trim, stop here.
6142-
break;
6143-
}
6144-
bytesTrimmedFromLeft += matchingCP->size();
6145-
}
6146-
return input.substr(bytesTrimmedFromLeft);
6147-
}
6148-
6149-
StringData ExpressionTrim::trimFromRight(StringData input, const std::vector<StringData>& trimCPs) {
6150-
std::size_t bytesTrimmedFromRight = 0u;
6151-
while (bytesTrimmedFromRight < input.size()) {
6152-
std::size_t indexToTrimFrom = input.size() - bytesTrimmedFromRight;
6153-
auto matchingCP = std::find_if(trimCPs.begin(), trimCPs.end(), [&](auto& testCP) {
6154-
if (indexToTrimFrom < testCP.size()) {
6155-
// We've gone off the left of the string.
6156-
return false;
6157-
}
6158-
return codePointMatchesAtIndex(input, indexToTrimFrom - testCP.size(), testCP);
6159-
});
6160-
if (matchingCP == trimCPs.end()) {
6161-
// Nothing to trim, stop here.
6162-
break;
6163-
}
6164-
bytesTrimmedFromRight += matchingCP->size();
6165-
}
6166-
return input.substr(0, input.size() - bytesTrimmedFromRight);
6167-
}
6168-
6169-
StringData ExpressionTrim::doTrim(StringData input, const std::vector<StringData>& trimCPs) const {
6170-
if (_trimType == TrimType::kBoth || _trimType == TrimType::kLeft) {
6171-
input = trimFromLeft(input, trimCPs);
6172-
}
6173-
if (_trimType == TrimType::kBoth || _trimType == TrimType::kRight) {
6174-
input = trimFromRight(input, trimCPs);
6175-
}
6176-
return input;
6056+
return Value(str_trim_utils::doTrim(
6057+
input,
6058+
str_trim_utils::extractCodePointsFromChars(unvalidatedUserChars.getStringData()),
6059+
_trimType == TrimType::kBoth || _trimType == TrimType::kLeft,
6060+
_trimType == TrimType::kBoth || _trimType == TrimType::kRight));
61776061
}
61786062

61796063
boost::intrusive_ptr<Expression> ExpressionTrim::optimize() {

0 commit comments

Comments
 (0)