diff --git a/velox/docs/functions/presto/regexp.rst b/velox/docs/functions/presto/regexp.rst index c8a8f4b81ac6..6171ad7a90ab 100644 --- a/velox/docs/functions/presto/regexp.rst +++ b/velox/docs/functions/presto/regexp.rst @@ -7,6 +7,9 @@ supports only a subset of PCRE syntax and in particular does not support backtracking and associated features (e.g. back references). See https://github.com/google/re2/wiki/Syntax for more information. +Compiling regular expressions is CPU intensive. Hence, each function is +limited to 20 different expressions per instance and thread of execution. + .. function:: like(string, pattern) -> boolean like(string, pattern, escape) -> boolean @@ -19,9 +22,11 @@ See https://github.com/google/re2/wiki/Syntax for more information. wildcard '_' represents exactly one character. Note: Each function instance allow for a maximum of 20 regular expressions to - be compiled throughout the lifetime of the query. Not all Patterns requires - compilation of regular expressions; for example a pattern 'aa' does not. - Only those that require the compilation of regular expressions are counted. + be compiled per thread of execution. Not all patterns require + compilation of regular expressions. Patterns 'aaa', 'aaa%', '%aaa', where 'aaa' + contains only regular characters and '_' wildcards are evaluated without + using regular expressions. Only those patterns that require the compilation of + regular expressions are counted towards the limit. SELECT like('abc', '%b%'); -- true SELECT like('a_c', '%#_%', '#'); -- true diff --git a/velox/expression/tests/ExpressionFuzzerTest.cpp b/velox/expression/tests/ExpressionFuzzerTest.cpp index d206add7c3bf..007e8a428e1d 100644 --- a/velox/expression/tests/ExpressionFuzzerTest.cpp +++ b/velox/expression/tests/ExpressionFuzzerTest.cpp @@ -54,6 +54,10 @@ int main(int argc, char** argv) { "width_bucket", // Fuzzer cannot generate valid 'comparator' lambda. "array_sort(array(T),constant function(T,T,bigint)) -> array(T)", + // https://github.com/facebookincubator/velox/issues/8438#issuecomment-1907234044 + "regexp_extract", + "regexp_extract_all", + "regexp_like", }; size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; return FuzzerRunner::run(initialSeed, skipFunctions); diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index ff7053526a6d..8c2509f03adc 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -23,6 +23,48 @@ namespace { static const int kMaxCompiledRegexes = 20; +void checkForBadPattern(const RE2& re) { + if (UNLIKELY(!re.ok())) { + VELOX_USER_FAIL("invalid regular expression:{}", re.error()); + } +} + +template +re2::StringPiece toStringPiece(const T& s) { + return re2::StringPiece(s.data(), s.size()); +} + +// A cache of compiled regular expressions (RE2 instances). Allows up to +// 'kMaxCompiledRegexes' different expressions. +// +// Compiling regular expressions is expensive. It can take up to 200 times +// more CPU time to compile a regex vs. evaluate it. +class ReCache { + public: + RE2* findOrCompile(const StringView& pattern) { + const std::string key = pattern; + + auto reIt = cache_.find(key); + if (reIt != cache_.end()) { + return reIt->second.get(); + } + + VELOX_USER_CHECK_LT( + cache_.size(), kMaxCompiledRegexes, "Max number of regex reached"); + + auto re = std::make_unique(toStringPiece(pattern), RE2::Quiet); + checkForBadPattern(*re); + + auto [it, inserted] = cache_.emplace(key, std::move(re)); + VELOX_CHECK(inserted); + + return it->second.get(); + } + + private: + folly::F14FastMap> cache_; +}; + std::string printTypesCsv( const std::vector& inputArgs) { std::string result; @@ -34,11 +76,6 @@ std::string printTypesCsv( return result; } -template -re2::StringPiece toStringPiece(const T& s) { - return re2::StringPiece(s.data(), s.size()); -} - // If v is a non-null constant vector, returns the constant value. Otherwise // returns nullopt. template @@ -50,12 +87,6 @@ std::optional getIfConstant(const BaseVector& v) { return std::nullopt; } -void checkForBadPattern(const RE2& re) { - if (UNLIKELY(!re.ok())) { - VELOX_USER_FAIL("invalid regular expression:{}", re.error()); - } -} - FlatVector& ensureWritableBool( const SelectivityVector& rows, exec::EvalCtx& context, @@ -220,11 +251,13 @@ class Re2Match final : public exec::VectorFunction { exec::LocalDecodedVector toSearch(context, *args[0], rows); exec::LocalDecodedVector pattern(context, *args[1], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); result.set(row, Fn(toSearch->valueAt(row), re)); }); } + + private: + mutable ReCache cache_; }; void checkForBadGroupId(int64_t groupId, const RE2& re) { @@ -348,8 +381,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { if (args.size() == 2) { groups.resize(1); context.applyToSelectedNoThrow(rows, [&](vector_size_t i) { - RE2 re(toStringPiece(pattern->valueAt(i)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(i)); mustRefSourceStrings |= re2Extract(result, i, re, toSearch, groups, 0, emptyNoMatch_); }); @@ -357,8 +389,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { exec::LocalDecodedVector groupIds(context, *args[2], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t i) { const auto groupId = groupIds->valueAt(i); - RE2 re(toStringPiece(pattern->valueAt(i)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(i)); checkForBadGroupId(groupId, re); groups.resize(groupId + 1); mustRefSourceStrings |= @@ -372,6 +403,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { private: const bool emptyNoMatch_; + mutable ReCache cache_; }; namespace { @@ -1126,8 +1158,7 @@ class Re2ExtractAll final : public exec::VectorFunction { // groups.resize(1); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); re2ExtractAll(resultWriter, re, inputStrs, row, groups, 0); }); } else { @@ -1136,8 +1167,7 @@ class Re2ExtractAll final : public exec::VectorFunction { exec::LocalDecodedVector groupIds(context, *args[2], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { const T groupId = groupIds->valueAt(row); - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); checkForBadGroupId(groupId, re); groups.resize(groupId + 1); re2ExtractAll(resultWriter, re, inputStrs, row, groups, groupId); @@ -1150,6 +1180,9 @@ class Re2ExtractAll final : public exec::VectorFunction { ->asFlatVector() ->acquireSharedStringBuffers(inputStrs->base()); } + + private: + mutable ReCache cache_; }; template @@ -1170,9 +1203,8 @@ std::shared_ptr makeRe2MatchImpl( return std::make_shared>( constantPattern->as>()->valueAt(0)); } - static std::shared_ptr> kMatchExpr = - std::make_shared>(); - return kMatchExpr; + + return std::make_shared>(); } } // namespace diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index 1d0bdfb505b2..bb7503d74b5f 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -1431,5 +1431,45 @@ TEST_F(Re2FunctionsTest, regexExtractAllLarge) { "No group 4611686018427387904 in regex '(\\d+)([a-z]+)") } +// Make sure we do not compile more than kMaxCompiledRegexes. +TEST_F(Re2FunctionsTest, limit) { + auto data = makeRowVector({ + makeFlatVector( + 100, + [](auto row) { return fmt::format("Apples and oranges {}", row); }), + makeFlatVector( + 100, + [](auto row) { return fmt::format("Apples (.*) oranges {}", row); }), + makeFlatVector( + 100, + [](auto row) { + return fmt::format("Apples (.*) oranges {}", row % 20); + }), + }); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract(c0, c1)", data), "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract(c0, c2)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract(c0, c1, 1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract(c0, c2, 1)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract_all(c0, c1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract_all(c0, c2)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract_all(c0, c1, 1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract_all(c0, c2, 1)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_like(c0, c1)", data), "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_like(c0, c2)", data)); +} + } // namespace } // namespace facebook::velox::functions