diff --git a/.circleci/dist_compile.yml b/.circleci/dist_compile.yml index 3278a67f6565..94a71eab4f7e 100644 --- a/.circleci/dist_compile.yml +++ b/.circleci/dist_compile.yml @@ -138,85 +138,8 @@ executors: check: docker: - image : ghcr.io/facebookincubator/velox-dev:check-avx - macos-intel: - macos: - xcode: "14.3.0" - resource_class: macos.x86.medium.gen2 - macos-m1: - macos: - xcode: "14.2.0" - resource_class: macos.m1.large.gen1 jobs: - macos-build: - parameters: - os: - type: executor - executor: << parameters.os >> - environment: - ICU_SOURCE: BUNDLED - simdjson_SOURCE: BUNDLED - steps: - - checkout - - update-submodules - - restore_cache: - name: "Restore Dependency Cache" - # The version number in the key can be incremented - # to manually avoid the case where bad dependencies - # are cached, and has no other meaning. - # If you update it, be sure to update save_cache too. - key: velox-circleci-macos-{{ arch }}-deps-v1-{{ checksum ".circleci/config.yml" }}-{{ checksum "scripts/setup-macos.sh" }} - - run: - name: "Install dependencies" - command: | - set -xu - mkdir -p ~/deps ~/deps-src - curl -L https://github.com/Homebrew/brew/tarball/master | tar xz --strip 1 -C ~/deps - PATH=~/deps/bin:${PATH} DEPENDENCY_DIR=~/deps-src INSTALL_PREFIX=~/deps PROMPT_ALWAYS_RESPOND=n ./scripts/setup-macos.sh - rm -rf ~/deps/.git ~/deps/Library/Taps/ # Reduce cache size by 70%. - no_output_timeout: 20m - - save_cache: - name: "Save Dependency Cache" - # The version number in the key can be incremented - # to manually avoid the case where bad dependencies - # are cached, and has no other meaning. - # If you update it, be sure to update restore_cache too. - key: velox-circleci-macos-{{ arch }}-deps-v1-{{ checksum ".circleci/config.yml" }}-{{ checksum "scripts/setup-macos.sh" }} - paths: - - ~/deps - - run: - name: "Calculate merge-base date for CCache" - command: git show -s --format=%cd --date="format:%Y%m%d" $(git merge-base origin/main HEAD) | tee merge-base-date - - restore_cache: - name: "Restore CCache cache" - keys: - - velox-ccache-debug-{{ arch }}-{{ checksum "merge-base-date" }} - - run: - name: "Build on MacOS" - command: | - export PATH=~/deps/bin:~/deps/opt/bison/bin:~/deps/opt/flex/bin:${PATH} - mkdir -p .ccache - export CCACHE_DIR=$(pwd)/.ccache - ccache -sz -M 5Gi - cmake \ - -B _build/debug \ - -GNinja \ - -DTREAT_WARNINGS_AS_ERRORS=1 \ - -DENABLE_ALL_WARNINGS=1 \ - -DVELOX_ENABLE_PARQUET=ON \ - -DCMAKE_BUILD_TYPE=Debug \ - -DCMAKE_PREFIX_PATH=~/deps \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DFLEX_INCLUDE_DIR=~/deps/opt/flex/include - ninja -C _build/debug - ccache -s - no_output_timeout: 1h - - save_cache: - name: "Save CCache cache" - key: velox-ccache-debug-{{ arch }}-{{ checksum "merge-base-date" }} - paths: - - .ccache/ - linux-build: executor: build environment: @@ -681,10 +604,6 @@ workflows: - linux-build-options - linux-adapters - linux-presto-fuzzer-run - - macos-build: - matrix: - parameters: - os: [macos-intel] - format-check - header-check - doc-gen-job: @@ -692,14 +611,6 @@ workflows: branches: only: - main - - macos-build: - matrix: - parameters: - os: [ macos-m1 ] - filters: - branches: - only: - - main shorter-fuzzer: unless: << pipeline.parameters.run-longer-expression-fuzzer >> @@ -708,10 +619,6 @@ workflows: - linux-pr-fuzzer-run - linux-build-options - linux-adapters - - macos-build: - matrix: - parameters: - os: [ macos-intel ] - format-check - header-check - doc-gen-job: @@ -719,11 +626,3 @@ workflows: branches: only: - main - - macos-build: - matrix: - parameters: - os: [ macos-m1 ] - filters: - branches: - only: - - main diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml new file mode 100644 index 000000000000..7c6c87661cd7 --- /dev/null +++ b/.github/workflows/macos.yml @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +name: macOS Build + +on: + push: + pull_request: + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.repository }}-${{ github.head_ref || github.sha }} + cancel-in-progress: true + +jobs: + macos-build: + name: "${{ matrix.os }}" + strategy: + fail-fast: false + matrix: + # macos-13 = x86_64 Mac + # macos-14 = arm64 Mac + os: [macos-13, macos-14] + runs-on: ${{ matrix.os }} + env: + CCACHE_DIR: '${{ github.workspace }}/.ccache' + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + - name: Install Dependencies + run: | + brew install \ + bison boost ccache double-conversion flex fmt gflags glog \ + icu4c libevent libsodium lz4 lzo ninja openssl range-v3 simdjson \ + snappy thrift xz xsimd zstd + + echo "NJOBS=`sysctl -n hw.ncpu`" >> $GITHUB_ENV + + - name: Cache ccache + uses: actions/cache@v4 + with: + path: '${{ env.CCACHE_DIR }}' + key: ccache-macos-${{ matrix.os }}-${{ hashFiles('velox/*') }} + restore-keys: ccache-macos-${{ matrix.os }} + + - name: Configure Build + env: + folly_SOURCE: BUNDLED + run: | + ccache -sz -M 5Gi + cmake \ + -B _build/debug \ + -GNinja \ + -DTREAT_WARNINGS_AS_ERRORS=1 \ + -DENABLE_ALL_WARNINGS=1 \ + -DVELOX_ENABLE_PARQUET=ON \ + -DCMAKE_BUILD_TYPE=Debug + + - name: Build + run: | + cmake --build _build/debug -j $NJOBS + ccache -s + - name: Run Tests + if: false + run: ctest -j $NJOBS --test-dir _build/debug --output-on-failure + diff --git a/CMakeLists.txt b/CMakeLists.txt index b9c88d4add33..1c7dc7d568d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -418,7 +418,7 @@ if(${VELOX_ENABLE_DUCKDB}) endif() set_source(fmt) -resolve_dependency(fmt) +resolve_dependency(fmt 9.0.0) if(NOT ${VELOX_BUILD_MINIMAL}) find_package(ZLIB REQUIRED) diff --git a/scripts/setup-adapters.sh b/scripts/setup-adapters.sh index 632dbd33b700..675ff4c0291d 100755 --- a/scripts/setup-adapters.sh +++ b/scripts/setup-adapters.sh @@ -84,6 +84,9 @@ function install_gcs-sdk-cpp { } function install_azure-storage-sdk-cpp { + # Disable VCPKG to install additional static dependencies under the VCPKG installed path + # instead of using system pre-installed dependencies. + export AZURE_SDK_DISABLE_AUTO_VCPKG=ON vcpkg_commit_id=7a6f366cefd27210f6a8309aed10c31104436509 github_checkout azure/azure-sdk-for-cpp azure-storage-files-datalake_12.8.0 sed -i "s/set(VCPKG_COMMIT_STRING .*)/set(VCPKG_COMMIT_STRING $vcpkg_commit_id)/" cmake-modules/AzureVcpkg.cmake diff --git a/scripts/setup-ubuntu.sh b/scripts/setup-ubuntu.sh index 14ac9f144b91..69760cf85ec0 100755 --- a/scripts/setup-ubuntu.sh +++ b/scripts/setup-ubuntu.sh @@ -24,11 +24,12 @@ CPU_TARGET="${CPU_TARGET:-avx}" COMPILER_FLAGS=$(get_cxx_flags "$CPU_TARGET") export COMPILER_FLAGS FB_OS_VERSION=v2023.12.04.00 +FMT_VERSION=10.1.1 NPROC=$(getconf _NPROCESSORS_ONLN) DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} export CMAKE_BUILD_TYPE=Release -# Install all velox and folly dependencies. +# Install all velox and folly dependencies. # The is an issue on 22.04 where a version conflict prevents glog install, # installing libunwind first fixes this. apt update && apt install sudo @@ -46,7 +47,6 @@ sudo --preserve-env apt update && sudo --preserve-env apt install -y libunwind-d libboost-all-dev \ libicu-dev \ libdouble-conversion-dev \ - libfmt-dev \ libgoogle-glog-dev \ libbz2-dev \ libgflags-dev \ @@ -87,6 +87,11 @@ function prompt { ) 2> /dev/null } +function install_fmt { + github_checkout fmtlib/fmt "${FMT_VERSION}" + cmake_install -DFMT_TEST=OFF +} + function install_folly { github_checkout facebook/folly "${FB_OS_VERSION}" cmake_install -DBUILD_TESTS=OFF -DFOLLY_HAVE_INT128_T=ON @@ -120,6 +125,7 @@ function install_conda { } function install_velox_deps { + run_and_time install_fmt run_and_time install_folly run_and_time install_fizz run_and_time install_wangle diff --git a/velox/common/base/tests/GTestUtils.h b/velox/common/base/tests/GTestUtils.h index f11b61520b1c..4104b6955f94 100644 --- a/velox/common/base/tests/GTestUtils.h +++ b/velox/common/base/tests/GTestUtils.h @@ -63,7 +63,8 @@ << "Expected error message to contain '" << (_errorMessage) \ << "', but received '" << status.message() << "'." -#define VELOX_ASSERT_ERROR_CODE_IMPL(_type, _expression, _errorCode) \ +#define VELOX_ASSERT_ERROR_CODE_IMPL( \ + _type, _expression, _errorCode, _errorMessage) \ try { \ (_expression); \ FAIL() << "Expected an exception"; \ @@ -71,19 +72,26 @@ ASSERT_TRUE(e.errorCode() == _errorCode) \ << "Expected error code to be '" << _errorCode << "', but received '" \ << e.errorCode() << "'."; \ + ASSERT_TRUE(e.message().find(_errorMessage) != std::string::npos) \ + << "Expected error message to contain '" << (_errorMessage) \ + << "', but received '" << e.message() << "'."; \ } -#define VELOX_ASSERT_THROW_CODE(_expression, _errorCode) \ - VELOX_ASSERT_ERROR_CODE_IMPL( \ - facebook::velox::VeloxException, _expression, _errorCode) +#define VELOX_ASSERT_THROW_CODE(_expression, _errorCode, _errorMessage) \ + VELOX_ASSERT_ERROR_CODE_IMPL( \ + facebook::velox::VeloxException, _expression, _errorCode, _errorMessage) -#define VELOX_ASSERT_USER_THROW_CODE(_expression, _errorCode) \ - VELOX_ASSERT_ERROR_CODE_IMPL( \ - facebook::velox::VeloxUserError, _expression, _errorCode) +#define VELOX_ASSERT_USER_THROW_CODE(_expression, _errorCode, _errorMessage) \ + VELOX_ASSERT_ERROR_CODE_IMPL( \ + facebook::velox::VeloxUserError, _expression, _errorCode, _errorMessage) -#define VELOX_ASSERT_RUNTIME_THROW_CODE(_expression, _errorCode) \ - VELOX_ASSERT_ERROR_CODE_IMPL( \ - facebook::velox::VeloxRuntimeError, _expression, _errorCode) +#define VELOX_ASSERT_RUNTIME_THROW_CODE( \ + _expression, _errorCode, _errorMessage) \ + VELOX_ASSERT_ERROR_CODE_IMPL( \ + facebook::velox::VeloxRuntimeError, \ + _expression, \ + _errorCode, \ + _errorMessage) #ifndef NDEBUG #define DEBUG_ONLY_TEST(test_fixture, test_name) TEST(test_fixture, test_name) diff --git a/velox/common/caching/CMakeLists.txt b/velox/common/caching/CMakeLists.txt index 1fa409b965e9..6aa3ee11f735 100644 --- a/velox/common/caching/CMakeLists.txt +++ b/velox/common/caching/CMakeLists.txt @@ -28,6 +28,7 @@ target_link_libraries( velox_exception velox_file velox_memory + velox_process Folly::folly fmt::fmt gflags::gflags diff --git a/velox/common/caching/SsdFile.cpp b/velox/common/caching/SsdFile.cpp index b246e198bbe4..bcb0c6b01223 100644 --- a/velox/common/caching/SsdFile.cpp +++ b/velox/common/caching/SsdFile.cpp @@ -15,12 +15,14 @@ */ #include "velox/common/caching/SsdFile.h" + #include #include #include "velox/common/base/AsyncSource.h" #include "velox/common/base/SuccinctPrinter.h" #include "velox/common/caching/FileIds.h" #include "velox/common/caching/SsdCache.h" +#include "velox/common/process/TraceContext.h" #include #ifdef linux @@ -128,6 +130,7 @@ SsdFile::SsdFile( shardId_(shardId), checkpointIntervalBytes_(checkpointIntervalBytes), executor_(executor) { + process::TraceContext trace("SsdFile::SsdFile"); int32_t oDirect = 0; #ifdef linux oDirect = FLAGS_ssd_odirect ? O_DIRECT : 0; @@ -266,6 +269,7 @@ CoalesceIoStats SsdFile::load( void SsdFile::read( uint64_t offset, const std::vector>& buffers) { + process::TraceContext trace("SsdFile::read"); readFile_->preadv(offset, buffers); } @@ -307,6 +311,7 @@ std::optional> SsdFile::getSpace( } bool SsdFile::growOrEvictLocked() { + process::TraceContext trace("SsdFile::growOrEvictLocked"); if (numRegions_ < maxRegions_) { const auto newSize = (numRegions_ + 1) * kRegionSize; const auto rc = ::ftruncate(fd_, newSize); @@ -360,6 +365,7 @@ void SsdFile::clearRegionEntriesLocked(const std::vector& regions) { } void SsdFile::write(std::vector& pins) { + process::TraceContext trace("SsdFile::write"); // Sorts the pins by their file/offset. In this way what is adjacent in // storage is likely adjacent on SSD. std::sort(pins.begin(), pins.end()); @@ -444,6 +450,7 @@ int32_t indexOfFirstMismatch(char* x, char* y, int n) { } // namespace void SsdFile::verifyWrite(AsyncDataCacheEntry& entry, SsdRun ssdRun) { + process::TraceContext trace("SsdFile::verifyWrite"); auto testData = std::make_unique(entry.size()); const auto rc = ::pread(fd_, testData.get(), entry.size(), ssdRun.offset()); VELOX_CHECK_EQ(rc, entry.size()); @@ -512,6 +519,7 @@ void SsdFile::clear() { } void SsdFile::deleteFile() { + process::TraceContext trace("SsdFile::deleteFile"); if (fd_) { close(fd_); fd_ = 0; @@ -651,6 +659,7 @@ inline const char* asChar(const T* ptr) { } // namespace void SsdFile::checkpoint(bool force) { + process::TraceContext trace("SsdFile::checkpoint"); std::lock_guard l(mutex_); if (!force && (bytesAfterCheckpoint_ < checkpointIntervalBytes_)) { return; diff --git a/velox/common/file/tests/FileTest.cpp b/velox/common/file/tests/FileTest.cpp index 1b73653495de..af5654a72369 100644 --- a/velox/common/file/tests/FileTest.cpp +++ b/velox/common/file/tests/FileTest.cpp @@ -324,5 +324,7 @@ TEST(LocalFile, fileNotFound) { auto path = fmt::format("{}/file", tempFolder->path); auto localFs = filesystems::getFileSystem(path, nullptr); VELOX_ASSERT_RUNTIME_THROW_CODE( - localFs->openFileForRead(path), error_code::kFileNotFound); + localFs->openFileForRead(path), + error_code::kFileNotFound, + "No such file or directory"); } diff --git a/velox/common/memory/ByteStream.h b/velox/common/memory/ByteStream.h index 0c623dfd29ee..677d6659df51 100644 --- a/velox/common/memory/ByteStream.h +++ b/velox/common/memory/ByteStream.h @@ -227,6 +227,10 @@ class ByteOutputStream { void operator=(const ByteOutputStream& other) = delete; + // Forcing a move constructor to be able to return ByteOutputStream objects + // from a function. + ByteOutputStream(ByteOutputStream&&) = default; + /// Sets 'this' to range over 'range'. If this is for purposes of writing, /// lastWrittenPosition specifies the end of any pre-existing content in /// 'range'. diff --git a/velox/common/memory/Memory.cpp b/velox/common/memory/Memory.cpp index d069b2caeabb..ddb5266b35dd 100644 --- a/velox/common/memory/Memory.cpp +++ b/velox/common/memory/Memory.cpp @@ -245,7 +245,7 @@ void MemoryManager::dropPool(MemoryPool* pool) { MemoryPool& MemoryManager::deprecatedSharedLeafPool() { const auto idx = std::hash{}(std::this_thread::get_id()); - folly::SharedMutex::ReadHolder guard{mutex_}; + std::shared_lock guard{mutex_}; return *sharedLeafPools_.at(idx % sharedLeafPools_.size()); } @@ -257,7 +257,7 @@ size_t MemoryManager::numPools() const { size_t numPools = defaultRoot_->getChildCount(); VELOX_CHECK_GE(numPools, 0); { - folly::SharedMutex::ReadHolder guard{mutex_}; + std::shared_lock guard{mutex_}; numPools += pools_.size() - sharedLeafPools_.size(); } return numPools; @@ -303,7 +303,7 @@ std::string MemoryManager::toString(bool detail) const { std::vector> MemoryManager::getAlivePools() const { std::vector> pools; - folly::SharedMutex::ReadHolder guard{mutex_}; + std::shared_lock guard{mutex_}; pools.reserve(pools_.size()); for (const auto& entry : pools_) { auto pool = entry.second.lock(); diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index aa2709e7be29..bec9c2ddc0e3 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -215,7 +215,7 @@ uint64_t MemoryReclaimer::reclaim( }; std::vector candidates; { - folly::SharedMutex::ReadHolder guard{pool->poolMutex_}; + std::shared_lock guard{pool->poolMutex_}; candidates.reserve(pool->children_.size()); for (auto& entry : pool->children_) { auto child = entry.second.lock(); diff --git a/velox/common/memory/MemoryPool.cpp b/velox/common/memory/MemoryPool.cpp index b6954ccad53e..6bf9fe0612ab 100644 --- a/velox/common/memory/MemoryPool.cpp +++ b/velox/common/memory/MemoryPool.cpp @@ -266,7 +266,7 @@ MemoryPool* MemoryPool::root() const { } uint64_t MemoryPool::getChildCount() const { - folly::SharedMutex::ReadHolder guard{poolMutex_}; + std::shared_lock guard{poolMutex_}; return children_.size(); } @@ -274,7 +274,7 @@ void MemoryPool::visitChildren( const std::function& visitor) const { std::vector> children; { - folly::SharedMutex::ReadHolder guard{poolMutex_}; + std::shared_lock guard{poolMutex_}; children.reserve(children_.size()); for (auto& entry : children_) { auto child = entry.second.lock(); diff --git a/velox/common/memory/tests/MemoryAllocatorTest.cpp b/velox/common/memory/tests/MemoryAllocatorTest.cpp index 7bf2fc6a174c..b01550568b1f 100644 --- a/velox/common/memory/tests/MemoryAllocatorTest.cpp +++ b/velox/common/memory/tests/MemoryAllocatorTest.cpp @@ -30,6 +30,10 @@ #include #include +#ifdef linux +#include +#endif // linux + DECLARE_int32(velox_memory_pool_mb); DECLARE_bool(velox_memory_leak_check_enabled); @@ -1407,7 +1411,7 @@ TEST_P(MemoryAllocatorTest, contiguousAllocation) { ASSERT_EQ(movedAllocation.pool(), pool_.get()); *allocation = std::move(movedAllocation); ASSERT_TRUE(!allocation->empty()); // NOLINT - ASSERT_TRUE(movedAllocation.empty()); + ASSERT_TRUE(movedAllocation.empty()); // NOLINT ASSERT_EQ(allocation->pool(), pool_.get()); } ASSERT_THROW(allocation->setPool(pool_.get()), VeloxRuntimeError); diff --git a/velox/common/process/CMakeLists.txt b/velox/common/process/CMakeLists.txt index 22182ed58f12..af0bedd5ce4f 100644 --- a/velox/common/process/CMakeLists.txt +++ b/velox/common/process/CMakeLists.txt @@ -13,7 +13,7 @@ # limitations under the License. add_library(velox_process ProcessBase.cpp StackTrace.cpp ThreadDebugInfo.cpp - TraceContext.cpp) + TraceContext.cpp TraceHistory.cpp) target_link_libraries( velox_process diff --git a/velox/common/process/TraceContext.cpp b/velox/common/process/TraceContext.cpp index cad158f48ee7..b0ee5b724097 100644 --- a/velox/common/process/TraceContext.cpp +++ b/velox/common/process/TraceContext.cpp @@ -16,23 +16,34 @@ #include "velox/common/process/TraceContext.h" +#include "velox/common/process/TraceHistory.h" + #include namespace facebook::velox::process { namespace { -folly::Synchronized>& traceMap() { - static folly::Synchronized> - staticTraceMap; - return staticTraceMap; -} + +// We use thread local instead lock here since the critical path is on write +// side. +auto registry = std::make_shared(); +thread_local auto threadLocalTraceData = + std::make_shared(registry); + } // namespace TraceContext::TraceContext(std::string label, bool isTemporary) : label_(std::move(label)), enterTime_(std::chrono::steady_clock::now()), - isTemporary_(isTemporary) { - traceMap().withWLock([&](auto& counts) { + isTemporary_(isTemporary), + traceData_(threadLocalTraceData) { + TraceHistory::push([&](auto& entry) { + entry.time = enterTime_; + entry.file = __FILE__; + entry.line = __LINE__; + snprintf(entry.label, entry.kLabelCapacity, "%s", label_.c_str()); + }); + traceData_->withValue([&](auto& counts) { auto& data = counts[label_]; ++data.numThreads; if (data.numThreads == 1) { @@ -43,17 +54,18 @@ TraceContext::TraceContext(std::string label, bool isTemporary) } TraceContext::~TraceContext() { - traceMap().withWLock([&](auto& counts) { - auto& data = counts[label_]; - --data.numThreads; + traceData_->withValue([&](auto& counts) { + auto it = counts.find(label_); + auto& data = it->second; + if (--data.numThreads == 0 && isTemporary_) { + counts.erase(it); + return; + } auto ms = std::chrono::duration_cast( std::chrono::steady_clock::now() - enterTime_) .count(); data.totalMs += ms; data.maxMs = std::max(data.maxMs, ms); - if (!data.numThreads && isTemporary_) { - counts.erase(label_); - } }); } @@ -61,27 +73,39 @@ TraceContext::~TraceContext() { std::string TraceContext::statusLine() { std::stringstream out; auto now = std::chrono::steady_clock::now(); - traceMap().withRLock([&](auto& counts) { - for (auto& pair : counts) { - if (pair.second.numThreads) { - auto continued = std::chrono::duration_cast( - now - pair.second.startTime) - .count(); - - out << pair.first << "=" << pair.second.numThreads << " entered " - << pair.second.numEnters << " avg ms " - << (pair.second.totalMs / pair.second.numEnters) << " max ms " - << pair.second.maxMs << " continuous for " << continued - << std::endl; - } + auto counts = status(); + for (auto& [label, data] : counts) { + if (data.numThreads > 0) { + auto continued = std::chrono::duration_cast( + now - data.startTime) + .count(); + out << label << ": numThreads=" << data.numThreads + << " numEnters=" << data.numEnters + << " avgMs=" << (data.totalMs / data.numEnters) + << " maxMs=" << data.maxMs << " continued=" << continued << std::endl; } - }); + } return out.str(); } // static -std::unordered_map TraceContext::status() { - return traceMap().withRLock([&](auto& map) { return map; }); +folly::F14FastMap TraceContext::status() { + folly::F14FastMap total; + registry->forAllValues([&](auto& counts) { + for (auto& [k, v] : counts) { + auto& sofar = total[k]; + if (sofar.numEnters == 0) { + sofar.startTime = v.startTime; + } else if (v.numEnters > 0) { + sofar.startTime = std::min(sofar.startTime, v.startTime); + } + sofar.numThreads += v.numThreads; + sofar.numEnters += v.numEnters; + sofar.totalMs += v.totalMs; + sofar.maxMs = std::max(sofar.maxMs, v.maxMs); + } + }); + return total; } } // namespace facebook::velox::process diff --git a/velox/common/process/TraceContext.h b/velox/common/process/TraceContext.h index c3d3a18be142..6e718515b58d 100644 --- a/velox/common/process/TraceContext.h +++ b/velox/common/process/TraceContext.h @@ -16,11 +16,13 @@ #pragma once +#include "velox/common/process/ThreadLocalRegistry.h" + #include #include #include -#include +#include namespace facebook::velox::process { @@ -47,6 +49,8 @@ struct TraceData { // produces a concise report of what the system is doing at any one // time. This is good for diagnosing crashes or hangs which are // difficult to figure out from stacks in a core dump. +// +// NOTE: TraceContext is not sharable between different threads. class TraceContext { public: // Starts a trace context. isTemporary is false if this is a generic @@ -56,6 +60,9 @@ class TraceContext { // which the record should be dropped once the last thread finishes. explicit TraceContext(std::string label, bool isTemporary = false); + TraceContext(const TraceContext&) = delete; + TraceContext& operator=(const TraceContext&) = delete; + ~TraceContext(); // Produces a human readable report of all TraceContexts in existence at the @@ -63,12 +70,18 @@ class TraceContext { static std::string statusLine(); // Returns a copy of the trace status. - static std::unordered_map status(); + static folly::F14FastMap status(); + + // Implementation detail type. Made public to be available with + // std::make_shared. Do not use outside this class. + using Registry = + ThreadLocalRegistry>; private: const std::string label_; const std::chrono::steady_clock::time_point enterTime_; const bool isTemporary_; + std::shared_ptr traceData_; }; } // namespace facebook::velox::process diff --git a/velox/common/process/TraceHistory.cpp b/velox/common/process/TraceHistory.cpp new file mode 100644 index 000000000000..bf7524590802 --- /dev/null +++ b/velox/common/process/TraceHistory.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/process/TraceHistory.h" + +#include + +#include + +namespace facebook::velox::process { + +namespace { +auto registry = std::make_shared>(); +} + +namespace detail { +thread_local ThreadLocalRegistry::Reference traceHistory( + registry); +} + +TraceHistory::TraceHistory() + : threadId_(std::this_thread::get_id()), osTid_(folly::getOSThreadID()) {} + +std::vector TraceHistory::listAll() { + std::vector results; + registry->forAllValues([&](auto& history) { + EntriesWithThreadInfo result; + result.threadId = history.threadId_; + result.osTid = history.osTid_; + for (int i = 0; i < kCapacity; ++i) { + const int j = (history.index_ + kCapacity - 1 - i) % kCapacity; + if (!populated(history.data_[j])) { + break; + } + result.entries.push_back(history.data_[j]); + } + std::reverse(result.entries.begin(), result.entries.end()); + results.push_back(std::move(result)); + }); + return results; +} + +} // namespace facebook::velox::process diff --git a/velox/common/process/TraceHistory.h b/velox/common/process/TraceHistory.h new file mode 100644 index 000000000000..bcee2cec69d7 --- /dev/null +++ b/velox/common/process/TraceHistory.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/process/ThreadLocalRegistry.h" + +#include +#include +#include +#include +#include + +/// Push an entry to the history ring buffer with a label from format string +/// (same as printf) and optional arguments. +#define VELOX_TRACE_HISTORY_PUSH(_format, ...) \ + ::facebook::velox::process::TraceHistory::push([&](auto& entry) { \ + entry.time = ::std::chrono::steady_clock::now(); \ + entry.file = __FILE__; \ + entry.line = __LINE__; \ + ::snprintf(entry.label, entry.kLabelCapacity, _format, ##__VA_ARGS__); \ + }) + +namespace facebook::velox::process { + +class TraceHistory; + +namespace detail { +extern thread_local ThreadLocalRegistry::Reference traceHistory; +} + +/// Keep list of labels in a ring buffer that is fixed sized and thread local. +class TraceHistory { + public: + TraceHistory(); + + /// An entry with tracing information and custom label. + struct Entry { + std::chrono::steady_clock::time_point time; + const char* file; + int32_t line; + + static constexpr int kLabelCapacity = + 64 - sizeof(time) - sizeof(file) - sizeof(line); + char label[kLabelCapacity]; + }; + + /// NOTE: usually VELOX_TRACE_HISTORY_PUSH should be used instead of calling + /// this function directly. + /// + /// Add a new entry to the thread local instance. If there are more than + /// `kCapacity' entries, overwrite the oldest ones. All the mutation on the + /// new entry should be done in the functor `init'. + template + static void push(F&& init) { + detail::traceHistory.withValue( + [init = std::forward(init)](auto& history) { + auto& entry = history.data_[history.index_]; + init(entry); + assert(populated(entry)); + history.index_ = (history.index_ + 1) % kCapacity; + }); + } + + /// All entries in a specific thread. + struct EntriesWithThreadInfo { + std::thread::id threadId; + uint64_t osTid; + std::vector entries; + }; + + /// List all entries from all threads. + static std::vector listAll(); + + /// Keep the last `kCapacity' entries per thread. Must be a power of 2. + static constexpr int kCapacity = 16; + + private: + static_assert((kCapacity & (kCapacity - 1)) == 0); + static_assert(sizeof(Entry) == 64); + + static bool populated(const Entry& entry) { + return entry.file != nullptr; + } + + alignas(64) Entry data_[kCapacity]{}; + const std::thread::id threadId_; + const uint64_t osTid_; + int index_ = 0; +}; + +} // namespace facebook::velox::process diff --git a/velox/common/process/tests/CMakeLists.txt b/velox/common/process/tests/CMakeLists.txt index 836e397466a2..2fce354e31ec 100644 --- a/velox/common/process/tests/CMakeLists.txt +++ b/velox/common/process/tests/CMakeLists.txt @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_process_test TraceContextTest.cpp - ThreadLocalRegistryTest.cpp) +add_executable( + velox_process_test TraceContextTest.cpp ThreadLocalRegistryTest.cpp + TraceHistoryTest.cpp) add_test(velox_process_test velox_process_test) diff --git a/velox/common/process/tests/TraceContextTest.cpp b/velox/common/process/tests/TraceContextTest.cpp index cfa021432a8a..130055e568fa 100644 --- a/velox/common/process/tests/TraceContextTest.cpp +++ b/velox/common/process/tests/TraceContextTest.cpp @@ -15,33 +15,125 @@ */ #include "velox/common/process/TraceContext.h" +#include "velox/common/process/TraceHistory.h" + #include +#include +#include +#include #include + #include -using namespace facebook::velox::process; +namespace facebook::velox::process { +namespace { + +class TraceContextTest : public testing::Test { + public: + void SetUp() override { + ASSERT_TRUE(TraceContext::status().empty()); + } -TEST(TraceContextTest, basic) { - constexpr int32_t kNumThreads = 10; + void TearDown() override { + ASSERT_TRUE(TraceContext::status().empty()); + } +}; + +TEST_F(TraceContextTest, basic) { + constexpr int kNumThreads = 3; std::vector threads; + folly::Baton<> batons[2][kNumThreads]; + folly::Latch latches[2] = { + folly::Latch(kNumThreads), + folly::Latch(kNumThreads), + }; threads.reserve(kNumThreads); - for (int32_t i = 0; i < kNumThreads; ++i) { - threads.push_back(std::thread([i]() { - TraceContext trace1("process data"); - TraceContext trace2(fmt::format("Process chunk {}", i), true); - std::this_thread::sleep_for(std::chrono::milliseconds(3)); - })); + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back([&, i]() { + { + TraceContext trace1("process data"); + TraceContext trace2(fmt::format("Process chunk {}", i), true); + latches[0].count_down(); + batons[0][i].wait(); + } + latches[1].count_down(); + batons[1][i].wait(); + }); + } + latches[0].wait(); + auto status = TraceContext::status(); + ASSERT_EQ(1 + kNumThreads, status.size()); + ASSERT_EQ(kNumThreads, status.at("process data").numThreads); + for (int i = 0; i < kNumThreads; ++i) { + ASSERT_EQ(1, status.at(fmt::format("Process chunk {}", i)).numThreads); } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - LOG(INFO) << TraceContext::statusLine(); - for (auto& thread : threads) { - thread.join(); + for (int i = 0; i < kNumThreads; ++i) { + batons[0][i].post(); } - LOG(INFO) << TraceContext::statusLine(); - // We expect one entry for "process data". The temporary entries - // are deleted after the treads complete. - auto after = TraceContext::status(); - EXPECT_EQ(1, after.size()); - EXPECT_EQ(kNumThreads, after["process data"].numEnters); - EXPECT_EQ(0, after["process data"].numThreads); + latches[1].wait(); + status = TraceContext::status(); + ASSERT_EQ(1, status.size()); + ASSERT_EQ(0, status.at("process data").numThreads); + ASSERT_EQ(kNumThreads, status.at("process data").numEnters); + for (int i = 0; i < kNumThreads; ++i) { + batons[1][i].post(); + threads[i].join(); + } +} + +TEST_F(TraceContextTest, traceHistory) { + std::thread([] { + TraceContext trace("test"); + TraceContext trace2( + std::string(TraceHistory::Entry::kLabelCapacity + 10, 'x')); + auto results = TraceHistory::listAll(); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].entries.size(), 2); + ASSERT_STREQ(results[0].entries[0].label, "test"); + ASSERT_EQ( + results[0].entries[1].label, + std::string(TraceHistory::Entry::kLabelCapacity - 1, 'x')); + }).join(); } + +TEST_F(TraceContextTest, transferBetweenThreads) { + auto [promise, future] = + folly::makePromiseContract>(); + folly::Baton<> batons[2]; + std::chrono::steady_clock::time_point timeLow, timeHigh; + std::thread receiver([&, future = std::move(future)]() mutable { + auto trace = std::move(future).get(std::chrono::seconds(1)); + { + SCOPE_EXIT { + batons[0].post(); + }; + auto status = TraceContext::status(); + ASSERT_EQ(1, status.size()); + auto& data = status.at("test"); + ASSERT_EQ(data.numThreads, 1); + ASSERT_EQ(data.numEnters, 1); + ASSERT_LE(timeLow, data.startTime); + ASSERT_LE(data.startTime, timeHigh); + } + batons[1].wait(); + auto status = TraceContext::status(); + ASSERT_EQ(1, status.size()); + auto& data = status.at("test"); + ASSERT_EQ(data.numThreads, 1); + ASSERT_EQ(data.numEnters, 1); + ASSERT_LE(timeLow, data.startTime); + ASSERT_LE(data.startTime, timeHigh); + }); + timeLow = std::chrono::steady_clock::now(); + std::thread([&, promise = std::move(promise)]() mutable { + auto trace = std::make_unique("test"); + timeHigh = std::chrono::steady_clock::now(); + promise.setValue(std::move(trace)); + batons[0].wait(); + }).join(); + batons[1].post(); + receiver.join(); +} + +} // namespace +} // namespace facebook::velox::process diff --git a/velox/common/process/tests/TraceHistoryTest.cpp b/velox/common/process/tests/TraceHistoryTest.cpp new file mode 100644 index 000000000000..754fe6f389c3 --- /dev/null +++ b/velox/common/process/tests/TraceHistoryTest.cpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/process/TraceHistory.h" + +#include +#include +#include +#include + +namespace facebook::velox::process { +namespace { + +class TraceHistoryTest : public testing::Test { + public: + void SetUp() override { + ASSERT_TRUE(TraceHistory::listAll().empty()); + } + + void TearDown() override { + ASSERT_TRUE(TraceHistory::listAll().empty()); + } +}; + +TEST_F(TraceHistoryTest, basic) { + std::thread([] { + auto timeLow = std::chrono::steady_clock::now(); + constexpr int kStartLine = __LINE__; + for (int i = 0; i < TraceHistory::kCapacity + 10; ++i) { + VELOX_TRACE_HISTORY_PUSH("Test %d", i); + } + auto timeHigh = std::chrono::steady_clock::now(); + auto results = TraceHistory::listAll(); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].threadId, std::this_thread::get_id()); + ASSERT_EQ(results[0].osTid, folly::getOSThreadID()); + ASSERT_EQ(results[0].entries.size(), TraceHistory::kCapacity); + auto lastTime = timeLow; + for (int i = 0; i < TraceHistory::kCapacity; ++i) { + auto& entry = results[0].entries[i]; + ASSERT_EQ(entry.line, kStartLine + 2); + ASSERT_STREQ( + entry.file + strlen(entry.file) - 20, "TraceHistoryTest.cpp"); + ASSERT_LE(lastTime, entry.time); + lastTime = entry.time; + ASSERT_EQ(strncmp(entry.label, "Test ", 5), 0); + ASSERT_EQ(atoi(entry.label + 5), i + 10); + } + ASSERT_LE(lastTime, timeHigh); + }).join(); +} + +TEST_F(TraceHistoryTest, multiThread) { + constexpr int kNumThreads = 3; + folly::Latch latch(kNumThreads); + folly::Baton<> batons[kNumThreads]; + std::vector threads; + auto timeLow = std::chrono::steady_clock::now(); + constexpr int kStartLine = __LINE__; + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back([&, i] { + VELOX_TRACE_HISTORY_PUSH("Test"); + VELOX_TRACE_HISTORY_PUSH("Test %d", i); + latch.count_down(); + batons[i].wait(); + }); + } + latch.wait(); + auto timeHigh = std::chrono::steady_clock::now(); + auto results = TraceHistory::listAll(); + ASSERT_EQ(results.size(), kNumThreads); + for (auto& result : results) { + auto threadIndex = + std::find_if( + threads.begin(), + threads.end(), + [&](auto& t) { return t.get_id() == result.threadId; }) - + threads.begin(); + ASSERT_EQ(result.entries.size(), 2); + ASSERT_EQ(result.entries[0].line, kStartLine + 3); + ASSERT_EQ(result.entries[1].line, kStartLine + 4); + ASSERT_STREQ(result.entries[0].label, "Test"); + ASSERT_EQ(result.entries[1].label, fmt::format("Test {}", threadIndex)); + for (auto& entry : result.entries) { + ASSERT_LE(timeLow, entry.time); + ASSERT_LE(entry.time, timeHigh); + ASSERT_TRUE(entry.file); + ASSERT_STREQ( + entry.file + strlen(entry.file) - 20, "TraceHistoryTest.cpp"); + } + } + for (int i = 0; i < kNumThreads; ++i) { + ASSERT_EQ(TraceHistory::listAll().size(), kNumThreads - i); + batons[i].post(); + threads[i].join(); + } +} + +TEST_F(TraceHistoryTest, largeLabel) { + std::thread([] { + VELOX_TRACE_HISTORY_PUSH( + "%s", + std::string(TraceHistory::Entry::kLabelCapacity + 10, 'x').c_str()); + auto results = TraceHistory::listAll(); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].entries.size(), 1); + ASSERT_EQ( + results[0].entries[0].label, + std::string(TraceHistory::Entry::kLabelCapacity - 1, 'x')); + }).join(); +} + +} // namespace +} // namespace facebook::velox::process diff --git a/velox/connectors/hive/HiveConnectorUtil.cpp b/velox/connectors/hive/HiveConnectorUtil.cpp index 4fe174ed7949..464d7560691a 100644 --- a/velox/connectors/hive/HiveConnectorUtil.cpp +++ b/velox/connectors/hive/HiveConnectorUtil.cpp @@ -377,7 +377,8 @@ std::shared_ptr makeScanSpec( } std::unique_ptr parseSerdeParameters( - const std::unordered_map& serdeParameters) { + const std::unordered_map& serdeParameters, + const std::unordered_map& tableParameters) { auto fieldIt = serdeParameters.find(dwio::common::SerDeOptions::kFieldDelim); if (fieldIt == serdeParameters.end()) { fieldIt = serdeParameters.find("serialization.format"); @@ -393,9 +394,13 @@ std::unique_ptr parseSerdeParameters( auto mapKeyIt = serdeParameters.find(dwio::common::SerDeOptions::kMapKeyDelim); + auto nullStringIt = tableParameters.find( + dwio::common::TableParameter::kSerializationNullFormat); + if (fieldIt == serdeParameters.end() && collectionIt == serdeParameters.end() && - mapKeyIt == serdeParameters.end()) { + mapKeyIt == serdeParameters.end() && + nullStringIt == tableParameters.end()) { return nullptr; } @@ -413,6 +418,7 @@ std::unique_ptr parseSerdeParameters( } auto serDeOptions = std::make_unique( fieldDelim, collectionDelim, mapKeyDelim); + serDeOptions->nullString = nullStringIt->second; return serDeOptions; } @@ -420,15 +426,15 @@ void configureReaderOptions( dwio::common::ReaderOptions& readerOptions, const std::shared_ptr& hiveConfig, const Config* sessionProperties, - const RowTypePtr& fileSchema, - std::shared_ptr hiveSplit) { + const std::shared_ptr& hiveTableHandle, + const std::shared_ptr& hiveSplit) { readerOptions.setMaxCoalesceBytes(hiveConfig->maxCoalescedBytes()); readerOptions.setMaxCoalesceDistance(hiveConfig->maxCoalescedDistanceBytes()); readerOptions.setFileColumnNamesReadAsLowerCase( hiveConfig->isFileColumnNamesReadAsLowerCase(sessionProperties)); readerOptions.setUseColumnNamesForColumnMapping( hiveConfig->isOrcUseColumnNames(sessionProperties)); - readerOptions.setFileSchema(fileSchema); + readerOptions.setFileSchema(hiveTableHandle->dataColumns()); readerOptions.setFooterEstimatedSize(hiveConfig->footerEstimatedSize()); readerOptions.setFilePreloadThreshold(hiveConfig->filePreloadThreshold()); @@ -439,7 +445,8 @@ void configureReaderOptions( dwio::common::toString(readerOptions.getFileFormat()), dwio::common::toString(hiveSplit->fileFormat)); } else { - auto serDeOptions = parseSerdeParameters(hiveSplit->serdeParameters); + auto serDeOptions = parseSerdeParameters( + hiveSplit->serdeParameters, hiveTableHandle->tableParameters()); if (serDeOptions) { readerOptions.setSerDeOptions(*serDeOptions); } diff --git a/velox/connectors/hive/HiveConnectorUtil.h b/velox/connectors/hive/HiveConnectorUtil.h index 51335f09e76a..67426bef78ca 100644 --- a/velox/connectors/hive/HiveConnectorUtil.h +++ b/velox/connectors/hive/HiveConnectorUtil.h @@ -26,6 +26,7 @@ namespace facebook::velox::connector::hive { class HiveColumnHandle; +class HiveTableHandle; class HiveConfig; struct HiveConnectorSplit; @@ -57,8 +58,8 @@ void configureReaderOptions( dwio::common::ReaderOptions& readerOptions, const std::shared_ptr& config, const Config* sessionProperties, - const RowTypePtr& fileSchema, - std::shared_ptr hiveSplit); + const std::shared_ptr& hiveTableHandle, + const std::shared_ptr& hiveSplit); void configureRowReaderOptions( dwio::common::RowReaderOptions& rowReaderOptions, diff --git a/velox/connectors/hive/SplitReader.cpp b/velox/connectors/hive/SplitReader.cpp index b6cce9860087..92376e566d38 100644 --- a/velox/connectors/hive/SplitReader.cpp +++ b/velox/connectors/hive/SplitReader.cpp @@ -82,7 +82,7 @@ void SplitReader::configureReaderOptions() { baseReaderOpts_, hiveConfig_, connectorQueryCtx_->sessionProperties(), - hiveTableHandle_->dataColumns(), + hiveTableHandle_, hiveSplit_); } diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.cpp index de86aa7f4386..4382fabb84a6 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.cpp @@ -13,20 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + #include "velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h" -#include "velox/common/file/File.h" -#include "velox/connectors/hive/HiveConfig.h" -#include "velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h" -#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" -#include "velox/core/Config.h" #include #include #include #include +#include "velox/common/file/File.h" +#include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h" +#include "velox/core/Config.h" + namespace facebook::velox::filesystems::abfs { using namespace Azure::Storage::Blobs; + class AbfsConfig { public: AbfsConfig(const Config* config) : config_(config) {} @@ -49,14 +52,12 @@ class AbfsReadFile::Impl { constexpr static uint64_t kReadConcurrency = 8; public: - explicit Impl(const std::string& path, const std::string& connectStr) - : path_(path), connectStr_(connectStr) { - auto abfsAccount = AbfsAccount(path_); - fileSystem_ = abfsAccount.fileSystem(); + explicit Impl(const std::string& path, const std::string& connectStr) { + auto abfsAccount = AbfsAccount(path); fileName_ = abfsAccount.filePath(); fileClient_ = std::make_unique(BlobClient::CreateFromConnectionString( - connectStr_, fileSystem_, fileName_)); + connectStr, abfsAccount.fileSystem(), fileName_)); } void initialize() { @@ -153,9 +154,6 @@ class AbfsReadFile::Impl { reinterpret_cast(position), length); } - const std::string path_; - const std::string connectStr_; - std::string fileSystem_; std::string fileName_; std::unique_ptr fileClient_; @@ -250,4 +248,13 @@ std::unique_ptr AbfsFileSystem::openFileForRead( abfsfile->initialize(); return abfsfile; } + +std::unique_ptr AbfsFileSystem::openFileForWrite( + std::string_view path, + const FileOptions& /*unused*/) { + auto abfsfile = std::make_unique( + std::string(path), impl_->connectionString(std::string(path))); + abfsfile->initialize(); + return abfsfile; +} } // namespace facebook::velox::filesystems::abfs diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h b/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h index f67789243545..4b8ec74d5954 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h @@ -44,9 +44,7 @@ class AbfsFileSystem : public FileSystem { std::unique_ptr openFileForWrite( std::string_view path, - const FileOptions& options = {}) override { - VELOX_UNSUPPORTED("write for abfs not implemented"); - } + const FileOptions& options = {}) override; void rename( std::string_view path, diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h index 609623fe47a4..2af0f4239009 100644 --- a/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h @@ -66,11 +66,15 @@ inline const std::string throwStorageExceptionWithOperationDetails( std::string operation, std::string path, Azure::Storage::StorageException& error) { - VELOX_FAIL( + const auto errMsg = fmt::format( "Operation '{}' to path '{}' encountered azure storage exception, Details: '{}'.", operation, path, error.what()); + if (error.StatusCode == Azure::Core::Http::HttpStatusCode::NotFound) { + VELOX_FILE_NOT_FOUND_ERROR(errMsg); + } + VELOX_FAIL(errMsg); } } // namespace facebook::velox::filesystems::abfs diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.cpp b/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.cpp new file mode 100644 index 000000000000..c231954258f5 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.cpp @@ -0,0 +1,169 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h" + +#include + +namespace facebook::velox::filesystems::abfs { +class BlobStorageFileClient final : public IBlobStorageFileClient { + public: + BlobStorageFileClient(std::unique_ptr client) + : client_(std::move(client)) {} + + void create() override { + client_->Create(); + } + + PathProperties getProperties() override { + return client_->GetProperties().Value; + } + + void append(const uint8_t* buffer, size_t size, uint64_t offset) override { + auto bodyStream = Azure::Core::IO::MemoryBodyStream(buffer, size); + client_->Append(bodyStream, offset); + } + + void flush(uint64_t position) override { + client_->Flush(position); + } + + void close() override { + // do nothing. + } + + private: + const std::unique_ptr client_; +}; + +class AbfsWriteFile::Impl { + public: + explicit Impl(const std::string& path, const std::string& connectStr) + : path_(path), connectStr_(connectStr) { + // Make it a no-op if invoked twice. + if (position_ != -1) { + return; + } + position_ = 0; + } + + void initialize() { + if (!blobStorageFileClient_) { + auto abfsAccount = AbfsAccount(path_); + blobStorageFileClient_ = std::make_unique( + std::make_unique( + DataLakeFileClient::CreateFromConnectionString( + connectStr_, + abfsAccount.fileSystem(), + abfsAccount.filePath()))); + } + + VELOX_CHECK(!checkIfFileExists(), "File already exists"); + blobStorageFileClient_->create(); + } + + void testingSetFileClient( + const std::shared_ptr& blobStorageManager) { + blobStorageFileClient_ = blobStorageManager; + } + + void close() { + if (!closed_) { + flush(); + blobStorageFileClient_->close(); + closed_ = true; + } + } + + void flush() { + if (!closed_) { + blobStorageFileClient_->flush(position_); + } + } + + void append(std::string_view data) { + VELOX_CHECK(!closed_, "File is not open"); + if (data.size() == 0) { + return; + } + append(data.data(), data.size()); + } + + uint64_t size() const { + return blobStorageFileClient_->getProperties().FileSize; + } + + void append(const char* buffer, size_t size) { + blobStorageFileClient_->append( + reinterpret_cast(buffer), size, position_); + position_ += size; + } + + private: + bool checkIfFileExists() { + try { + blobStorageFileClient_->getProperties(); + return true; + } catch (Azure::Storage::StorageException& e) { + if (e.StatusCode == Azure::Core::Http::HttpStatusCode::NotFound) { + return false; + } else { + throwStorageExceptionWithOperationDetails("GetProperties", path_, e); + } + } + } + + const std::string path_; + const std::string connectStr_; + std::string fileSystem_; + std::string fileName_; + std::shared_ptr blobStorageFileClient_; + + uint64_t position_ = -1; + bool closed_ = false; +}; + +AbfsWriteFile::AbfsWriteFile( + const std::string& path, + const std::string& connectStr) { + impl_ = std::make_shared(path, connectStr); +} + +void AbfsWriteFile::initialize() { + impl_->initialize(); +} + +void AbfsWriteFile::close() { + impl_->close(); +} + +void AbfsWriteFile::flush() { + impl_->flush(); +} + +void AbfsWriteFile::append(std::string_view data) { + impl_->append(data); +} + +uint64_t AbfsWriteFile::size() const { + return impl_->size(); +} + +void AbfsWriteFile::testingSetFileClient( + const std::shared_ptr& fileClient) { + impl_->testingSetFileClient(fileClient); +} +} // namespace facebook::velox::filesystems::abfs diff --git a/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h b/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h new file mode 100644 index 000000000000..72549720344f --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/file/File.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsUtil.h" + +namespace Azure::Storage::Files::DataLake::Models { +class PathProperties; +} + +namespace facebook::velox::filesystems::abfs { +using namespace Azure::Storage::Files::DataLake; +using namespace Azure::Storage::Files::DataLake::Models; + +/* + * We are using the DFS (Data Lake Storage) endpoint for Azure Blob File write + * operations because the DFS endpoint is designed to be compatible with file + * operation semantics, such as `Append` to a file and file `Flush` operations. + * The legacy Blob endpoint can only be used for blob level append and flush + * operations. When using the Blob endpoint, we would need to manually manage + * the creation, appending, and committing of file-related blocks. + * + * However, the Azurite Simulator does not yet support the DFS endpoint. + * (For more information, see https://github.com/Azure/Azurite/issues/553 and + * https://github.com/Azure/Azurite/issues/409). + * You can find a comparison between DFS and Blob endpoints here: + * https://github.com/Azure/Azurite/wiki/ADLS-Gen2-Implementation-Guidance + * + * To facilitate unit testing of file write scenarios, we define the + * IBlobStorageFileClient here, which can be mocked during testing. + */ +class IBlobStorageFileClient { + public: + virtual void create() = 0; + virtual PathProperties getProperties() = 0; + virtual void append(const uint8_t* buffer, size_t size, uint64_t offset) = 0; + virtual void flush(uint64_t position) = 0; + virtual void close() = 0; +}; + +/// Implementation of abfs write file. Nothing written to the file should be +/// read back until it is closed. +class AbfsWriteFile : public WriteFile { + public: + constexpr static uint64_t kNaturalWriteSize = 8 << 20; // 8M + /// The constructor. + /// @param path The file path to write. + /// @param connectStr the connection string used to auth the storage account. + AbfsWriteFile(const std::string& path, const std::string& connectStr); + + /// check any issue reading file. + void initialize(); + + /// Get the file size. + uint64_t size() const override; + + /// Flush the data. + void flush() override; + + /// Write the data by append mode. + void append(std::string_view data) override; + + /// Close the file. + void close() override; + + /// Used by tests to override the FileSystem client. + void testingSetFileClient( + const std::shared_ptr& fileClient); + + protected: + class Impl; + std::shared_ptr impl_; +}; +} // namespace facebook::velox::filesystems::abfs diff --git a/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt b/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt index 74e5a8d81c91..cd4cee572e5e 100644 --- a/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/abfs/CMakeLists.txt @@ -17,7 +17,8 @@ add_library(velox_abfs RegisterAbfsFileSystem.cpp) if(VELOX_ENABLE_ABFS) - target_sources(velox_abfs PRIVATE AbfsFileSystem.cpp AbfsUtils.cpp) + target_sources(velox_abfs PRIVATE AbfsFileSystem.cpp AbfsWriteFile.cpp + AbfsUtils.cpp) target_link_libraries( velox_abfs PUBLIC velox_file @@ -25,6 +26,7 @@ if(VELOX_ENABLE_ABFS) velox_hive_config velox_dwio_common_exception Azure::azure-storage-blobs + Azure::azure-storage-files-datalake Folly::folly glog::glog fmt::fmt) diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp index 98354f9dd52c..44fa5a3c8508 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AbfsFileSystemTest.cpp @@ -14,29 +14,32 @@ * limitations under the License. */ -#include "velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h" -#include "gtest/gtest.h" +#include +#include +#include +#include + #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/File.h" #include "velox/common/file/FileSystems.h" #include "velox/connectors/hive/FileHandle.h" #include "velox/connectors/hive/HiveConfig.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/abfs/AbfsReadFile.h" +#include "velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h" #include "velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h" +#include "velox/connectors/hive/storage_adapters/abfs/tests/MockBlobStorageFileClient.h" #include "velox/exec/tests/utils/PortUtil.h" #include "velox/exec/tests/utils/TempFilePath.h" -#include -#include - using namespace facebook::velox; - +using namespace facebook::velox::filesystems::abfs; using ::facebook::velox::common::Region; constexpr int kOneMB = 1 << 20; static const std::string filePath = "test_file.txt"; static const std::string fullFilePath = - facebook::velox::filesystems::test::AzuriteABFSEndpoint + filePath; + filesystems::test::AzuriteABFSEndpoint + filePath; class AbfsFileSystemTest : public testing::Test { public: @@ -55,14 +58,11 @@ class AbfsFileSystemTest : public testing::Test { } public: - std::shared_ptr - azuriteServer; + std::shared_ptr azuriteServer; void SetUp() override { auto port = facebook::velox::exec::test::getFreePort(); - azuriteServer = - std::make_shared( - port); + azuriteServer = std::make_shared(port); azuriteServer->start(); auto tempFile = createFile(); azuriteServer->addFile(tempFile->path, filePath); @@ -72,13 +72,50 @@ class AbfsFileSystemTest : public testing::Test { azuriteServer->stop(); } + std::unique_ptr openFileForWrite( + std::string_view path, + std::shared_ptr client) { + auto abfsfile = std::make_unique( + std::string(path), azuriteServer->connectionStr()); + abfsfile->testingSetFileClient(client); + abfsfile->initialize(); + return abfsfile; + } + + static std::string generateRandomData(int size) { + static const char charset[] = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + std::string data(size, ' '); + + for (int i = 0; i < size; ++i) { + int index = rand() % (sizeof(charset) - 1); + data[i] = charset[index]; + } + + return data; + } + private: - static std::shared_ptr<::exec::test::TempFilePath> createFile() { + static std::shared_ptr<::exec::test::TempFilePath> createFile( + uint64_t size = -1) { auto tempFile = ::exec::test::TempFilePath::create(); - tempFile->append("aaaaa"); - tempFile->append("bbbbb"); - tempFile->append(std::string(kOneMB, 'c')); - tempFile->append("ddddd"); + if (size == -1) { + tempFile->append("aaaaa"); + tempFile->append("bbbbb"); + tempFile->append(std::string(kOneMB, 'c')); + tempFile->append("ddddd"); + } else { + const uint64_t totalSize = size * 1024 * 1024; + const uint64_t chunkSize = 5 * 1024 * 1024; + uint64_t remainingSize = totalSize; + while (remainingSize > 0) { + uint64_t dataSize = std::min(remainingSize, chunkSize); + std::string randomData = generateRandomData(dataSize); + tempFile->append(randomData); + remainingSize -= dataSize; + } + } return tempFile; } }; @@ -171,27 +208,54 @@ TEST_F(AbfsFileSystemTest, multipleThreadsWithReadFile) { } TEST_F(AbfsFileSystemTest, missingFile) { - try { - auto hiveConfig = AbfsFileSystemTest::hiveConfig( - {{"fs.azure.account.key.test.dfs.core.windows.net", - azuriteServer->connectionStr()}}); - const std::string abfsFile = - facebook::velox::filesystems::test::AzuriteABFSEndpoint + "test.txt"; - auto abfs = std::make_shared(hiveConfig); - auto readFile = abfs->openFileForRead(abfsFile); - FAIL() << "Expected VeloxException"; - } catch (VeloxException const& err) { - EXPECT_TRUE(err.message().find("404") != std::string::npos); - } -} - -TEST_F(AbfsFileSystemTest, openFileForWriteNotImplemented) { auto hiveConfig = AbfsFileSystemTest::hiveConfig( {{"fs.azure.account.key.test.dfs.core.windows.net", azuriteServer->connectionStr()}}); + const std::string abfsFile = + facebook::velox::filesystems::test::AzuriteABFSEndpoint + "test.txt"; auto abfs = std::make_shared(hiveConfig); + VELOX_ASSERT_RUNTIME_THROW_CODE( + abfs->openFileForRead(abfsFile), error_code::kFileNotFound, "404"); +} + +TEST_F(AbfsFileSystemTest, OpenFileForWriteTest) { + const std::string abfsFile = + filesystems::test::AzuriteABFSEndpoint + "writetest.txt"; + auto mockClient = + std::make_shared( + filesystems::test::MockBlobStorageFileClient()); + auto abfsWriteFile = openFileForWrite(abfsFile, mockClient); + EXPECT_EQ(abfsWriteFile->size(), 0); + std::string dataContent = ""; + uint64_t totalSize = 0; + std::string randomData = + AbfsFileSystemTest::generateRandomData(1 * 1024 * 1024); + for (int i = 0; i < 8; ++i) { + abfsWriteFile->append(randomData); + dataContent += randomData; + } + totalSize = randomData.size() * 8; + abfsWriteFile->flush(); + EXPECT_EQ(abfsWriteFile->size(), totalSize); + + randomData = AbfsFileSystemTest::generateRandomData(9 * 1024 * 1024); + dataContent += randomData; + abfsWriteFile->append(randomData); + totalSize += randomData.size(); + randomData = AbfsFileSystemTest::generateRandomData(2 * 1024 * 1024); + dataContent += randomData; + totalSize += randomData.size(); + abfsWriteFile->append(randomData); + abfsWriteFile->flush(); + EXPECT_EQ(abfsWriteFile->size(), totalSize); + abfsWriteFile->flush(); + abfsWriteFile->close(); + VELOX_ASSERT_THROW(abfsWriteFile->append("abc"), "File is not open"); VELOX_ASSERT_THROW( - abfs->openFileForWrite(fullFilePath), "write for abfs not implemented"); + openFileForWrite(abfsFile, mockClient), "File already exists"); + std::string fileContent = mockClient->readContent(); + ASSERT_EQ(fileContent.size(), dataContent.size()); + ASSERT_EQ(fileContent, dataContent); } TEST_F(AbfsFileSystemTest, renameNotImplemented) { @@ -247,9 +311,7 @@ TEST_F(AbfsFileSystemTest, credNotFOund) { const std::string abfsFile = std::string("abfs://test@test1.dfs.core.windows.net/test"); auto hiveConfig = AbfsFileSystemTest::hiveConfig({}); - auto abfs = - std::make_shared( - hiveConfig); + auto abfs = std::make_shared(hiveConfig); VELOX_ASSERT_THROW( abfs->openFileForRead(abfsFile), "Failed to find storage credentials"); } diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h b/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h index 165cb2767c11..4836183f3819 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h +++ b/velox/connectors/hive/storage_adapters/abfs/tests/AzuriteServer.h @@ -36,8 +36,8 @@ static const std::string AzuriteAccountKey{ "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="}; static const std::string AzuriteABFSEndpoint = fmt::format( "abfs://{}@{}.dfs.core.windows.net/", - AzuriteAccountName, - AzuriteContainerName); + AzuriteContainerName, + AzuriteAccountName); class AzuriteServer { public: diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt b/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt index 297a7db4e1bc..2fb451171b22 100644 --- a/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt +++ b/velox/connectors/hive/storage_adapters/abfs/tests/CMakeLists.txt @@ -13,7 +13,7 @@ # limitations under the License. add_executable(velox_abfs_test AbfsFileSystemTest.cpp AbfsUtilTest.cpp - AzuriteServer.cpp) + AzuriteServer.cpp MockBlobStorageFileClient.cpp) add_test(velox_abfs_test velox_abfs_test) target_link_libraries( velox_abfs_test @@ -26,4 +26,5 @@ target_link_libraries( velox_exec gtest gtest_main - Azure::azure-storage-blobs) + Azure::azure-storage-blobs + Azure::azure-storage-files-datalake) diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/MockBlobStorageFileClient.cpp b/velox/connectors/hive/storage_adapters/abfs/tests/MockBlobStorageFileClient.cpp new file mode 100644 index 000000000000..5f0cf9fa1efd --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/tests/MockBlobStorageFileClient.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/tests/MockBlobStorageFileClient.h" + +#include + +#include + +using namespace Azure::Storage::Files::DataLake; +namespace facebook::velox::filesystems::test { +void MockBlobStorageFileClient::create() { + fileStream_ = std::ofstream( + filePath_, + std::ios_base::out | std::ios_base::binary | std::ios_base::app); +} + +PathProperties MockBlobStorageFileClient::getProperties() { + if (!std::filesystem::exists(filePath_)) { + Azure::Storage::StorageException exp(filePath_ + "doesn't exists"); + exp.StatusCode = Azure::Core::Http::HttpStatusCode::NotFound; + throw exp; + } + std::ifstream file(filePath_, std::ios::binary | std::ios::ate); + uint64_t size = static_cast(file.tellg()); + PathProperties ret; + ret.FileSize = size; + return ret; +} + +void MockBlobStorageFileClient::append( + const uint8_t* buffer, + size_t size, + uint64_t offset) { + fileStream_.seekp(offset); + fileStream_.write(reinterpret_cast(buffer), size); +} + +void MockBlobStorageFileClient::flush(uint64_t position) { + fileStream_.flush(); +} + +void MockBlobStorageFileClient::close() { + fileStream_.flush(); + fileStream_.close(); +} +} // namespace facebook::velox::filesystems::test diff --git a/velox/connectors/hive/storage_adapters/abfs/tests/MockBlobStorageFileClient.h b/velox/connectors/hive/storage_adapters/abfs/tests/MockBlobStorageFileClient.h new file mode 100644 index 000000000000..046cb094c1b1 --- /dev/null +++ b/velox/connectors/hive/storage_adapters/abfs/tests/MockBlobStorageFileClient.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/connectors/hive/storage_adapters/abfs/AbfsWriteFile.h" + +#include "velox/exec/tests/utils/TempFilePath.h" + +using namespace facebook::velox; +using namespace facebook::velox::filesystems::abfs; + +namespace facebook::velox::filesystems::test { +// A mocked blob storage file client backend with local file store. +class MockBlobStorageFileClient : public IBlobStorageFileClient { + public: + MockBlobStorageFileClient() { + auto tempFile = ::exec::test::TempFilePath::create(); + filePath_ = tempFile->path; + } + + void create() override; + PathProperties getProperties() override; + void append(const uint8_t* buffer, size_t size, uint64_t offset) override; + void flush(uint64_t position) override; + void close() override; + + // for testing purpose to verify the written content if correct. + std::string readContent() { + std::ifstream inputFile(filePath_); + std::string content; + inputFile.seekg(0, std::ios::end); + std::streamsize fileSize = inputFile.tellg(); + inputFile.seekg(0, std::ios::beg); + content.resize(fileSize); + inputFile.read(&content[0], fileSize); + inputFile.close(); + return content; + } + + private: + std::string filePath_; + std::ofstream fileStream_; +}; +} // namespace facebook::velox::filesystems::test diff --git a/velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp b/velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp index 6eaada548701..cf9371d62a2c 100644 --- a/velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/GCSFileSystem.cpp @@ -48,14 +48,17 @@ inline void checkGCSStatus( const std::string& bucket, const std::string& key) { if (!outcome.ok()) { - auto error = outcome.error_info(); - VELOX_FAIL( + const auto errMsg = fmt::format( "{} due to: Path:'{}', SDK Error Type:{}, GCS Status Code:{}, Message:'{}'", errorMsgPrefix, gcsURI(bucket, key), - error.domain(), + outcome.error_info().domain(), getErrorStringFromGCSError(outcome.code()), outcome.message()); + if (outcome.code() == gc::StatusCode::kNotFound) { + VELOX_FILE_NOT_FOUND_ERROR(errMsg); + } + VELOX_FAIL(errMsg); } } diff --git a/velox/connectors/hive/storage_adapters/gcs/tests/GCSFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/gcs/tests/GCSFileSystemTest.cpp index a0cb3c7c5222..545af97ba4a2 100644 --- a/velox/connectors/hive/storage_adapters/gcs/tests/GCSFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/gcs/tests/GCSFileSystemTest.cpp @@ -285,30 +285,20 @@ TEST_F(GCSFileSystemTest, missingFile) { const std::string gcsFile = gcsURI(preexistingBucketName(), file); filesystems::GCSFileSystem gcfs(testGcsOptions()); gcfs.initializeClient(); - try { - gcfs.openFileForRead(gcsFile); - FAIL() << "Expected VeloxException"; - } catch (VeloxException const& err) { - EXPECT_THAT( - err.message(), - ::testing::HasSubstr( - "\\\"message\\\": \\\"Live version of object test1-gcs/newTest.txt does not exist.\\\"")); - } + VELOX_ASSERT_RUNTIME_THROW_CODE( + gcfs.openFileForRead(gcsFile), + error_code::kFileNotFound, + "\\\"message\\\": \\\"Live version of object test1-gcs/newTest.txt does not exist.\\\""); } TEST_F(GCSFileSystemTest, missingBucket) { filesystems::GCSFileSystem gcfs(testGcsOptions()); gcfs.initializeClient(); - try { - const char* gcsFile = "gs://dummy/foo.txt"; - gcfs.openFileForRead(gcsFile); - FAIL() << "Expected VeloxException"; - } catch (VeloxException const& err) { - EXPECT_THAT( - err.message(), - ::testing::HasSubstr( - "\\\"message\\\": \\\"Bucket dummy does not exist.\\\"")); - } + const char* gcsFile = "gs://dummy/foo.txt"; + VELOX_ASSERT_RUNTIME_THROW_CODE( + gcfs.openFileForRead(gcsFile), + error_code::kFileNotFound, + "\\\"message\\\": \\\"Bucket dummy does not exist.\\\""); } TEST_F(GCSFileSystemTest, credentialsConfig) { diff --git a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp index 84bbd217d474..9d99420c9d7e 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/HdfsReadFile.cpp @@ -23,11 +23,17 @@ namespace facebook::velox { HdfsReadFile::HdfsReadFile(hdfsFS hdfs, const std::string_view path) : hdfsClient_(hdfs), filePath_(path) { fileInfo_ = hdfsGetPathInfo(hdfsClient_, filePath_.data()); - VELOX_CHECK_NOT_NULL( - fileInfo_, - "Unable to get file path info for file: {}. got error: {}", - filePath_, - hdfsGetLastError()); + if (fileInfo_ == nullptr) { + auto error = hdfsGetLastError(); + auto errMsg = fmt::format( + "Unable to get file path info for file: {}. got error: {}", + filePath_, + error); + if (std::strstr(error, "FileNotFoundException") != nullptr) { + VELOX_FILE_NOT_FOUND_ERROR(errMsg); + } + VELOX_FAIL(errMsg); + } } HdfsReadFile::~HdfsReadFile() { diff --git a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp index f9d50f5985d6..ac8ed66f7c0f 100644 --- a/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/hdfs/tests/HdfsFileSystemTest.cpp @@ -218,20 +218,14 @@ TEST_F(HdfsFileSystemTest, oneFsInstanceForOneEndpoint) { } TEST_F(HdfsFileSystemTest, missingFileViaFileSystem) { - try { - auto memConfig = - std::make_shared(configurationValues); - auto hdfsFileSystem = - filesystems::getFileSystem(fullDestinationPath, memConfig); - auto readFile = hdfsFileSystem->openFileForRead( - "hdfs://localhost:7777/path/that/does/not/exist"); - FAIL() << "expected VeloxException"; - } catch (VeloxException const& error) { - EXPECT_THAT( - error.message(), - testing::HasSubstr( - "Unable to get file path info for file: /path/that/does/not/exist. got error: FileNotFoundException: Path /path/that/does/not/exist does not exist.")); - } + auto memConfig = std::make_shared(configurationValues); + auto hdfsFileSystem = + filesystems::getFileSystem(fullDestinationPath, memConfig); + VELOX_ASSERT_RUNTIME_THROW_CODE( + hdfsFileSystem->openFileForRead( + "hdfs://localhost:7777/path/that/does/not/exist"), + error_code::kFileNotFound, + "Unable to get file path info for file: /path/that/does/not/exist. got error: FileNotFoundException: Path /path/that/does/not/exist does not exist."); } TEST_F(HdfsFileSystemTest, missingHost) { diff --git a/velox/connectors/hive/storage_adapters/s3fs/S3Util.h b/velox/connectors/hive/storage_adapters/s3fs/S3Util.h index 09f42a47ffb5..ec67fb0c4175 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/S3Util.h +++ b/velox/connectors/hive/storage_adapters/s3fs/S3Util.h @@ -26,6 +26,8 @@ #include "velox/common/base/Exceptions.h" +#include + namespace facebook::velox { namespace { @@ -158,7 +160,7 @@ inline std::string getRequestID( { \ if (!outcome.IsSuccess()) { \ auto error = outcome.GetError(); \ - VELOX_FAIL( \ + auto errMsg = fmt::format( \ "{} due to: '{}'. Path:'{}', SDK Error Type:{}, HTTP Status Code:{}, S3 Service:'{}', Message:'{}', RequestID:'{}'", \ errorMsgPrefix, \ getErrorStringFromS3Error(error), \ @@ -167,7 +169,11 @@ inline std::string getRequestID( error.GetResponseCode(), \ getS3BackendService(error.GetResponseHeaders()), \ error.GetMessage(), \ - getRequestID(error.GetResponseHeaders())) \ + getRequestID(error.GetResponseHeaders())); \ + if (error.GetResponseCode() == Aws::Http::HttpResponseCode::NOT_FOUND) { \ + VELOX_FILE_NOT_FOUND_ERROR(errMsg); \ + } \ + VELOX_FAIL(errMsg) \ } \ } diff --git a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp index 2d0b56fedf6f..1f383e37b381 100644 --- a/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp +++ b/velox/connectors/hive/storage_adapters/s3fs/tests/S3FileSystemTest.cpp @@ -113,16 +113,18 @@ TEST_F(S3FileSystemTest, missingFile) { addBucket(bucketName); auto hiveConfig = minioServer_->hiveConfig(); filesystems::S3FileSystem s3fs(hiveConfig); - VELOX_ASSERT_THROW( + VELOX_ASSERT_RUNTIME_THROW_CODE( s3fs.openFileForRead(s3File), + error_code::kFileNotFound, "Failed to get metadata for S3 object due to: 'Resource not found'. Path:'s3://data1/i-do-not-exist.txt', SDK Error Type:16, HTTP Status Code:404, S3 Service:'MinIO', Message:'No response body.'"); } TEST_F(S3FileSystemTest, missingBucket) { auto hiveConfig = minioServer_->hiveConfig(); filesystems::S3FileSystem s3fs(hiveConfig); - VELOX_ASSERT_THROW( + VELOX_ASSERT_RUNTIME_THROW_CODE( s3fs.openFileForRead(kDummyPath), + error_code::kFileNotFound, "Failed to get metadata for S3 object due to: 'Resource not found'. Path:'s3://dummy/foo.txt', SDK Error Type:16, HTTP Status Code:404, S3 Service:'MinIO', Message:'No response body.'"); } diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 564243826ebf..39bf52cfdab9 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -137,10 +137,6 @@ class QueryConfig { static constexpr const char* kMaxPartitionedOutputBufferSize = "max_page_partitioning_buffer_size"; - /// Deprecated. Use kMaxOutputBufferSize instead. - static constexpr const char* kMaxArbitraryBufferSize = - "max_arbitrary_buffer_size"; - static constexpr const char* kMaxOutputBufferSize = "max_output_buffer_size"; /// Preferred size of batches in bytes to be returned by operators from @@ -434,13 +430,8 @@ class QueryConfig { /// this. The Drivers are resumed when the buffered size goes below /// OutputBufferManager::kContinuePct % of this. uint64_t maxOutputBufferSize() const { - return get(kMaxOutputBufferSize, maxArbitraryBufferSize()); - } - - /// Deprecated. Use maxBufferSize() instead. - uint64_t maxArbitraryBufferSize() const { static constexpr uint64_t kDefault = 32UL << 20; - return get(kMaxArbitraryBufferSize, kDefault); + return get(kMaxOutputBufferSize, kDefault); } uint64_t maxLocalExchangeBufferSize() const { @@ -587,7 +578,7 @@ class QueryConfig { /// calculate the spilling partition number for join spill or aggregation /// spill. uint8_t spillStartPartitionBit() const { - constexpr uint8_t kDefaultStartBit = 29; + constexpr uint8_t kDefaultStartBit = 48; return get(kSpillStartPartitionBit, kDefaultStartBit); } diff --git a/velox/docs/develop/connectors.rst b/velox/docs/develop/connectors.rst index 16bedb50cf77..2d4602508741 100644 --- a/velox/docs/develop/connectors.rst +++ b/velox/docs/develop/connectors.rst @@ -82,6 +82,9 @@ Storage Adapters Hive Connector allows reading and writing files from a variety of distributed storage systems. The supported storage API are S3, HDFS, GCS, Linux FS. +If file is not found when reading, `openFileForRead` API throws `VeloxRuntimeError` with `error_code::kFileNotFound`. +This behavior is necessary to support the `ignore_missing_files` configuration property. + S3 is supported using the `AWS SDK for C++ `_ library. S3 supported schemes are `s3://` (Amazon S3, Minio), `s3a://` (Hadoop 3.x), `s3n://` (Deprecated in Hadoop 3.x), `oss://` (Alibaba cloud storage), and `cos://`, `cosn://` (Tencent cloud storage). diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index 726187d8aa13..bd01ed2e4e4b 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -160,6 +160,7 @@ Array Functions SELECT array_sort(ARRAY [ARRAY [1, 2], ARRAY [1, null]]); -- failed: Ordering nulls is not supported .. function:: array_sort(array(T), function(T,U)) -> array(T) + :noindex: Returns the array sorted by values computed using specified lambda in ascending order. U must be an orderable type. Null elements will be placed at the end of @@ -185,6 +186,7 @@ Array Functions SELECT array_sort(ARRAY [ARRAY [1, 2], ARRAY [1, null]]); -- failed: Ordering nulls is not supported .. function:: array_sort_desc(array(T), function(T,U)) -> array(T) + :noindex: Returns the array sorted by values computed using specified lambda in descending order. U must be an orderable type. Null elements will be placed at the end of @@ -221,9 +223,9 @@ Array Functions When 'element' is of complex type, throws if 'x' or 'element' contains nested nulls and these need to be compared to produce a result. :: - SELECT contains(ARRAY[ARRAY[1, 3]], ARRAY[2, null]); -- false. - SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null - SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null + SELECT contains(ARRAY[ARRAY[1, 3]], ARRAY[2, null]); -- false. + SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null + SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null .. function:: element_at(array(E), index) -> E @@ -247,6 +249,7 @@ Array Functions for no-match and first-match-is-null cases. .. function:: find_first(array(T), index, function(T,boolean)) -> E + :noindex: Returns the first element of ``array`` that matches the predicate. Returns ``NULL`` if no element matches the predicate. @@ -268,6 +271,7 @@ Array Functions Returns ``NULL`` if no such element exists. .. function:: find_first_index(array(T), index, function(T,boolean)) -> BIGINT + :noindex: Returns the 1-based index of the first element of ``array`` that matches the predicate. Returns ``NULL`` if no such element exists. @@ -304,7 +308,9 @@ Array Functions the element, ``inputFunction`` takes the current state, initially ``initialState``, and returns the new state. ``outputFunction`` will be invoked to turn the final state into the result value. It may be the - identity function (``i -> i``). :: + identity function (``i -> i``). + + Throws if array has more than 10,000 elements. :: SELECT reduce(ARRAY [], 0, (s, x) -> s + x, s -> s); -- 0 SELECT reduce(ARRAY [5, 20, 50], 0, (s, x) -> s + x, s -> s); -- 75 @@ -327,7 +333,7 @@ Array Functions .. function:: shuffle(array(E)) -> array(E) - Generate a random permutation of the given ``array``:: + Generate a random permutation of the given ``array`` :: SELECT shuffle(ARRAY [1, 2, 3]); -- [3, 1, 2] or any other random permutation SELECT shuffle(ARRAY [0, 0, 0]); -- [0, 0, 0] @@ -375,7 +381,7 @@ Array Functions .. function:: remove_nulls(x) -> array - Remove null values from an array ``array``:: + Remove null values from an array ``array`` :: SELECT remove_nulls(ARRAY[1, NULL, 3, NULL]); -- [1, 3] SELECT remove_nulls(ARRAY[true, false, NULL]); -- [true, false] @@ -392,7 +398,8 @@ Array Functions .. function:: zip_with(array(T), array(U), function(T,U,R)) -> array(R) Merges the two given arrays, element-wise, into a single array using ``function``. - If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying ``function``:: + If one array is shorter, nulls are appended at the end to match the length of the + longer array, before applying ``function`` :: SELECT zip_with(ARRAY[1, 3, 5], ARRAY['a', 'b', 'c'], (x, y) -> (y, x)); -- [ROW('a', 1), ROW('b', 3), ROW('c', 5)] SELECT zip_with(ARRAY[1, 2], ARRAY[3, 4], (x, y) -> x + y); -- [4, 6] diff --git a/velox/docs/functions/presto/regexp.rst b/velox/docs/functions/presto/regexp.rst index c032413fb3e5..6171ad7a90ab 100644 --- a/velox/docs/functions/presto/regexp.rst +++ b/velox/docs/functions/presto/regexp.rst @@ -7,6 +7,9 @@ supports only a subset of PCRE syntax and in particular does not support backtracking and associated features (e.g. back references). See https://github.com/google/re2/wiki/Syntax for more information. +Compiling regular expressions is CPU intensive. Hence, each function is +limited to 20 different expressions per instance and thread of execution. + .. function:: like(string, pattern) -> boolean like(string, pattern, escape) -> boolean @@ -19,9 +22,11 @@ See https://github.com/google/re2/wiki/Syntax for more information. wildcard '_' represents exactly one character. Note: Each function instance allow for a maximum of 20 regular expressions to - be compiled throughout the lifetime of the query. Not all Patterns requires - compilation of regular expressions; for example a pattern 'aa' does not. - Only those that require the compilation of regular expressions are counted. + be compiled per thread of execution. Not all patterns require + compilation of regular expressions. Patterns 'aaa', 'aaa%', '%aaa', where 'aaa' + contains only regular characters and '_' wildcards are evaluated without + using regular expressions. Only those patterns that require the compilation of + regular expressions are counted towards the limit. SELECT like('abc', '%b%'); -- true SELECT like('a_c', '%#_%', '#'); -- true @@ -34,14 +39,22 @@ See https://github.com/google/re2/wiki/Syntax for more information. SELECT regexp_extract('1a 2b 14m', '\d+'); -- 1 .. function:: regexp_extract(string, pattern, group) -> varchar - :noindex: + :noindex: Finds the first occurrence of the regular expression ``pattern`` in ``string`` and returns the capturing group number ``group``:: SELECT regexp_extract('1a 2b 14m', '(\d+)([a-z]+)', 2); -- 'a' -.. function:: regexp_extract_all(string, pattern, group) -> array(varchar) +.. function:: regexp_extract_all(string, pattern) -> array(varchar): + + Returns the substring(s) matched by the regular expression ``pattern`` + in ``string``:: + + SELECT regexp_extract_all('1a 2b 14m', '\d+'); -- [1, 2, 14] + +.. function:: regexp_extract_all(string, pattern, group) -> array(varchar): + :noindex: Finds all occurrences of the regular expression ``pattern`` in ``string`` and returns the capturing group number ``group``:: @@ -69,7 +82,7 @@ See https://github.com/google/re2/wiki/Syntax for more information. SELECT regexp_replace('1a 2b 14m', '\d+[ab] '); -- '14m' .. function:: regexp_replace(string, pattern, replacement) -> varchar - :noindex: + :noindex: Replaces every instance of the substring matched by the regular expression ``pattern`` in ``string`` with ``replacement``. Capturing groups can be referenced in diff --git a/velox/docs/functions/spark/array.rst b/velox/docs/functions/spark/array.rst index 2183f4f301c2..f80e13923dc0 100644 --- a/velox/docs/functions/spark/array.rst +++ b/velox/docs/functions/spark/array.rst @@ -51,6 +51,18 @@ Array Functions SELECT array_min(ARRAY [4.0, float('nan')]); -- 4.0 SELECT array_min(ARRAY [NULL, float('nan')]); -- NaN +.. spark:function:: array_repeat(element, count) -> array(E) + + Returns an array containing ``element`` ``count`` times. If ``count`` is negative or zero, + returns empty array. If ``element`` is NULL, returns an array containing ``count`` NULLs. + If ``count`` is NULL, returns NULL as result. Throws an exception if ``count`` exceeds 10'000. :: + + SELECT array_repeat(100, 3); -- [100, 100, 100] + SELECT array_repeat(NULL, 3); -- [NULL, NULL, NULL] + SELECT array_repeat(100, NULL); -- NULL + SELECT array_repeat(100, 0); -- [] + SELECT array_repeat(100, -1); -- [] + .. spark:function:: array_sort(array(E)) -> array(E) Returns an array which has the sorted order of the input array(E). The elements of array(E) must diff --git a/velox/docs/functions/spark/window.rst b/velox/docs/functions/spark/window.rst index a6305a8d7891..2a4d7921c95a 100644 --- a/velox/docs/functions/spark/window.rst +++ b/velox/docs/functions/spark/window.rst @@ -30,3 +30,8 @@ Returns the rank of a value in a group of values. The rank is one plus the numbe Returns the rank of a value in a group of values. This is similar to rank(), except that tie values do not produce gaps in the sequence. +.. spark:function:: ntile(n) -> integer + +Divides the rows for each window partition into n buckets ranging from 1 to at most ``n``. Bucket values will differ by at most 1. If the number of rows in the partition does not divide evenly into the number of buckets, then the remainder values are distributed one per bucket, starting with the first bucket. + +For example, with 6 rows and 4 buckets, the bucket values would be as follows: ``1 1 2 2 3 4`` diff --git a/velox/dwio/common/ColumnLoader.cpp b/velox/dwio/common/ColumnLoader.cpp index c04cb74db271..4db934c35bdd 100644 --- a/velox/dwio/common/ColumnLoader.cpp +++ b/velox/dwio/common/ColumnLoader.cpp @@ -16,6 +16,8 @@ #include "velox/dwio/common/ColumnLoader.h" +#include "velox/common/process/TraceContext.h" + namespace facebook::velox::dwio::common { // Wraps '*result' in a dictionary to make the contiguous values @@ -45,6 +47,7 @@ void ColumnLoader::loadInternal( ValueHook* hook, vector_size_t resultSize, VectorPtr* result) { + process::TraceContext trace("ColumnLoader::loadInternal"); VELOX_CHECK_EQ( version_, structReader_->numReads(), diff --git a/velox/dwio/common/Options.h b/velox/dwio/common/Options.h index ecebe4c62848..c154f8f4a6e0 100644 --- a/velox/dwio/common/Options.h +++ b/velox/dwio/common/Options.h @@ -94,6 +94,8 @@ class SerDeOptions { struct TableParameter { static constexpr const char* kSkipHeaderLineCount = "skip.header.line.count"; + static constexpr const char* kSerializationNullFormat = + "serialization.null.format"; }; /** @@ -131,7 +133,8 @@ class RowReaderOptions { // (in dwrf row reader). todo: encapsulate this and keySelectionCallBack_ in a // struct std::function blockedOnIoCallback_; - std::function decodingTimeMsCallback_; + std::function decodingTimeUsCallback_; + std::function stripeCountCallback_; bool eagerFirstStripeLoad = true; uint64_t skipRows_ = 0; @@ -349,12 +352,21 @@ class RowReaderOptions { return blockedOnIoCallback_; } - void setDecodingTimeMsCallback(std::function decodingTimeMs) { - decodingTimeMsCallback_ = std::move(decodingTimeMs); + void setDecodingTimeUsCallback(std::function decodingTimeUs) { + decodingTimeUsCallback_ = std::move(decodingTimeUs); } - const std::function getDecodingTimeMsCallback() const { - return decodingTimeMsCallback_; + std::function getDecodingTimeUsCallback() const { + return decodingTimeUsCallback_; + } + + void setStripeCountCallback( + std::function stripeCountCallback) { + stripeCountCallback_ = std::move(stripeCountCallback); + } + + std::function getStripeCountCallback() const { + return stripeCountCallback_; } void setSkipRows(uint64_t skipRows) { diff --git a/velox/dwio/common/SelectiveColumnReader.cpp b/velox/dwio/common/SelectiveColumnReader.cpp index f2c157ff9c7e..25aff8eb42c3 100644 --- a/velox/dwio/common/SelectiveColumnReader.cpp +++ b/velox/dwio/common/SelectiveColumnReader.cpp @@ -66,6 +66,7 @@ const std::vector& SelectiveColumnReader::children() } void SelectiveColumnReader::seekTo(vector_size_t offset, bool readsNullsOnly) { + VELOX_TRACE_HISTORY_PUSH("seekTo %d %d", offset, readsNullsOnly); if (offset == readOffset_) { return; } diff --git a/velox/dwio/common/SelectiveColumnReader.h b/velox/dwio/common/SelectiveColumnReader.h index 08740a139914..4b26b0d4d60c 100644 --- a/velox/dwio/common/SelectiveColumnReader.h +++ b/velox/dwio/common/SelectiveColumnReader.h @@ -18,6 +18,7 @@ #include "velox/common/base/RawVector.h" #include "velox/common/memory/Memory.h" #include "velox/common/process/ProcessBase.h" +#include "velox/common/process/TraceHistory.h" #include "velox/dwio/common/ColumnSelector.h" #include "velox/dwio/common/FormatData.h" #include "velox/dwio/common/IntDecoder.h" @@ -189,7 +190,8 @@ class SelectiveColumnReader { // group. Interpretation of 'index' depends on format. Clears counts // of skipped enclosing struct nulls for formats where nulls are // recorded at each nesting level, i.e. not rep-def. - virtual void seekToRowGroup(uint32_t /*index*/) { + virtual void seekToRowGroup(uint32_t index) { + VELOX_TRACE_HISTORY_PUSH("seekToRowGroup %u", index); numParentNulls_ = 0; parentNullsRecordedTo_ = 0; } diff --git a/velox/dwio/common/SelectiveStructColumnReader.cpp b/velox/dwio/common/SelectiveStructColumnReader.cpp index 30e6e748fc1b..1f07f73351e3 100644 --- a/velox/dwio/common/SelectiveStructColumnReader.cpp +++ b/velox/dwio/common/SelectiveStructColumnReader.cpp @@ -16,6 +16,7 @@ #include "velox/dwio/common/SelectiveStructColumnReader.h" +#include "velox/common/process/TraceContext.h" #include "velox/dwio/common/ColumnLoader.h" namespace facebook::velox::dwio::common { @@ -56,6 +57,7 @@ void SelectiveStructColumnReaderBase::next( uint64_t numValues, VectorPtr& result, const Mutation* mutation) { + process::TraceContext trace("SelectiveStructColumnReaderBase::next"); if (children_.empty()) { if (mutation && mutation->deletedRows) { numValues -= bits::countBits(mutation->deletedRows, 0, numValues); @@ -136,6 +138,7 @@ void SelectiveStructColumnReaderBase::read( VELOX_CHECK(!childSpecs.empty()); for (size_t i = 0; i < childSpecs.size(); ++i) { auto& childSpec = childSpecs[i]; + VELOX_TRACE_HISTORY_PUSH("read %s", childSpec->fieldName().c_str()); if (isChildConstant(*childSpec)) { continue; } @@ -339,6 +342,7 @@ void SelectiveStructColumnReaderBase::getValues( } bool lazyPrepared = false; for (auto& childSpec : scanSpec_->children()) { + VELOX_TRACE_HISTORY_PUSH("getValues %s", childSpec->fieldName().c_str()); if (!childSpec->projectOut()) { continue; } diff --git a/velox/dwio/dwrf/reader/DwrfReader.cpp b/velox/dwio/dwrf/reader/DwrfReader.cpp index e78d395eaec3..e0d5accb1088 100644 --- a/velox/dwio/dwrf/reader/DwrfReader.cpp +++ b/velox/dwio/dwrf/reader/DwrfReader.cpp @@ -15,6 +15,9 @@ */ #include "velox/dwio/dwrf/reader/DwrfReader.h" + +#include + #include "velox/dwio/common/TypeUtils.h" #include "velox/dwio/common/exception/Exception.h" #include "velox/dwio/dwrf/reader/ColumnReader.h" @@ -35,6 +38,8 @@ DwrfRowReader::DwrfRowReader( : StripeReaderBase(reader), options_(opts), executor_{options_.getDecodingExecutor()}, + decodingTimeUsCallback_{options_.getDecodingTimeUsCallback()}, + stripeCountCallback_{options_.getStripeCountCallback()}, columnSelector_{std::make_shared( ColumnSelector::apply(opts.getSelector(), reader->getSchema()))} { if (executor_) { @@ -73,6 +78,9 @@ DwrfRowReader::DwrfRowReader( if (stripeCeiling_ == 0) { stripeCeiling_ = firstStripe_; } + if (stripeCountCallback_) { + stripeCountCallback_(stripeCeiling_ - firstStripe_); + } if (currentStripe_ == 0) { previousRow_ = std::numeric_limits::max(); @@ -269,18 +277,23 @@ void DwrfRowReader::readNext( const dwio::common::Mutation* mutation, VectorPtr& result) { if (!selectiveColumnReader_) { - const auto startTime = std::chrono::high_resolution_clock::now(); + std::optional startTime; + if (decodingTimeUsCallback_) { + // We'll use wall time since we have parallel decoding. + // If we move to sequential decoding only, we can use CPU time. + startTime.emplace(std::chrono::steady_clock::now()); + } // TODO: Move row number appending logic here. Currently this is done in // the wrapper reader. VELOX_CHECK( mutation == nullptr, "Mutation pushdown is only supported in selective reader"); columnReader_->next(rowsToRead, result); - auto reportDecodingTimeMsMetric = options_.getDecodingTimeMsCallback(); - if (reportDecodingTimeMsMetric) { - auto decodingTime = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - startTime); - reportDecodingTimeMsMetric(decodingTime.count()); + if (startTime.has_value()) { + decodingTimeUsCallback_( + std::chrono::duration_cast( + std::chrono::steady_clock::now() - startTime.value()) + .count()); } return; } diff --git a/velox/dwio/dwrf/reader/DwrfReader.h b/velox/dwio/dwrf/reader/DwrfReader.h index 79742d447b76..950f0829ba43 100644 --- a/velox/dwio/dwrf/reader/DwrfReader.h +++ b/velox/dwio/dwrf/reader/DwrfReader.h @@ -149,6 +149,8 @@ class DwrfRowReader : public StrideIndexProvider, std::shared_ptr stripeDictionaryCache_; dwio::common::RowReaderOptions options_; std::shared_ptr executor_; + std::function decodingTimeUsCallback_; + std::function stripeCountCallback_; struct PrefetchedStripeState { bool preloaded; diff --git a/velox/dwio/dwrf/reader/ReaderBase.cpp b/velox/dwio/dwrf/reader/ReaderBase.cpp index 8cca5531e56d..4c98bcd3441d 100644 --- a/velox/dwio/dwrf/reader/ReaderBase.cpp +++ b/velox/dwio/dwrf/reader/ReaderBase.cpp @@ -18,6 +18,7 @@ #include +#include "velox/common/process/TraceContext.h" #include "velox/dwio/common/exception/Exception.h" namespace facebook::velox::dwrf { @@ -100,6 +101,7 @@ ReaderBase::ReaderBase( footerEstimatedSize_(footerEstimatedSize), filePreloadThreshold_(filePreloadThreshold), input_(std::move(input)) { + process::TraceContext trace("ReaderBase::ReaderBase"); // read last bytes into buffer to get PostScript // If file is small, load the entire file. // TODO: make a config diff --git a/velox/dwio/dwrf/test/ReaderTest.cpp b/velox/dwio/dwrf/test/ReaderTest.cpp index 627dd82d4333..7d28696b74b5 100644 --- a/velox/dwio/dwrf/test/ReaderTest.cpp +++ b/velox/dwio/dwrf/test/ReaderTest.cpp @@ -1819,6 +1819,92 @@ TEST_F(TestReader, fileColumnNamesReadAsLowerCaseComplexStruct) { EXPECT_EQ(col0_1_1_0_0->childByName("ccint3"), col0_1_1_0_0_0); } +TEST_F(TestReader, TestStripeSizeCallback) { + dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setFilePreloadThreshold(0); + readerOpts.setFooterEstimatedSize(4); + RowReaderOptions rowReaderOpts; + + std::shared_ptr requestedType = std::dynamic_pointer_cast< + const RowType>(HiveTypeParser().parse( + "struct")); + rowReaderOpts.select(std::make_shared(requestedType)); + rowReaderOpts.setEagerFirstStripeLoad(false); + uint16_t stripeCount = 0; + int numCalls = 0; + rowReaderOpts.setStripeCountCallback([&](uint16_t count) { + stripeCount += count; + ++numCalls; + }); + + auto reader = DwrfReader::create( + createFileBufferedInput( + getExampleFilePath("dict_encoded_strings.orc"), + readerOpts.getMemoryPool()), + readerOpts); + auto rowReaderOwner = reader->createRowReader(rowReaderOpts); + EXPECT_EQ(stripeCount, 3); + EXPECT_EQ(numCalls, 1); +} + +TEST_F(TestReader, TestStripeSizeCallbackLimitsOneStripe) { + dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setFilePreloadThreshold(0); + readerOpts.setFooterEstimatedSize(4); + RowReaderOptions rowReaderOpts; + + std::shared_ptr requestedType = std::dynamic_pointer_cast< + const RowType>(HiveTypeParser().parse( + "struct")); + rowReaderOpts.select(std::make_shared(requestedType)); + rowReaderOpts.setEagerFirstStripeLoad(false); + rowReaderOpts.range(600, 600); + uint16_t stripeCount = 0; + int numCalls = 0; + rowReaderOpts.setStripeCountCallback([&](uint16_t count) { + stripeCount += count; + ++numCalls; + }); + + auto reader = DwrfReader::create( + createFileBufferedInput( + getExampleFilePath("dict_encoded_strings.orc"), + readerOpts.getMemoryPool()), + readerOpts); + auto rowReaderOwner = reader->createRowReader(rowReaderOpts); + EXPECT_EQ(stripeCount, 1); + EXPECT_EQ(numCalls, 1); +} + +TEST_F(TestReader, TestStripeSizeCallbackLimitsTwoStripe) { + dwio::common::ReaderOptions readerOpts{pool()}; + readerOpts.setFilePreloadThreshold(0); + readerOpts.setFooterEstimatedSize(4); + RowReaderOptions rowReaderOpts; + + std::shared_ptr requestedType = std::dynamic_pointer_cast< + const RowType>(HiveTypeParser().parse( + "struct")); + rowReaderOpts.select(std::make_shared(requestedType)); + rowReaderOpts.setEagerFirstStripeLoad(false); + rowReaderOpts.range(0, 600); + uint16_t stripeCount = 0; + int numCalls = 0; + rowReaderOpts.setStripeCountCallback([&](uint16_t count) { + stripeCount += count; + ++numCalls; + }); + + auto reader = DwrfReader::create( + createFileBufferedInput( + getExampleFilePath("dict_encoded_strings.orc"), + readerOpts.getMemoryPool()), + readerOpts); + auto rowReaderOwner = reader->createRowReader(rowReaderOpts); + EXPECT_EQ(stripeCount, 2); + EXPECT_EQ(numCalls, 1); +} + TEST_P(TestReaderP, testUpcastBoolean) { MockStripeStreams streams; diff --git a/velox/dwio/parquet/reader/CMakeLists.txt b/velox/dwio/parquet/reader/CMakeLists.txt index 3fb5250b7e64..fbb38dd64eef 100644 --- a/velox/dwio/parquet/reader/CMakeLists.txt +++ b/velox/dwio/parquet/reader/CMakeLists.txt @@ -23,7 +23,6 @@ add_library( ParquetData.cpp RepeatedColumnReader.cpp RleBpDecoder.cpp - Statistics.cpp StructColumnReader.cpp StringColumnReader.cpp) diff --git a/velox/dwio/parquet/reader/Metadata.cpp b/velox/dwio/parquet/reader/Metadata.cpp index c0fa6ab7ca02..771e68e8a595 100644 --- a/velox/dwio/parquet/reader/Metadata.cpp +++ b/velox/dwio/parquet/reader/Metadata.cpp @@ -15,11 +15,142 @@ */ #include "velox/dwio/parquet/reader/Metadata.h" - -#include "velox/dwio/parquet/reader/Statistics.h" +#include "velox/dwio/parquet/thrift/ParquetThriftTypes.h" namespace facebook::velox::parquet { +template +inline const T load(const char* ptr) { + T ret; + std::memcpy(&ret, ptr, sizeof(ret)); + return ret; +} + +template +inline std::optional getMin(const thrift::Statistics& columnChunkStats) { + return columnChunkStats.__isset.min_value + ? load(columnChunkStats.min_value.data()) + : (columnChunkStats.__isset.min + ? std::optional(load(columnChunkStats.min.data())) + : std::nullopt); +} + +template +inline std::optional getMax(const thrift::Statistics& columnChunkStats) { + return columnChunkStats.__isset.max_value + ? std::optional(load(columnChunkStats.max_value.data())) + : (columnChunkStats.__isset.max + ? std::optional(load(columnChunkStats.max.data())) + : std::nullopt); +} + +template <> +inline std::optional getMin( + const thrift::Statistics& columnChunkStats) { + return columnChunkStats.__isset.min_value + ? std::optional(columnChunkStats.min_value) + : (columnChunkStats.__isset.min ? std::optional(columnChunkStats.min) + : std::nullopt); +} + +template <> +inline std::optional getMax( + const thrift::Statistics& columnChunkStats) { + return columnChunkStats.__isset.max_value + ? std::optional(columnChunkStats.max_value) + : (columnChunkStats.__isset.max ? std::optional(columnChunkStats.max) + : std::nullopt); +} + +std::unique_ptr buildColumnStatisticsFromThrift( + const thrift::Statistics& columnChunkStats, + const velox::Type& type, + uint64_t numRowsInRowGroup) { + std::optional nullCount = columnChunkStats.__isset.null_count + ? std::optional(columnChunkStats.null_count) + : std::nullopt; + std::optional valueCount = nullCount.has_value() + ? std::optional(numRowsInRowGroup - nullCount.value()) + : std::nullopt; + std::optional hasNull = columnChunkStats.__isset.null_count + ? std::optional(columnChunkStats.null_count > 0) + : std::nullopt; + + switch (type.kind()) { + case TypeKind::BOOLEAN: + return std::make_unique( + valueCount, hasNull, std::nullopt, std::nullopt, std::nullopt); + case TypeKind::TINYINT: + return std::make_unique( + valueCount, + hasNull, + std::nullopt, + std::nullopt, + getMin(columnChunkStats), + getMax(columnChunkStats), + std::nullopt); + case TypeKind::SMALLINT: + return std::make_unique( + valueCount, + hasNull, + std::nullopt, + std::nullopt, + getMin(columnChunkStats), + getMax(columnChunkStats), + std::nullopt); + case TypeKind::INTEGER: + return std::make_unique( + valueCount, + hasNull, + std::nullopt, + std::nullopt, + getMin(columnChunkStats), + getMax(columnChunkStats), + std::nullopt); + case TypeKind::BIGINT: + return std::make_unique( + valueCount, + hasNull, + std::nullopt, + std::nullopt, + getMin(columnChunkStats), + getMax(columnChunkStats), + std::nullopt); + case TypeKind::REAL: + return std::make_unique( + valueCount, + hasNull, + std::nullopt, + std::nullopt, + getMin(columnChunkStats), + getMax(columnChunkStats), + std::nullopt); + case TypeKind::DOUBLE: + return std::make_unique( + valueCount, + hasNull, + std::nullopt, + std::nullopt, + getMin(columnChunkStats), + getMax(columnChunkStats), + std::nullopt); + case TypeKind::VARCHAR: + case TypeKind::VARBINARY: + return std::make_unique( + valueCount, + hasNull, + std::nullopt, + std::nullopt, + getMin(columnChunkStats), + getMax(columnChunkStats), + std::nullopt); + + default: + return std::make_unique( + valueCount, hasNull, std::nullopt, std::nullopt); + } +} + common::CompressionKind thriftCodecToCompressionKind( thrift::CompressionCodec::type codec) { switch (codec) { diff --git a/velox/dwio/parquet/reader/ParquetColumnReader.cpp b/velox/dwio/parquet/reader/ParquetColumnReader.cpp index ea3169ae727a..c3816c0e960a 100644 --- a/velox/dwio/parquet/reader/ParquetColumnReader.cpp +++ b/velox/dwio/parquet/reader/ParquetColumnReader.cpp @@ -25,10 +25,8 @@ #include "velox/dwio/parquet/reader/FloatingPointColumnReader.h" #include "velox/dwio/parquet/reader/IntegerColumnReader.h" #include "velox/dwio/parquet/reader/RepeatedColumnReader.h" -#include "velox/dwio/parquet/reader/Statistics.h" #include "velox/dwio/parquet/reader/StringColumnReader.h" #include "velox/dwio/parquet/reader/StructColumnReader.h" -#include "velox/dwio/parquet/thrift/ParquetThriftTypes.h" namespace facebook::velox::parquet { diff --git a/velox/dwio/parquet/reader/ParquetData.cpp b/velox/dwio/parquet/reader/ParquetData.cpp index 283190bbfb0a..a2688403ebcd 100644 --- a/velox/dwio/parquet/reader/ParquetData.cpp +++ b/velox/dwio/parquet/reader/ParquetData.cpp @@ -17,7 +17,6 @@ #include "velox/dwio/parquet/reader/ParquetData.h" #include "velox/dwio/common/BufferedInput.h" -#include "velox/dwio/parquet/reader/Statistics.h" namespace facebook::velox::parquet { diff --git a/velox/dwio/parquet/reader/Statistics.cpp b/velox/dwio/parquet/reader/Statistics.cpp deleted file mode 100644 index e7ee86a8b768..000000000000 --- a/velox/dwio/parquet/reader/Statistics.cpp +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "velox/dwio/parquet/reader/Statistics.h" - -#include "velox/dwio/common/Statistics.h" -#include "velox/type/Type.h" - -namespace facebook::velox::parquet { - -std::unique_ptr buildColumnStatisticsFromThrift( - const thrift::Statistics& columnChunkStats, - const velox::Type& type, - uint64_t numRowsInRowGroup) { - std::optional nullCount = columnChunkStats.__isset.null_count - ? std::optional(columnChunkStats.null_count) - : std::nullopt; - std::optional valueCount = nullCount.has_value() - ? std::optional(numRowsInRowGroup - nullCount.value()) - : std::nullopt; - std::optional hasNull = columnChunkStats.__isset.null_count - ? std::optional(columnChunkStats.null_count > 0) - : std::nullopt; - - switch (type.kind()) { - case TypeKind::BOOLEAN: - return std::make_unique( - valueCount, hasNull, std::nullopt, std::nullopt, std::nullopt); - case TypeKind::TINYINT: - return std::make_unique( - valueCount, - hasNull, - std::nullopt, - std::nullopt, - getMin(columnChunkStats), - getMax(columnChunkStats), - std::nullopt); - case TypeKind::SMALLINT: - return std::make_unique( - valueCount, - hasNull, - std::nullopt, - std::nullopt, - getMin(columnChunkStats), - getMax(columnChunkStats), - std::nullopt); - case TypeKind::INTEGER: - return std::make_unique( - valueCount, - hasNull, - std::nullopt, - std::nullopt, - getMin(columnChunkStats), - getMax(columnChunkStats), - std::nullopt); - case TypeKind::BIGINT: - return std::make_unique( - valueCount, - hasNull, - std::nullopt, - std::nullopt, - getMin(columnChunkStats), - getMax(columnChunkStats), - std::nullopt); - case TypeKind::REAL: - return std::make_unique( - valueCount, - hasNull, - std::nullopt, - std::nullopt, - getMin(columnChunkStats), - getMax(columnChunkStats), - std::nullopt); - case TypeKind::DOUBLE: - return std::make_unique( - valueCount, - hasNull, - std::nullopt, - std::nullopt, - getMin(columnChunkStats), - getMax(columnChunkStats), - std::nullopt); - case TypeKind::VARCHAR: - case TypeKind::VARBINARY: - return std::make_unique( - valueCount, - hasNull, - std::nullopt, - std::nullopt, - getMin(columnChunkStats), - getMax(columnChunkStats), - std::nullopt); - - default: - return std::make_unique( - valueCount, hasNull, std::nullopt, std::nullopt); - } -} - -} // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/reader/Statistics.h b/velox/dwio/parquet/reader/Statistics.h deleted file mode 100644 index 18f67d5b13b0..000000000000 --- a/velox/dwio/parquet/reader/Statistics.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "velox/dwio/parquet/thrift/ParquetThriftTypes.h" - -#include -#include - -namespace facebook::velox { -class Type; -} - -namespace facebook::velox::dwio::common { -class ColumnStatistics; -} - -namespace facebook::velox::parquet { - -// TODO: provide function to merge multiple Statistics into one - -template -inline const T load(const char* ptr) { - T ret; - std::memcpy(&ret, ptr, sizeof(ret)); - return ret; -} - -template -inline std::optional getMin(const thrift::Statistics& columnChunkStats) { - return columnChunkStats.__isset.min_value - ? load(columnChunkStats.min_value.data()) - : (columnChunkStats.__isset.min - ? std::optional(load(columnChunkStats.min.data())) - : std::nullopt); -} - -template -inline std::optional getMax(const thrift::Statistics& columnChunkStats) { - return columnChunkStats.__isset.max_value - ? std::optional(load(columnChunkStats.max_value.data())) - : (columnChunkStats.__isset.max - ? std::optional(load(columnChunkStats.max.data())) - : std::nullopt); -} - -template <> -inline std::optional getMin( - const thrift::Statistics& columnChunkStats) { - return columnChunkStats.__isset.min_value - ? std::optional(columnChunkStats.min_value) - : (columnChunkStats.__isset.min ? std::optional(columnChunkStats.min) - : std::nullopt); -} - -template <> -inline std::optional getMax( - const thrift::Statistics& columnChunkStats) { - return columnChunkStats.__isset.max_value - ? std::optional(columnChunkStats.max_value) - : (columnChunkStats.__isset.max ? std::optional(columnChunkStats.max) - : std::nullopt); -} - -std::unique_ptr buildColumnStatisticsFromThrift( - const thrift::Statistics& columnChunkStats, - const velox::Type& type, - uint64_t numRowsInRowGroup); - -} // namespace facebook::velox::parquet diff --git a/velox/dwio/parquet/tests/CMakeLists.txt b/velox/dwio/parquet/tests/CMakeLists.txt index 32d0d67afca4..427a9085d828 100644 --- a/velox/dwio/parquet/tests/CMakeLists.txt +++ b/velox/dwio/parquet/tests/CMakeLists.txt @@ -43,3 +43,6 @@ target_link_libraries( velox_aggregates velox_tpch_gen ${TEST_LINK_LIBS}) + +file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/examples + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/velox/exec/AddressableNonNullValueList.cpp b/velox/exec/AddressableNonNullValueList.cpp index dc0bc686c29b..4f440dc4422c 100644 --- a/velox/exec/AddressableNonNullValueList.cpp +++ b/velox/exec/AddressableNonNullValueList.cpp @@ -18,9 +18,7 @@ namespace facebook::velox::aggregate::prestosql { -AddressableNonNullValueList::Entry AddressableNonNullValueList::append( - const DecodedVector& decoded, - vector_size_t index, +ByteOutputStream AddressableNonNullValueList::initStream( HashStringAllocator* allocator) { ByteOutputStream stream(allocator); if (!firstHeader_) { @@ -30,13 +28,21 @@ AddressableNonNullValueList::Entry AddressableNonNullValueList::append( // and a next pointer. This could be adaptive, with smaller initial // sizes for lots of small arrays. static constexpr int kInitialSize = 44; - currentPosition_ = allocator->newWrite(stream, kInitialSize); firstHeader_ = currentPosition_.header; } else { allocator->extendWrite(currentPosition_, stream); } + return stream; +} + +AddressableNonNullValueList::Entry AddressableNonNullValueList::append( + const DecodedVector& decoded, + vector_size_t index, + HashStringAllocator* allocator) { + auto stream = initStream(allocator); + const auto hash = decoded.base()->hashValueAt(decoded.index(index)); const auto originalSize = stream.size(); @@ -44,7 +50,6 @@ AddressableNonNullValueList::Entry AddressableNonNullValueList::append( // Write value. exec::ContainerRowSerde::serialize( *decoded.base(), decoded.index(index), stream); - ++size_; auto startAndFinish = allocator->finishWrite(stream, 1024); @@ -55,6 +60,21 @@ AddressableNonNullValueList::Entry AddressableNonNullValueList::append( return {startAndFinish.first, writtenSize, hash}; } +HashStringAllocator::Position AddressableNonNullValueList::appendSerialized( + const StringView& value, + HashStringAllocator* allocator) { + auto stream = initStream(allocator); + + const auto originalSize = stream.size(); + stream.appendStringView(value); + ++size_; + + auto startAndFinish = allocator->finishWrite(stream, 1024); + currentPosition_ = startAndFinish.second; + VELOX_CHECK_EQ(stream.size() - originalSize, value.size()); + return {startAndFinish.first}; +} + namespace { ByteInputStream prepareRead(const AddressableNonNullValueList::Entry& entry) { @@ -94,4 +114,12 @@ void AddressableNonNullValueList::read( exec::ContainerRowSerde::deserialize(stream, index, &result); } +// static +void AddressableNonNullValueList::readSerialized( + const Entry& position, + char* dest) { + auto stream = prepareRead(position); + stream.readBytes(dest, position.size); +} + } // namespace facebook::velox::aggregate::prestosql diff --git a/velox/exec/AddressableNonNullValueList.h b/velox/exec/AddressableNonNullValueList.h index cd142385e559..81fd6a66aab9 100644 --- a/velox/exec/AddressableNonNullValueList.h +++ b/velox/exec/AddressableNonNullValueList.h @@ -15,6 +15,7 @@ */ #pragma once +#include "velox/common/base/IOUtils.h" #include "velox/common/memory/HashStringAllocator.h" #include "velox/vector/DecodedVector.h" @@ -57,6 +58,12 @@ class AddressableNonNullValueList { vector_size_t index, HashStringAllocator* allocator); + /// Append a non-null serialized value to the end of the list. + /// Returns position that can be used to access the value later. + HashStringAllocator::Position appendSerialized( + const StringView& value, + HashStringAllocator* allocator); + /// Removes last element. 'position' must be a value returned from the latest /// call to 'append'. void removeLast(const Entry& entry) { @@ -77,6 +84,9 @@ class AddressableNonNullValueList { static void read(const Entry& position, BaseVector& result, vector_size_t index); + /// Copies to 'dest' entry.size bytes at position. + static void readSerialized(const Entry& position, char* dest); + void free(HashStringAllocator& allocator) { if (size_ > 0) { allocator.free(firstHeader_); @@ -84,6 +94,8 @@ class AddressableNonNullValueList { } private: + ByteOutputStream initStream(HashStringAllocator* allocator); + // Memory allocation (potentially multi-part). HashStringAllocator::Header* firstHeader_{nullptr}; HashStringAllocator::Position currentPosition_{nullptr, nullptr}; diff --git a/velox/exec/AggregateInfo.cpp b/velox/exec/AggregateInfo.cpp index 838d4b58c979..47f415b98e59 100644 --- a/velox/exec/AggregateInfo.cpp +++ b/velox/exec/AggregateInfo.cpp @@ -52,13 +52,6 @@ std::vector toAggregateInfo( for (auto i = 0; i < numAggregates; i++) { const auto& aggregate = aggregationNode.aggregates()[i]; - - // TODO: Add support for StreamingAggregation - if (isStreaming && aggregate.distinct) { - VELOX_UNSUPPORTED( - "Streaming aggregation doesn't support aggregations over distinct inputs yet"); - } - AggregateInfo info; // Populate input. auto& channels = info.inputs; diff --git a/velox/exec/DistinctAggregations.cpp b/velox/exec/DistinctAggregations.cpp index 974514e89331..1391fa864325 100644 --- a/velox/exec/DistinctAggregations.cpp +++ b/velox/exec/DistinctAggregations.cpp @@ -135,6 +135,14 @@ class TypedDistinctAggregations : public DistinctAggregations { // Release memory back to HashStringAllocator to allow next // aggregate to re-use it. aggregate.function->destroy(groups); + + // Overwrite empty groups over the destructed groups to keep the container + // in a well formed state. + raw_vector temp; + aggregate.function->initializeNewGroups( + groups.data(), + folly::Range( + iota(groups.size(), temp), groups.size())); } } diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 11b49d095de0..551ac2131a95 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -395,7 +395,10 @@ size_t OpCallStatusRaw::callDuration() const { /*static*/ std::string OpCallStatusRaw::formatCall( Operator* op, const char* operatorMethod) { - return fmt::format("{}::{}", op ? op->operatorType() : "N/A", operatorMethod); + return op + ? fmt::format( + "{}.{}::{}", op->operatorType(), op->planNodeId(), operatorMethod) + : fmt::format("null::{}", operatorMethod); } CpuWallTiming Driver::processLazyTiming( diff --git a/velox/exec/ExchangeClient.h b/velox/exec/ExchangeClient.h index 289d9e974a62..05eb8dd60ed2 100644 --- a/velox/exec/ExchangeClient.h +++ b/velox/exec/ExchangeClient.h @@ -85,7 +85,7 @@ class ExchangeClient : public std::enable_shared_from_this { /// /// If no data is available returns empty list and sets 'atEnd' to true if no /// more data is expected. If data is still expected, sets 'atEnd' to false - /// and sets 'future' to a Future that will comlete when data arrives. + /// and sets 'future' to a Future that will complete when data arrives. /// /// The data may be compressed, in which case 'maxBytes' applies to compressed /// size. diff --git a/velox/exec/ExchangeQueue.cpp b/velox/exec/ExchangeQueue.cpp index 438b19d12fca..a01a218db162 100644 --- a/velox/exec/ExchangeQueue.cpp +++ b/velox/exec/ExchangeQueue.cpp @@ -111,7 +111,7 @@ std::vector> ExchangeQueue::dequeueLocked( if (queue_.empty()) { if (atEnd_) { *atEnd = true; - } else { + } else if (pages.empty()) { promises_.emplace_back("ExchangeQueue::dequeue"); *future = promises_.back().getSemiFuture(); } diff --git a/velox/exec/ExchangeQueue.h b/velox/exec/ExchangeQueue.h index f77ba1afdba6..d19babf26ce7 100644 --- a/velox/exec/ExchangeQueue.h +++ b/velox/exec/ExchangeQueue.h @@ -114,7 +114,7 @@ class ExchangeQueue { /// sets 'atEnd' to false and 'future' to a Future that will complete when /// data arrives. If no more data is expected, sets 'atEnd' to true. Returns /// at least one page if data is available. If multiple pages are available, - /// returns as many pages as fit within 'maxBytes', but no fewer than onc. + /// returns as many pages as fit within 'maxBytes', but no fewer than one. /// Calling this method with 'maxBytes' of 1 returns at most one page. /// /// The data may be compressed, in which case 'maxBytes' applies to compressed diff --git a/velox/exec/ExchangeSource.h b/velox/exec/ExchangeSource.h index 452c8ded9f42..8cee990d8c02 100644 --- a/velox/exec/ExchangeSource.h +++ b/velox/exec/ExchangeSource.h @@ -40,12 +40,6 @@ class ExchangeSource : public std::enable_shared_from_this { std::shared_ptr queue, memory::MemoryPool* pool); - /// Temporary API to indicate whether 'request(maxBytes, maxWaitSeconds)' API - /// is supported. - virtual bool supportsFlowControlV2() const { - VELOX_UNREACHABLE(); - } - /// Temporary API to indicate whether 'metrics()' API /// is supported. virtual bool supportsMetrics() const { @@ -65,14 +59,6 @@ class ExchangeSource : public std::enable_shared_from_this { return requestPending_; } - /// Requests the producer to generate up to 'maxBytes' more data. - /// Returns a future that completes when producer responds either with 'data' - /// or with a message indicating that all data has been already produced or - /// data will take more time to produce. - virtual ContinueFuture request(uint32_t /*maxBytes*/) { - VELOX_NYI(); - } - struct Response { /// Size of the response in bytes. Zero means response didn't contain any /// data. diff --git a/velox/exec/GroupingSet.cpp b/velox/exec/GroupingSet.cpp index e45f4e5cb8c6..a0fb9136a89a 100644 --- a/velox/exec/GroupingSet.cpp +++ b/velox/exec/GroupingSet.cpp @@ -224,7 +224,12 @@ void GroupingSet::addInputForActiveRows( TestValue::adjust( "facebook::velox::exec::GroupingSet::addInputForActiveRows", this); - table_->prepareForGroupProbe(*lookup_, input, activeRows_, ignoreNullKeys_); + table_->prepareForGroupProbe( + *lookup_, + input, + activeRows_, + ignoreNullKeys_, + BaseHashTable::kNoSpillInputStartPartitionBit); if (lookup_->rows.empty()) { // No rows to probe. Can happen when ignoreNullKeys_ is true and all rows // have null keys. diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index c2743f1739d4..bfeb1cd6ff4e 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -334,6 +334,18 @@ void HashBuild::addInput(RowVectorPtr input) { hashers[i]->decode(*key, activeRows_); } + // Update statistics for null keys in join operator. + // We use activeRows_ to store which rows have some null keys, + // and reset it after using it. + // Only process when input is not spilled, to avoid overcounting. + if (!isInputFromSpill()) { + auto lockedStats = stats_.wlock(); + deselectRowsWithNulls(hashers, activeRows_); + lockedStats->numNullKeys += + activeRows_.size() - activeRows_.countSelected(); + activeRows_.setAll(); + } + if (!isRightJoin(joinType_) && !isFullJoin(joinType_) && !isRightSemiProjectJoin(joinType_) && !isLeftNullAwareJoinWithFilter(joinNode_)) { @@ -811,7 +823,9 @@ bool HashBuild::finishHashBuild() { table_->prepareJoinTable( std::move(otherTables), allowParallelJoinBuild ? operatorCtx_->task()->queryCtx()->executor() - : nullptr); + : nullptr, + isInputFromSpill() ? spillConfig()->startPartitionBit + : BaseHashTable::kNoSpillInputStartPartitionBit); addRuntimeStats(); if (joinBridge_->setHashTable( std::move(table_), std::move(spillPartitions), joinHasNullKeys_)) { diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 6f4281be01f5..228358c350a9 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -581,6 +581,17 @@ void HashProbe::addInput(RowVectorPtr input) { } activeRows_ = nonNullInputRows_; + // Update statistics for null keys in join operator. + // Updating here means we will report 0 null keys when build side is empty. + // If we want more accurate stats, we will have to decode input vector + // even when not needed. So we tradeoff less accurate stats for more + // performance. + { + auto lockedStats = stats_.wlock(); + lockedStats->numNullKeys += + activeRows_.size() - activeRows_.countSelected(); + } + table_->prepareForJoinProbe(*lookup_.get(), input_, activeRows_, false); passingInputRowsInitialized_ = false; diff --git a/velox/exec/HashTable.cpp b/velox/exec/HashTable.cpp index 1e0bbd6ffb3d..d09b3bc041f4 100644 --- a/velox/exec/HashTable.cpp +++ b/velox/exec/HashTable.cpp @@ -20,6 +20,7 @@ #include "velox/common/base/Portability.h" #include "velox/common/base/SimdUtil.h" #include "velox/common/process/ProcessBase.h" +#include "velox/common/process/TraceContext.h" #include "velox/common/testutil/TestValue.h" #include "velox/exec/OperatorUtils.h" #include "velox/vector/VectorTypeUtils.h" @@ -857,6 +858,7 @@ bool HashTable::canApplyParallelJoinBuild() const { template void HashTable::parallelJoinBuild() { + process::TraceContext trace("HashTable::parallelJoinBuild"); TestValue::adjust( "facebook::velox::exec::HashTable::parallelJoinBuild", rows_->pool()); VELOX_CHECK_LE(1 + otherTables_.size(), std::numeric_limits::max()); @@ -1605,7 +1607,8 @@ bool mayUseValueIds(const BaseHashTable& table) { template void HashTable::prepareJoinTable( std::vector> tables, - folly::Executor* executor) { + folly::Executor* executor, + int8_t spillInputStartPartitionBit) { buildExecutor_ = executor; otherTables_.reserve(tables.size()); for (auto& table : tables) { @@ -1648,6 +1651,7 @@ void HashTable::prepareJoinTable( } else { decideHashMode(0); } + checkHashBitsOverlap(spillInputStartPartitionBit); } template @@ -1980,7 +1984,9 @@ void BaseHashTable::prepareForGroupProbe( HashLookup& lookup, const RowVectorPtr& input, SelectivityVector& rows, - bool ignoreNullKeys) { + bool ignoreNullKeys, + int8_t spillInputStartPartitionBit) { + checkHashBitsOverlap(spillInputStartPartitionBit); auto& hashers = lookup.hashers; for (auto& hasher : hashers) { @@ -2013,7 +2019,8 @@ void BaseHashTable::prepareForGroupProbe( decideHashMode(input->size()); // Do not forward 'ignoreNullKeys' to avoid redundant evaluation of // deselectRowsWithNulls. - prepareForGroupProbe(lookup, input, rows, false); + prepareForGroupProbe( + lookup, input, rows, false, spillInputStartPartitionBit); return; } } diff --git a/velox/exec/HashTable.h b/velox/exec/HashTable.h index 5dc6e128934c..eec394caf599 100644 --- a/velox/exec/HashTable.h +++ b/velox/exec/HashTable.h @@ -121,6 +121,8 @@ class BaseHashTable { /// Specifies the hash mode of a table. enum class HashMode { kHash, kArray, kNormalizedKey }; + static constexpr int8_t kNoSpillInputStartPartitionBit = -1; + /// Returns the string of the given 'mode'. static std::string modeString(HashMode mode); @@ -181,7 +183,8 @@ class BaseHashTable { HashLookup& lookup, const RowVectorPtr& input, SelectivityVector& rows, - bool ignoreNullKeys); + bool ignoreNullKeys, + int8_t spillInputStartPartitionBit); /// Finds or creates a group for each key in 'lookup'. The keys are /// returned in 'lookup.hits'. @@ -248,7 +251,8 @@ class BaseHashTable { virtual void prepareJoinTable( std::vector> tables, - folly::Executor* executor = nullptr) = 0; + folly::Executor* executor = nullptr, + int8_t spillInputStartPartitionBit = kNoSpillInputStartPartitionBit) = 0; /// Returns the memory footprint in bytes for any data structures /// owned by 'this'. @@ -328,7 +332,12 @@ class BaseHashTable { /// Extracts a 7 bit tag from a hash number. The high bit is always set. static uint8_t hashTag(uint64_t hash) { - return static_cast(hash >> 32) | 0x80; + // This is likely all 0 for small key types (<= 32 bits). Not an issue + // because small types have a range that makes them normalized key cases. + // If there are multiple small type keys, they are mixed which makes them a + // 64 bit hash. Normalized keys are mixed before being used as hash + // numbers. + return static_cast(hash >> 38) | 0x80; } /// Loads a vector of tags for bulk comparison. Disables tsan errors @@ -365,6 +374,20 @@ class BaseHashTable { virtual void setHashMode(HashMode mode, int32_t numNew) = 0; + virtual int sizeBits() const = 0; + + // We don't want any overlap in the bit ranges used by bucket index and those + // used by spill partitioning; otherwise because we receive data from only one + // partition, the overlapped bits would be the same and only a fraction of the + // buckets would be used. This would cause the insertion taking very long + // time and block driver threads. + void checkHashBitsOverlap(int8_t spillInputStartPartitionBit) { + if (spillInputStartPartitionBit != kNoSpillInputStartPartitionBit && + hashMode() != HashMode::kArray) { + VELOX_CHECK_LE(sizeBits(), spillInputStartPartitionBit); + } + } + std::vector> hashers_; std::unique_ptr rows_; @@ -525,7 +548,9 @@ class HashTable : public BaseHashTable { // and VectorHashers and decides the hash mode and representation. void prepareJoinTable( std::vector> tables, - folly::Executor* executor = nullptr) override; + folly::Executor* executor = nullptr, + int8_t spillInputStartPartitionBit = + kNoSpillInputStartPartitionBit) override; uint64_t hashTableSizeIncrease(int32_t numNewDistinct) const override { if (numDistinct_ + numNewDistinct > rehashSize()) { @@ -587,10 +612,6 @@ class HashTable : public BaseHashTable { // occupy exactly two (64 bytes) cache lines. class Bucket { public: - Bucket() { - static_assert(sizeof(Bucket) == 128); - } - uint8_t tagAt(int32_t slotIndex) { return reinterpret_cast(&tags_)[slotIndex]; } @@ -622,6 +643,7 @@ class HashTable : public BaseHashTable { char padding_[16]; }; + static_assert(sizeof(Bucket) == 128); static constexpr uint64_t kBucketSize = sizeof(Bucket); // Returns the bucket at byte offset 'offset' from 'table_'. @@ -881,6 +903,10 @@ class HashTable : public BaseHashTable { } } + int sizeBits() const final { + return sizeBits_; + } + // The min table size in row to trigger parallel join table build. const uint32_t minTableSizeForParallelJoinBuild_; @@ -938,7 +964,7 @@ class HashTable : public BaseHashTable { // Executor for parallelizing hash join build. This may be the // executor for Drivers. If this executor is indefinitely taken by - // other work, the thread of prepareJoinTables() will sequentially + // other work, the thread of prepareJoinTable() will sequentially // execute the parallel build steps. folly::Executor* buildExecutor_{nullptr}; diff --git a/velox/exec/LocalPlanner.cpp b/velox/exec/LocalPlanner.cpp index 62e78b35e450..58dceab18b3e 100644 --- a/velox/exec/LocalPlanner.cpp +++ b/velox/exec/LocalPlanner.cpp @@ -223,6 +223,12 @@ uint32_t maxDrivers( } else if (std::dynamic_pointer_cast(node)) { // Merge join must run single-threaded. return 1; + } else if ( + auto join = std::dynamic_pointer_cast(node)) { + // Right semi project doesn't support multi-threaded execution. + if (join->isRightSemiProjectJoin()) { + return 1; + } } else if ( auto tableWrite = std::dynamic_pointer_cast(node)) { diff --git a/velox/exec/Operator.cpp b/velox/exec/Operator.cpp index a85d5a8e95eb..4f1804e372da 100644 --- a/velox/exec/Operator.cpp +++ b/velox/exec/Operator.cpp @@ -481,6 +481,8 @@ void OperatorStats::add(const OperatorStats& other) { spilledRows += other.spilledRows; spilledPartitions += other.spilledPartitions; spilledFiles += other.spilledFiles; + + numNullKeys += other.numNullKeys; } void OperatorStats::clear() { diff --git a/velox/exec/Operator.h b/velox/exec/Operator.h index 0a978188c8fe..189f209dd864 100644 --- a/velox/exec/Operator.h +++ b/velox/exec/Operator.h @@ -152,6 +152,11 @@ struct OperatorStats { int64_t lastLazyCpuNanos{0}; int64_t lastLazyWallNanos{0}; + // Total null keys processed by the operator. + // Currently populated only by HashJoin/HashBuild. + // HashProbe doesn't populate numNullKeys when build side is empty. + int64_t numNullKeys{0}; + std::unordered_map runtimeStats; int numDrivers = 0; diff --git a/velox/exec/RowNumber.cpp b/velox/exec/RowNumber.cpp index f753cde0e38f..04f289c818ee 100644 --- a/velox/exec/RowNumber.cpp +++ b/velox/exec/RowNumber.cpp @@ -78,7 +78,12 @@ void RowNumber::addInput(RowVectorPtr input) { } SelectivityVector rows(numInput); - table_->prepareForGroupProbe(*lookup_, input, rows, false); + table_->prepareForGroupProbe( + *lookup_, + input, + rows, + false, + BaseHashTable::kNoSpillInputStartPartitionBit); table_->groupProbe(*lookup_); // Initialize new partitions with zeros. @@ -93,7 +98,8 @@ void RowNumber::addInput(RowVectorPtr input) { void RowNumber::addSpillInput() { const auto numInput = input_->size(); SelectivityVector rows(numInput); - table_->prepareForGroupProbe(*lookup_, input_, rows, false); + table_->prepareForGroupProbe( + *lookup_, input_, rows, false, spillConfig_->startPartitionBit); table_->groupProbe(*lookup_); // Initialize new partitions with zeros. @@ -157,7 +163,8 @@ void RowNumber::restoreNextSpillPartition() { const auto numInput = input->size(); SelectivityVector rows(numInput); - table_->prepareForGroupProbe(*lookup_, input, rows, false); + table_->prepareForGroupProbe( + *lookup_, input, rows, false, spillConfig_->startPartitionBit); table_->groupProbe(*lookup_); auto* counts = data->children().back()->as>(); diff --git a/velox/exec/StreamingAggregation.cpp b/velox/exec/StreamingAggregation.cpp index b526c3fb2883..c2d600aebc3a 100644 --- a/velox/exec/StreamingAggregation.cpp +++ b/velox/exec/StreamingAggregation.cpp @@ -64,6 +64,17 @@ void StreamingAggregation::initialize() { // Setup SortedAggregations. sortedAggregations_ = SortedAggregations::create(aggregates_, inputType, pool()); + + distinctAggregations_.reserve(aggregates_.size()); + for (auto& aggregate : aggregates_) { + if (aggregate.distinct) { + distinctAggregations_.emplace_back( + DistinctAggregations::create({&aggregate}, inputType, pool())); + } else { + distinctAggregations_.push_back(nullptr); + } + } + masks_ = std::make_unique(extractMaskChannels(aggregates_)); rows_ = makeRowContainer(groupingKeyTypes); @@ -138,6 +149,11 @@ RowVectorPtr StreamingAggregation::createOutput(size_t numGroups) { if (!aggregate.sortingKeys.empty()) { continue; } + + if (aggregate.distinct) { + continue; + } + const auto& function = aggregate.function; auto& result = output->childAt(numKeys + i); if (isPartialOutput(step_)) { @@ -152,6 +168,13 @@ RowVectorPtr StreamingAggregation::createOutput(size_t numGroups) { folly::Range(groups_.data(), numGroups), output); } + for (const auto& aggregation : distinctAggregations_) { + if (aggregation != nullptr) { + aggregation->extractValues( + folly::Range(groups_.data(), numGroups), output); + } + } + return output; } @@ -209,6 +232,16 @@ void StreamingAggregation::evaluateAggregates() { continue; } + const auto& rows = getSelectivityVector(i); + if (!rows.hasSelections()) { + continue; + } + + if (aggregate.distinct) { + distinctAggregations_.at(i)->addInput(inputGroups_.data(), input_, rows); + continue; + } + const auto& function = aggregate.function; const auto& inputs = aggregate.inputs; const auto& constantInputs = aggregate.constantInputs; @@ -222,8 +255,6 @@ void StreamingAggregation::evaluateAggregates() { } } - const auto& rows = getSelectivityVector(i); - if (isRawInput(step_)) { function->addRawInput(inputGroups_.data(), rows, args, false); } else { @@ -297,6 +328,12 @@ std::unique_ptr StreamingAggregation::makeRowContainer( accumulators.push_back(sortedAggregations_->accumulator()); } + for (const auto& aggregation : distinctAggregations_) { + if (aggregation != nullptr) { + accumulators.push_back(aggregation->accumulator()); + } + } + return std::make_unique( groupingKeyTypes, !aggregationNode_->ignoreNullKeys(), @@ -318,43 +355,64 @@ void StreamingAggregation::initializeNewGroups(size_t numPrevGroups) { newGroups.resize(numGroups_ - numPrevGroups); std::iota(newGroups.begin(), newGroups.end(), numPrevGroups); - for (const auto& aggregate : aggregates_) { + for (auto i = 0; i < aggregates_.size(); ++i) { + const auto& aggregate = aggregates_.at(i); if (!aggregate.sortingKeys.empty()) { continue; } - aggregate.function->initializeNewGroups( - groups_.data(), folly::Range(newGroups.data(), newGroups.size())); + if (aggregate.distinct) { + distinctAggregations_.at(i)->initializeNewGroups( + groups_.data(), newGroups); + continue; + } + + aggregate.function->initializeNewGroups(groups_.data(), newGroups); } if (sortedAggregations_) { - sortedAggregations_->initializeNewGroups( - groups_.data(), folly::Range(newGroups.data(), newGroups.size())); + sortedAggregations_->initializeNewGroups(groups_.data(), newGroups); } } void StreamingAggregation::initializeAggregates(uint32_t numKeys) { - for (auto i = 0; i < aggregates_.size(); ++i) { - auto& function = aggregates_[i].function; + int32_t columnIndex = numKeys; + for (auto& aggregate : aggregates_) { + auto& function = aggregate.function; function->setAllocator(&rows_->stringAllocator()); - const auto rowColumn = rows_->columnAt(numKeys + i); + const auto rowColumn = rows_->columnAt(columnIndex); function->setOffsets( rowColumn.offset(), rowColumn.nullByte(), rowColumn.nullMask(), rows_->rowSizeOffset()); + columnIndex++; } if (sortedAggregations_) { sortedAggregations_->setAllocator(&rows_->stringAllocator()); - const auto rowColumn = - rows_->columnAt(rows_->keyTypes().size() + aggregates_.size()); + const auto& rowColumn = rows_->columnAt(columnIndex); sortedAggregations_->setOffsets( rowColumn.offset(), rowColumn.nullByte(), rowColumn.nullMask(), rows_->rowSizeOffset()); + columnIndex++; + } + + for (const auto& aggregation : distinctAggregations_) { + if (aggregation != nullptr) { + aggregation->setAllocator(&rows_->stringAllocator()); + + const auto& rowColumn = rows_->columnAt(columnIndex); + aggregation->setOffsets( + rowColumn.offset(), + rowColumn.nullByte(), + rowColumn.nullMask(), + rows_->rowSizeOffset()); + columnIndex++; + } } }; diff --git a/velox/exec/StreamingAggregation.h b/velox/exec/StreamingAggregation.h index 65ff35fb51e6..953ffd0a33ce 100644 --- a/velox/exec/StreamingAggregation.h +++ b/velox/exec/StreamingAggregation.h @@ -18,6 +18,7 @@ #include "velox/exec/Aggregate.h" #include "velox/exec/AggregateInfo.h" #include "velox/exec/AggregationMasks.h" +#include "velox/exec/DistinctAggregations.h" #include "velox/exec/Operator.h" #include "velox/exec/SortedAggregations.h" @@ -93,6 +94,7 @@ class StreamingAggregation : public Operator { std::vector groupingKeys_; std::vector aggregates_; std::unique_ptr sortedAggregations_; + std::vector> distinctAggregations_; std::unique_ptr masks_; std::vector decodedKeys_; diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp index d925fc47c4d7..fd5addd7af67 100644 --- a/velox/exec/Task.cpp +++ b/velox/exec/Task.cpp @@ -2000,11 +2000,13 @@ bool Task::getLongRunningOpCalls( if (!opCallStatus.empty()) { auto callDurationMs = opCallStatus.callDuration(); if (callDurationMs > thresholdDurationMs) { + auto* op = driver->findOperatorNoThrow(opCallStatus.opId); out.push_back({ .durationMs = callDurationMs, .tid = driver->state().tid, .opId = opCallStatus.opId, .taskId = taskId_, + .opCall = OpCallStatusRaw::formatCall(op, opCallStatus.method), }); } } diff --git a/velox/exec/Task.h b/velox/exec/Task.h index 6d3ec629fe94..cb4a8507f13a 100644 --- a/velox/exec/Task.h +++ b/velox/exec/Task.h @@ -257,9 +257,12 @@ class Task : public std::enable_shared_from_this { /// Information about an operator call that helps debugging stuck calls. struct OpCallInfo { size_t durationMs; + /// Thread id of where the operator got stuck. int32_t tid; int32_t opId; std::string taskId; + /// Call in the format of ".::". + std::string opCall; }; /// Collect long running operator calls across all drivers in this task. diff --git a/velox/exec/TopNRowNumber.cpp b/velox/exec/TopNRowNumber.cpp index ceb9dc2131e3..5ad9184c0bd5 100644 --- a/velox/exec/TopNRowNumber.cpp +++ b/velox/exec/TopNRowNumber.cpp @@ -191,7 +191,12 @@ void TopNRowNumber::addInput(RowVectorPtr input) { ensureInputFits(input); SelectivityVector rows(numInput); - table_->prepareForGroupProbe(*lookup_, input, rows, false); + table_->prepareForGroupProbe( + *lookup_, + input, + rows, + false, + BaseHashTable::kNoSpillInputStartPartitionBit); table_->groupProbe(*lookup_); // Initialize new partitions. diff --git a/velox/exec/fuzzer/AggregationFuzzer.cpp b/velox/exec/fuzzer/AggregationFuzzer.cpp index c168d7161c0e..8b22c95746d3 100644 --- a/velox/exec/fuzzer/AggregationFuzzer.cpp +++ b/velox/exec/fuzzer/AggregationFuzzer.cpp @@ -83,7 +83,6 @@ class AggregationFuzzer : public AggregationFuzzerBase { // Number of iterations using aggregations over distinct inputs. size_t numDistinctInputs{0}; - // Number of iterations using window expressions. size_t numWindow{0}; @@ -1142,6 +1141,21 @@ bool AggregationFuzzer::verifyDistinctAggregation( std::vector plans; plans.push_back({firstPlan, {}}); + if (!groupingKeys.empty()) { + plans.push_back( + {PlanBuilder() + .values(input) + .orderBy(groupingKeys, false) + .streamingAggregation( + groupingKeys, + aggregates, + masks, + core::AggregationNode::Step::kSingle, + false) + .planNode(), + {}}); + } + // Alternate between using Values and TableScan node. std::shared_ptr directory; @@ -1156,6 +1170,21 @@ bool AggregationFuzzer::verifyDistinctAggregation( .singleAggregation(groupingKeys, aggregates, masks) .planNode(), splits}); + + if (!groupingKeys.empty()) { + plans.push_back( + {PlanBuilder() + .tableScan(inputRowType) + .orderBy(groupingKeys, false) + .streamingAggregation( + groupingKeys, + aggregates, + masks, + core::AggregationNode::Step::kSingle, + false) + .planNode(), + splits}); + } } if (persistAndRunOnce_) { diff --git a/velox/exec/tests/AddressableNonNullValueListTest.cpp b/velox/exec/tests/AddressableNonNullValueListTest.cpp index 546d3fb47476..bb4983ea0308 100644 --- a/velox/exec/tests/AddressableNonNullValueListTest.cpp +++ b/velox/exec/tests/AddressableNonNullValueListTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/exec/AddressableNonNullValueList.h" #include +#include "velox/common/base/IOUtils.h" #include "velox/vector/tests/utils/VectorTestBase.h" namespace facebook::velox::aggregate::prestosql { @@ -28,14 +29,17 @@ class AddressableNonNullValueListTest : public testing::Test, memory::MemoryManager::testingSetInstance({}); } - void test(const VectorPtr& data, const VectorPtr& uniqueData) { - using T = AddressableNonNullValueList::Entry; - using Set = folly::F14FastSet< - T, - AddressableNonNullValueList::Hash, - AddressableNonNullValueList::EqualTo, - AlignedStlAllocator>; + using T = AddressableNonNullValueList::Entry; + using Set = folly::F14FastSet< + T, + AddressableNonNullValueList::Hash, + AddressableNonNullValueList::EqualTo, + AlignedStlAllocator>; + + static constexpr size_t kSizeOfHash = sizeof(uint64_t); + static constexpr size_t kSizeOfLength = sizeof(vector_size_t); + void test(const VectorPtr& data, const VectorPtr& uniqueData) { Set uniqueValues{ 0, AddressableNonNullValueList::Hash{}, @@ -46,6 +50,10 @@ class AddressableNonNullValueListTest : public testing::Test, std::vector entries; + // Tracks the number of bytes for serializing the + // AddressableNonNullValueList. + vector_size_t totalSize = 0; + DecodedVector decodedVector(*data); for (auto i = 0; i < data->size(); ++i) { auto entry = values.append(decodedVector, i, allocator()); @@ -56,6 +64,9 @@ class AddressableNonNullValueListTest : public testing::Test, } entries.push_back(entry); + // The total size for serialization is + // (size of length + size of hash + actual value size) for each entry. + totalSize += entry.size + kSizeOfHash + kSizeOfLength; ASSERT_TRUE(uniqueValues.insert(entry).second); ASSERT_TRUE(uniqueValues.contains(entry)); @@ -65,7 +76,19 @@ class AddressableNonNullValueListTest : public testing::Test, ASSERT_EQ(uniqueData->size(), values.size()); ASSERT_EQ(uniqueData->size(), uniqueValues.size()); - auto copy = BaseVector::create(data->type(), uniqueData->size(), pool()); + testDirectRead(entries, uniqueValues, uniqueData); + testSerialization(entries, totalSize, uniqueData); + } + + // Test direct read from AddressableNonNullValueList. + // Reads AddressableNonNullValueList into a vector, and validates its + // content. + void testDirectRead( + const std::vector& entries, + const Set& uniqueValues, + const VectorPtr& uniqueData) { + auto copy = + BaseVector::create(uniqueData->type(), uniqueData->size(), pool()); for (auto i = 0; i < entries.size(); ++i) { auto entry = entries[i]; ASSERT_TRUE(uniqueValues.contains(entry)); @@ -75,6 +98,57 @@ class AddressableNonNullValueListTest : public testing::Test, test::assertEqualVectors(uniqueData, copy); } + // Test copy/appendSerialized round-trip for AddressableNonNullValueList. + // Steps in the test: + // i) Copy entry length, hash and value of each entry to a stream. + // ii) Deserialize stream to a new set of entries. + // iii) Read deserialized entries back into a vector. + // iv) Validate the result vector. + void testSerialization( + const std::vector& entries, + vector_size_t totalSize, + const VectorPtr& uniqueData) { + size_t offset = 0; + auto buffer = AlignedBuffer::allocate(totalSize, pool()); + auto* rawBuffer = buffer->asMutable(); + + auto append = [&](const void* value, size_t size) { + memcpy((void*)(rawBuffer + offset), value, size); + offset += size; + }; + + for (const auto& entry : entries) { + append(&entry.size, kSizeOfLength); + append(&entry.hash, kSizeOfHash); + AddressableNonNullValueList::readSerialized( + entry, (char*)(rawBuffer + offset)); + offset += entry.size; + } + ASSERT_EQ(offset, totalSize); + + // Deserialize entries from the stream. + AddressableNonNullValueList deserialized; + std::vector deserializedEntries; + common::InputByteStream stream(rawBuffer); + while (stream.offset() < totalSize) { + auto length = stream.read(); + auto hash = stream.read(); + StringView contents(stream.read(length), length); + auto position = deserialized.appendSerialized(contents, allocator()); + deserializedEntries.push_back({position, contents.size(), hash}); + } + + // Direct read from deserialized AddressableNonNullValueList. Validate the + // results. + auto deserializedCopy = + BaseVector::create(uniqueData->type(), uniqueData->size(), pool()); + for (auto i = 0; i < deserializedEntries.size(); ++i) { + auto entry = deserializedEntries[i]; + AddressableNonNullValueList::read(entry, *deserializedCopy, i); + } + test::assertEqualVectors(uniqueData, deserializedCopy); + } + HashStringAllocator* allocator() { return allocator_.get(); } diff --git a/velox/exec/tests/DriverTest.cpp b/velox/exec/tests/DriverTest.cpp index 5237c5538db7..64f3492acc99 100644 --- a/velox/exec/tests/DriverTest.cpp +++ b/velox/exec/tests/DriverTest.cpp @@ -1418,3 +1418,66 @@ DEBUG_ONLY_TEST_F(DriverTest, driverCpuTimeSlicingCheck) { } } } + +class OpCallStatusTest : public OperatorTestBase {}; + +// Test that the opCallStatus is returned properly and formats the call as +// expected. +TEST_F(OpCallStatusTest, basic) { + std::vector data{ + makeRowVector({"c0"}, {makeFlatVector({1, 2, 3})})}; + + const int firstNodeId{17}; + auto planNodeIdGenerator = + std::make_shared(firstNodeId); + auto fragment = PlanBuilder(planNodeIdGenerator).values(data).planFragment(); + + std::unordered_map queryConfig; + auto task = Task::create( + "t19", + fragment, + 0, + std::make_shared( + driverExecutor_.get(), std::move(queryConfig)), + [](RowVectorPtr /*unused*/, ContinueFuture* /*unused*/) { + return exec::BlockingReason::kNotBlocked; + }); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Values::getOutput", + std::function([&](const exec::Values* values) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + auto* driver = values->testingOperatorCtx()->driver(); + auto ocs = driver->opCallStatus(); + // Check osc to be not empty and the correct format. + EXPECT_FALSE(ocs.empty()); + const auto formattedOpCall = + ocs.formatCall(driver->findOperatorNoThrow(ocs.opId), ocs.method); + EXPECT_EQ( + formattedOpCall, + fmt::format("Values.{}::{}", firstNodeId, ocs.method)); + // Check the correct format when operator is not found. + ocs.method = "randomName"; + EXPECT_EQ( + ocs.formatCall( + driver->findOperatorNoThrow(ocs.opId + 10), ocs.method), + fmt::format("null::{}", ocs.method)); + + // Check that the task returns correct long running op call. + std::vector stuckCalls; + const std::chrono::milliseconds lockTimeoutMs(10); + task->getLongRunningOpCalls(lockTimeoutMs, 10, stuckCalls); + EXPECT_EQ(stuckCalls.size(), 1); + if (!stuckCalls.empty()) { + const auto& stuckCall = stuckCalls[0]; + EXPECT_EQ(stuckCall.opId, ocs.opId); + EXPECT_GE(stuckCall.durationMs, 100); + EXPECT_EQ(stuckCall.tid, driver->state().tid); + EXPECT_EQ(stuckCall.taskId, task->taskId()); + EXPECT_EQ(stuckCall.opCall, formattedOpCall); + } + })); + + task->start(1, 1); + ASSERT_TRUE(waitForTaskCompletion(task.get(), 600'000'000)); +}; diff --git a/velox/exec/tests/ExchangeClientTest.cpp b/velox/exec/tests/ExchangeClientTest.cpp index ea2273d883a9..782a2b7452b7 100644 --- a/velox/exec/tests/ExchangeClientTest.cpp +++ b/velox/exec/tests/ExchangeClientTest.cpp @@ -255,11 +255,14 @@ TEST_F(ExchangeClientTest, multiPageFetch) { auto client = std::make_shared("test", 17, 1 << 20, pool(), executor()); - bool atEnd; - ContinueFuture future; - auto pages = client->next(1, &atEnd, &future); - ASSERT_EQ(0, pages.size()); - ASSERT_FALSE(atEnd); + { + bool atEnd; + ContinueFuture future = ContinueFuture::makeEmpty(); + auto pages = client->next(1, &atEnd, &future); + ASSERT_EQ(0, pages.size()); + ASSERT_FALSE(atEnd); + ASSERT_TRUE(future.valid()); + } const auto& queue = client->queue(); addSources(*queue, 1); @@ -269,20 +272,25 @@ TEST_F(ExchangeClientTest, multiPageFetch) { } // Fetch one page. - pages = client->next(1, &atEnd, &future); + bool atEnd; + ContinueFuture future = ContinueFuture::makeEmpty(); + auto pages = client->next(1, &atEnd, &future); ASSERT_EQ(1, pages.size()); ASSERT_FALSE(atEnd); + ASSERT_FALSE(future.valid()); // Fetch multiple pages. Each page is slightly larger than 1K bytes, hence, // only 4 pages fit. pages = client->next(5'000, &atEnd, &future); ASSERT_EQ(4, pages.size()); ASSERT_FALSE(atEnd); + ASSERT_FALSE(future.valid()); // Fetch the rest of the pages. pages = client->next(10'000, &atEnd, &future); ASSERT_EQ(5, pages.size()); ASSERT_FALSE(atEnd); + ASSERT_FALSE(future.valid()); // Signal no-more-data. enqueue(*queue, nullptr); @@ -290,6 +298,7 @@ TEST_F(ExchangeClientTest, multiPageFetch) { pages = client->next(10'000, &atEnd, &future); ASSERT_EQ(0, pages.size()); ASSERT_TRUE(atEnd); + ASSERT_FALSE(future.valid()); client->close(); } diff --git a/velox/exec/tests/FilterProjectTest.cpp b/velox/exec/tests/FilterProjectTest.cpp index b530b007690b..57e8d8280be1 100644 --- a/velox/exec/tests/FilterProjectTest.cpp +++ b/velox/exec/tests/FilterProjectTest.cpp @@ -325,42 +325,6 @@ TEST_F(FilterProjectTest, projectAndIdentityOverLazy) { assertQuery(plan, "SELECT c0 < 10 AND c1 < 10, c1 FROM tmp"); } -// Verify that nulls on nested parent are propagated to child without copying -// the child. Note that null on top level columns are handled separately in -// Expr::evalWithNulls; this happens only once per expression tree so we are not -// optimizing that code. We are testing the optimization of potentially more -// expensive case of FieldReference::evalSpecialForm here. -TEST_F(FilterProjectTest, nestedFieldReference) { - auto vector = makeRowVector({ - makeRowVector({ - makeRowVector( - { - makeRowVector({ - makeFlatVector(10, folly::identity), - }), - }, - nullEvery(2)), - }), - }); - CursorParameters params; - params.planNode = - PlanBuilder().values({vector}).project({"(c0).c0.c0.c0"}).planNode(); - params.copyResult = false; - auto cursor = TaskCursor::create(params); - ASSERT_TRUE(cursor->moveNext()); - auto result = cursor->current(); - auto* actual = result->as()->childAt(0).get(); - const BaseVector* expected = vector.get(); - for (int i = 0; i < 4; ++i) { - expected = expected->as()->childAt(0).get(); - } - ASSERT_EQ(*actual->type(), *expected->type()); - ASSERT_EQ(actual, expected); - for (int i = 0; i < actual->size(); ++i) { - ASSERT_EQ(actual->isNullAt(i), i % 2 == 0); - } -} - // Verify the optimization of avoiding copy in null propagation does not break // the case when the field is shared between multiple parents. TEST_F(FilterProjectTest, nestedFieldReferenceSharedChild) { diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index b026b5967c1e..08dc5b8f7604 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -2344,10 +2344,84 @@ TEST_P(MultiThreadedHashJoinTest, leftJoin) { .buildVectors(std::move(buildVectors)) .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) .joinType(core::JoinType::kLeft) - //.joinOutputLayout({"row_number", "c0", "c1", "u_c1"}) .joinOutputLayout({"row_number", "c0", "c1", "u_c0"}) .referenceQuery( "SELECT t.row_number, t.c0, t.c1, u.c0 FROM t LEFT JOIN u ON t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + int nullJoinBuildKeyCount = 0; + int nullJoinProbeKeyCount = 0; + + for (auto& pipeline : task->taskStats().pipelineStats) { + for (auto op : pipeline.operatorStats) { + if (op.operatorType == "HashBuild") { + nullJoinBuildKeyCount += op.numNullKeys; + } + if (op.operatorType == "HashProbe") { + nullJoinProbeKeyCount += op.numNullKeys; + } + } + } + ASSERT_EQ(nullJoinBuildKeyCount, 33 * GetParam().numDrivers); + ASSERT_EQ(nullJoinProbeKeyCount, 34 * GetParam().numDrivers); + }) + .run(); +} + +TEST_P(MultiThreadedHashJoinTest, nullStatsWithEmptyBuild) { + std::vector probeVectors = + makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector( + {"c0", "c1", "row_number"}, + { + makeFlatVector( + 77, [](auto row) { return row % 21; }, nullEvery(13)), + makeFlatVector(77, [](auto row) { return row; }), + makeFlatVector(77, [](auto row) { return row; }), + }); + }); + + // All null keys on build side. + std::vector buildVectors = + makeBatches(1, [&](int32_t /*unused*/) { + return makeRowVector({ + makeFlatVector( + 1, [](auto row) { return row % 5; }, nullEvery(1)), + makeFlatVector( + 1, [](auto row) { return -111 + row * 2; }, nullEvery(1)), + }); + }); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .probeKeys({"c0"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_c0"}) + .buildVectors(std::move(buildVectors)) + .buildProjections({"c0 AS u_c0", "c1 AS u_c1"}) + .joinType(core::JoinType::kLeft) + .joinOutputLayout({"row_number", "c0", "c1", "u_c0"}) + .referenceQuery( + "SELECT t.row_number, t.c0, t.c1, u.c0 FROM t LEFT JOIN u ON t.c0 = u.c0") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + int nullJoinBuildKeyCount = 0; + int nullJoinProbeKeyCount = 0; + + for (auto& pipeline : task->taskStats().pipelineStats) { + for (auto op : pipeline.operatorStats) { + if (op.operatorType == "HashBuild") { + nullJoinBuildKeyCount += op.numNullKeys; + } + if (op.operatorType == "HashProbe") { + nullJoinProbeKeyCount += op.numNullKeys; + } + } + } + // Due to inaccurate stats tracking in case of empty build side, + // we will report 0 null keys on probe side. + ASSERT_EQ(nullJoinProbeKeyCount, 0); + ASSERT_EQ(nullJoinBuildKeyCount, 1 * GetParam().numDrivers); + }) + .checkSpillStats(false) .run(); } @@ -4964,6 +5038,22 @@ TEST_F(HashJoinTest, spillFileSize) { } } +TEST_F(HashJoinTest, spillPartitionBitsOverlap) { + auto builder = + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .keyTypes({BIGINT(), BIGINT()}) + .probeVectors(2'000, 3) + .buildVectors(2'000, 3) + .referenceQuery( + "SELECT t_k0, t_k1, t_data, u_k0, u_k1, u_data FROM t, u WHERE t_k0 = u_k0 and t_k1 = u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "8") + .config(core::QueryConfig::kJoinSpillPartitionBits, "1") + .checkSpillStats(false) + .maxSpillLevel(0); + VELOX_ASSERT_THROW(builder.run(), "vs. 8"); +} + // The test is to verify if the hash build reservation has been released on // task error. DEBUG_ONLY_TEST_F(HashJoinTest, buildReservationReleaseCheck) { @@ -5168,6 +5258,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringInputProcessing) { .spillDirectory(testData.spillEnabled ? tempDirectory->path : "") .referenceQuery( "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") .verifier([&](const std::shared_ptr& task, bool /*unused*/) { const auto statsPair = taskSpilledStats(*task); if (testData.expectedReclaimable) { @@ -5320,6 +5411,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringReserve) { .spillDirectory(tempDirectory->path) .referenceQuery( "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") .verifier([&](const std::shared_ptr& task, bool /*unused*/) { const auto statsPair = taskSpilledStats(*task); ASSERT_GT(statsPair.first.spilledBytes, 0); @@ -5714,6 +5806,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { .spillDirectory(tempDirectory->path) .referenceQuery( "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") .verifier([&](const std::shared_ptr& task, bool /*unused*/) { const auto statsPair = taskSpilledStats(*task); ASSERT_GT(statsPair.first.spilledBytes, 0); @@ -6277,6 +6370,7 @@ TEST_F(HashJoinTest, exceededMaxSpillLevel) { .spillDirectory(tempDirectory->path) .referenceQuery( "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") .verifier([&](const std::shared_ptr& task, bool /*unused*/) { auto joinStats = task->taskStats() .pipelineStats.back() diff --git a/velox/exec/tests/JoinFuzzer.cpp b/velox/exec/tests/JoinFuzzer.cpp index 159b2c088aba..6338de322387 100644 --- a/velox/exec/tests/JoinFuzzer.cpp +++ b/velox/exec/tests/JoinFuzzer.cpp @@ -140,10 +140,15 @@ class JoinFuzzer { JoinFuzzer::JoinFuzzer(size_t initialSeed) : vectorFuzzer_{getFuzzerOptions(), pool_.get()} { filesystems::registerLocalFileSystem(); + + // Make sure not to run out of open file descriptors. + const std::unordered_map hiveConfig = { + {connector::hive::HiveConfig::kNumCacheFileHandles, "1000"}}; auto hiveConnector = connector::getConnectorFactory( connector::hive::HiveConnectorFactory::kHiveConnectorName) - ->newConnector(kHiveConnectorId, std::make_shared()); + ->newConnector( + kHiveConnectorId, std::make_shared(hiveConfig)); connector::registerConnector(hiveConnector); seed(initialSeed); @@ -277,23 +282,10 @@ std::vector flatten(const std::vector& vectors) { return flatVectors; } -bool isNullAwareRightSemiProjectJoin(const core::PlanNodePtr& plan) { - if (auto joinNode = dynamic_cast(plan.get())) { - return joinNode->isNullAware() && - joinNode->joinType() == core::JoinType::kRightSemiProject; - } - - return false; -} - RowVectorPtr JoinFuzzer::execute(const PlanWithSplits& plan, bool injectSpill) { LOG(INFO) << "Executing query plan: " << std::endl << plan.plan->toString(true, true); - // Null-aware right semi project join doesn't support multi-threaded - // execution. - const int maxDrivers = isNullAwareRightSemiProjectJoin(plan.plan) ? 1 : 2; - AssertQueryBuilder builder(plan.plan); for (const auto& [nodeId, nodeSplits] : plan.splits) { builder.splits(nodeId, nodeSplits); @@ -308,7 +300,7 @@ RowVectorPtr JoinFuzzer::execute(const PlanWithSplits& plan, bool injectSpill) { .spillDirectory(spillDirectory->path); } - auto result = builder.maxDrivers(maxDrivers).copyResults(pool_.get()); + auto result = builder.maxDrivers(2).copyResults(pool_.get()); LOG(INFO) << "Results: " << result->toString(); if (VLOG_IS_ON(1)) { VLOG(1) << std::endl << result->toString(0, result->size()); diff --git a/velox/exec/tests/StreamingAggregationTest.cpp b/velox/exec/tests/StreamingAggregationTest.cpp index ed30f34b978d..bf0a4cf5b0c7 100644 --- a/velox/exec/tests/StreamingAggregationTest.cpp +++ b/velox/exec/tests/StreamingAggregationTest.cpp @@ -32,7 +32,7 @@ class StreamingAggregationTest : public OperatorTestBase { void testAggregation( const std::vector& keys, - uint32_t outputBatchSize = 1'024) { + uint32_t outputBatchSize) { std::vector data; vector_size_t totalSize = 0; @@ -113,7 +113,7 @@ class StreamingAggregationTest : public OperatorTestBase { void testSortedAggregation( const std::vector& keys, - uint32_t outputBatchSize = 1'024) { + uint32_t outputBatchSize) { std::vector data; vector_size_t totalSize = 0; @@ -146,6 +146,60 @@ class StreamingAggregationTest : public OperatorTestBase { "SELECT c0, max(c1 order by c2), max(c1 order by c2 desc), array_agg(c1 order by c2) FROM tmp GROUP BY c0"); } + void testDistinctAggregation( + const std::vector& keys, + uint32_t outputBatchSize) { + std::vector data; + + vector_size_t totalSize = 0; + for (const auto& keyVector : keys) { + auto size = keyVector->size(); + auto payload = makeFlatVector( + size, [totalSize](auto row) { return totalSize + row; }); + data.push_back(makeRowVector({keyVector, payload, payload})); + totalSize += size; + } + createDuckDbTable(data); + + { + auto plan = PlanBuilder() + .values(data) + .streamingAggregation( + {"c0"}, + {"array_agg(distinct c1)", + "array_agg(c1 order by c2)", + "count(distinct c1)", + "array_agg(c2)"}, + {}, + core::AggregationNode::Step::kSingle, + false) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(outputBatchSize)) + .assertResults( + "SELECT c0, array_agg(distinct c1), array_agg(c1 order by c2), " + "count(distinct c1), array_agg(c2) FROM tmp GROUP BY c0"); + } + + { + auto plan = + PlanBuilder() + .values(data) + .streamingAggregation( + {"c0"}, {}, {}, core::AggregationNode::Step::kSingle, false) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(outputBatchSize)) + .assertResults("SELECT distinct c0 FROM tmp"); + } + } + std::vector addPayload(const std::vector& keys) { auto numKeys = keys[0]->type()->size(); @@ -172,7 +226,7 @@ class StreamingAggregationTest : public OperatorTestBase { void testMultiKeyAggregation( const std::vector& keys, - uint32_t outputBatchSize = 1'024) { + uint32_t outputBatchSize) { testMultiKeyAggregation( keys, keys[0]->type()->asRow().names(), outputBatchSize); } @@ -180,7 +234,7 @@ class StreamingAggregationTest : public OperatorTestBase { void testMultiKeyAggregation( const std::vector& keys, const std::vector& preGroupedKeys, - uint32_t outputBatchSize = 1'024) { + uint32_t outputBatchSize) { auto data = addPayload(keys); createDuckDbTable(data); @@ -224,6 +278,73 @@ class StreamingAggregationTest : public OperatorTestBase { EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed); } + + void testMultiKeyDistinctAggregation( + const std::vector& keys, + uint32_t outputBatchSize) { + auto data = addPayload(keys); + createDuckDbTable(data); + + { + auto plan = + PlanBuilder() + .values(data) + .streamingAggregation( + keys[0]->type()->asRow().names(), + {"count(distinct c1)", "array_agg(c1)", "sumnonpod(1)"}, + {}, + core::AggregationNode::Step::kSingle, + false) + .planNode(); + + // Generate a list of grouping keys to use in the query: c0, c1, c2,.. + std::ostringstream keySql; + keySql << "c0"; + for (auto i = 1; i < numKeys(keys); i++) { + keySql << ", c" << i; + } + + const auto sql = fmt::format( + "SELECT {}, count(distinct c1), array_agg(c1), sum(1) FROM tmp GROUP BY {}", + keySql.str(), + keySql.str()); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(outputBatchSize)) + .assertResults(sql); + + EXPECT_EQ(NonPODInt64::constructed, NonPODInt64::destructed); + } + + { + auto plan = PlanBuilder() + .values(data) + .streamingAggregation( + keys[0]->type()->asRow().names(), + {}, + {}, + core::AggregationNode::Step::kSingle, + false) + .planNode(); + + // Generate a list of grouping keys to use in the query: c0, c1, c2,.. + std::ostringstream keySql; + keySql << "c0"; + for (auto i = 1; i < numKeys(keys); i++) { + keySql << ", c" << i; + } + + const auto sql = fmt::format("SELECT distinct {} FROM tmp", keySql.str()); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config( + core::QueryConfig::kPreferredOutputBatchRows, + std::to_string(outputBatchSize)) + .assertResults(sql); + } + } }; TEST_F(StreamingAggregationTest, smallInputBatches) { @@ -236,7 +357,7 @@ TEST_F(StreamingAggregationTest, smallInputBatches) { makeFlatVector({6, 7, 8}), }; - testAggregation(keys); + testAggregation(keys, 1024); // Cut output into tiny batches of size 3. testAggregation(keys, 3); @@ -258,7 +379,7 @@ TEST_F(StreamingAggregationTest, multipleKeys) { }), }; - testMultiKeyAggregation(keys); + testMultiKeyAggregation(keys, 1024); // Cut output into tiny batches of size 3. testMultiKeyAggregation(keys, 3); @@ -277,7 +398,7 @@ TEST_F(StreamingAggregationTest, regularSizeInputBatches) { 78, [size](auto row) { return (3 * size + row) / 5; }), }; - testAggregation(keys); + testAggregation(keys, 1024); // Cut output into small batches of size 100. testAggregation(keys, 100); @@ -294,7 +415,7 @@ TEST_F(StreamingAggregationTest, uniqueKeys) { makeFlatVector(78, [size](auto row) { return 3 * size + row; }), }; - testAggregation(keys); + testAggregation(keys, 1024); // Cut output into small batches of size 100. testAggregation(keys, 100); @@ -335,7 +456,7 @@ TEST_F(StreamingAggregationTest, partialStreaming) { }), }; - testMultiKeyAggregation(keys, {"c0"}); + testMultiKeyAggregation(keys, {"c0"}, 1024); } // Test StreamingAggregation being closed without being initialized. Create a @@ -369,7 +490,7 @@ TEST_F(StreamingAggregationTest, closeUninitialized) { } TEST_F(StreamingAggregationTest, sortedAggregations) { - auto size = 128; + auto size = 1024; std::vector keys = { makeFlatVector(size, [](auto row) { return row; }), @@ -380,27 +501,40 @@ TEST_F(StreamingAggregationTest, sortedAggregations) { 78, [size](auto row) { return (3 * size + row); }), }; - testSortedAggregation(keys); + testSortedAggregation(keys, 1024); testSortedAggregation(keys, 32); } TEST_F(StreamingAggregationTest, distinctAggregations) { - auto data = makeRowVector({ - makeFlatVector(1'000, [](auto row) { return row / 4; }), - makeFlatVector(1'000, [](auto row) { return row % 3; }), - }); + auto size = 1024; - auto plan = PlanBuilder() - .values({data}) - .streamingAggregation( - {"c0"}, - {"array_agg(distinct c1)"}, - {}, - core::AggregationNode::Step::kSingle, - false) - .planNode(); + std::vector keys = { + makeFlatVector(size, [](auto row) { return row; }), + makeFlatVector(size, [size](auto row) { return (size + row); }), + makeFlatVector( + size, [size](auto row) { return (2 * size + row); }), + makeFlatVector( + 78, [size](auto row) { return (3 * size + row); }), + }; - VELOX_ASSERT_THROW( - AssertQueryBuilder(plan).copyResults(pool()), - "Streaming aggregation doesn't support aggregations over distinct inputs yet"); + testDistinctAggregation(keys, 1024); + testDistinctAggregation(keys, 32); + + std::vector multiKeys = { + makeRowVector({ + makeFlatVector({1, 1, 2, 2, 2}), + makeFlatVector({10, 20, 20, 30, 30}), + }), + makeRowVector({ + makeFlatVector({2, 3, 3, 3, 4}), + makeFlatVector({30, 30, 40, 40, 40}), + }), + makeRowVector({ + makeNullableFlatVector({5, 5, 6, 6, 6}), + makeNullableFlatVector({40, 50, 50, 50, 50}), + }), + }; + + testMultiKeyDistinctAggregation(multiKeys, 1024); + testMultiKeyDistinctAggregation(multiKeys, 3); } diff --git a/velox/exec/tests/TableScanTest.cpp b/velox/exec/tests/TableScanTest.cpp index a0919adb5d5b..76485ebb37fa 100644 --- a/velox/exec/tests/TableScanTest.cpp +++ b/velox/exec/tests/TableScanTest.cpp @@ -1400,7 +1400,9 @@ TEST_F(TableScanTest, fileNotFound) { }; assertMissingFile(true); VELOX_ASSERT_RUNTIME_THROW_CODE( - assertMissingFile(false), error_code::kFileNotFound); + assertMissingFile(false), + error_code::kFileNotFound, + "No such file or directory"); } // A valid ORC file (containing headers) but no data. diff --git a/velox/exec/tests/utils/ArbitratorTestUtil.cpp b/velox/exec/tests/utils/ArbitratorTestUtil.cpp index aeb75584beab..62761c9f33d8 100644 --- a/velox/exec/tests/utils/ArbitratorTestUtil.cpp +++ b/velox/exec/tests/utils/ArbitratorTestUtil.cpp @@ -99,6 +99,7 @@ QueryTestResult runHashJoinTask( .spillDirectory(spillDirectory->path) .config(core::QueryConfig::kSpillEnabled, true) .config(core::QueryConfig::kJoinSpillEnabled, true) + .config(core::QueryConfig::kSpillStartPartitionBit, "29") .queryCtx(queryCtx) .maxDrivers(numDrivers) .copyResults(pool, result.task); diff --git a/velox/expression/CastExpr-inl.h b/velox/expression/CastExpr-inl.h index 8d9e732d33a9..5397a6f23aa1 100644 --- a/velox/expression/CastExpr-inl.h +++ b/velox/expression/CastExpr-inl.h @@ -345,7 +345,14 @@ void CastExpr::applyCastKernel( FromKind == TypeKind::TIMESTAMP && (ToKind == TypeKind::VARCHAR || ToKind == TypeKind::VARBINARY)) { auto writer = exec::StringWriter<>(result, row); - hooks_->castTimestampToString(inputRowValue, writer); + const auto& queryConfig = context.execCtx()->queryCtx()->queryConfig(); + auto sessionTzName = queryConfig.sessionTimezone(); + if (queryConfig.adjustTimestampToTimezone() && !sessionTzName.empty()) { + const auto* timeZone = date::locate_zone(sessionTzName); + hooks_->castTimestampToString(inputRowValue, writer, timeZone); + } else { + hooks_->castTimestampToString(inputRowValue, writer); + } return; } diff --git a/velox/expression/CastHooks.h b/velox/expression/CastHooks.h index 9ad439afc154..0721ebe5afa3 100644 --- a/velox/expression/CastHooks.h +++ b/velox/expression/CastHooks.h @@ -35,7 +35,8 @@ class CastHooks { // Cast from timestamp to string and write the result to string writer. virtual void castTimestampToString( const Timestamp& timestamp, - StringWriter& out) const = 0; + StringWriter& out, + const date::time_zone* timeZone = nullptr) const = 0; // Returns whether legacy cast semantics are enabled. virtual bool legacy() const = 0; diff --git a/velox/expression/FieldReference.cpp b/velox/expression/FieldReference.cpp index cff2a7437af9..f00280a541d1 100644 --- a/velox/expression/FieldReference.cpp +++ b/velox/expression/FieldReference.cpp @@ -30,31 +30,6 @@ void FieldReference::computeDistinctFields() { } } -// Fast path to avoid copying result. An alternative way to do this is to -// ensure that children has null if parent has nulls on corresponding rows, -// whenever the RowVector is constructed or mutated (eager propagation of -// nulls). The current lazy propagation might still be better (more efficient) -// when adding extra nulls. -bool FieldReference::addNullsFast( - const SelectivityVector& rows, - EvalCtx& context, - VectorPtr& result, - const RowVector* row) { - if (result) { - return false; - } - auto& child = - inputs_.empty() ? context.getField(index_) : row->childAt(index_); - if (row->mayHaveNulls()) { - if (!child.unique()) { - return false; - } - addNulls(rows, row->rawNulls(), context, const_cast(child)); - } - result = child; - return true; -} - void FieldReference::apply( const SelectivityVector& rows, EvalCtx& context, @@ -116,9 +91,6 @@ void FieldReference::apply( VELOX_CHECK(rowType); index_ = rowType->getChildIdx(field_); } - if (!useDecode && addNullsFast(rows, context, result, row)) { - return; - } VectorPtr child = inputs_.empty() ? context.getField(index_) : row->childAt(index_); if (child->encoding() == VectorEncoding::Simple::LAZY) { diff --git a/velox/expression/FieldReference.h b/velox/expression/FieldReference.h index 97ef4e0d8d33..ed1b12bc2c15 100644 --- a/velox/expression/FieldReference.h +++ b/velox/expression/FieldReference.h @@ -86,12 +86,6 @@ class FieldReference : public SpecialForm { void apply(const SelectivityVector& rows, EvalCtx& context, VectorPtr& result); - bool addNullsFast( - const SelectivityVector& rows, - EvalCtx& context, - VectorPtr& result, - const RowVector* row); - const std::string field_; int32_t index_ = -1; }; diff --git a/velox/expression/PrestoCastHooks.cpp b/velox/expression/PrestoCastHooks.cpp index 0d1d0a473c55..1805d7cd0072 100644 --- a/velox/expression/PrestoCastHooks.cpp +++ b/velox/expression/PrestoCastHooks.cpp @@ -30,13 +30,25 @@ int32_t PrestoCastHooks::castStringToDate(const StringView& dateString) const { void PrestoCastHooks::castTimestampToString( const Timestamp& timestamp, - StringWriter& out) const { - out.copy_from( - legacyCast_ - ? util::Converter:: - cast(timestamp) - : util::Converter:: - cast(timestamp)); + StringWriter& out, + const date::time_zone* timeZone) const { + if (legacyCast_) { + out.copy_from( + util::Converter::cast( + timestamp)); + } else { + if (timeZone) { + Timestamp adjustedTimestamp(timestamp); + adjustedTimestamp.toTimezone(*timeZone); + out.copy_from( + util::Converter:: + cast(adjustedTimestamp)); + } else { + out.copy_from( + util::Converter:: + cast(timestamp)); + } + } out.finalize(); } diff --git a/velox/expression/PrestoCastHooks.h b/velox/expression/PrestoCastHooks.h index 161bbd7f5b5f..0e792d7c6dc5 100644 --- a/velox/expression/PrestoCastHooks.h +++ b/velox/expression/PrestoCastHooks.h @@ -36,7 +36,8 @@ class PrestoCastHooks : public CastHooks { // Applies different cast options according to 'isLegacyCast' config. void castTimestampToString( const Timestamp& timestamp, - StringWriter& out) const override; + StringWriter& out, + const date::time_zone* timeZone) const override; // Follows 'isLegacyCast' config. bool legacy() const override; diff --git a/velox/expression/SpecialFormRegistry.cpp b/velox/expression/SpecialFormRegistry.cpp index 616378ce4152..3957981905b3 100644 --- a/velox/expression/SpecialFormRegistry.cpp +++ b/velox/expression/SpecialFormRegistry.cpp @@ -29,8 +29,10 @@ SpecialFormRegistry& specialFormRegistryInternal() { void SpecialFormRegistry::registerFunctionCallToSpecialForm( const std::string& name, std::unique_ptr functionCallToSpecialForm) { - registry_.withWLock( - [&](auto& map) { map[name] = std::move(functionCallToSpecialForm); }); + const auto sanitizedName = sanitizeName(name); + registry_.withWLock([&](auto& map) { + map[sanitizedName] = std::move(functionCallToSpecialForm); + }); } void SpecialFormRegistry::unregisterAllFunctionCallToSpecialForm() { @@ -39,9 +41,10 @@ void SpecialFormRegistry::unregisterAllFunctionCallToSpecialForm() { FunctionCallToSpecialForm* FOLLY_NULLABLE SpecialFormRegistry::getSpecialForm(const std::string& name) const { + const auto sanitizedName = sanitizeName(name); FunctionCallToSpecialForm* specialForm = nullptr; registry_.withRLock([&](const auto& map) { - auto it = map.find(name); + auto it = map.find(sanitizedName); if (it != map.end()) { specialForm = it->second.get(); } diff --git a/velox/expression/tests/CastExprTest.cpp b/velox/expression/tests/CastExprTest.cpp index 62a7de0ffd6b..e8d2ba40637d 100644 --- a/velox/expression/tests/CastExprTest.cpp +++ b/velox/expression/tests/CastExprTest.cpp @@ -606,6 +606,7 @@ TEST_F(CastExprTest, timestampToString) { testCast( "string", { + Timestamp(0, 0), Timestamp(946729316, 123), Timestamp(-50049331200, 0), Timestamp(253405036800, 0), @@ -613,12 +614,34 @@ TEST_F(CastExprTest, timestampToString) { std::nullopt, }, { + "1970-01-01T00:00:00.000", "2000-01-01T12:21:56.000", "384-01-01T08:00:00.000", "10000-02-01T16:00:00.000", "-10-02-01T10:00:00.000", std::nullopt, }); + + setLegacyCast(false); + setTimezone("America/Los_Angeles"); + testCast( + "string", + { + Timestamp(0, 0), + Timestamp(946729316, 123), + Timestamp(-50049331622, 0), + Timestamp(253405036800, 0), + Timestamp(-62480038022, 0), + std::nullopt, + }, + { + "1969-12-31 16:00:00.000", + "2000-01-01 04:21:56.000", + "0384-01-01 00:00:00.000", + "10000-02-01 08:00:00.000", + "-0010-02-01 02:00:00.000", + std::nullopt, + }); } TEST_F(CastExprTest, dateToTimestamp) { diff --git a/velox/expression/tests/ExpressionFuzzerTest.cpp b/velox/expression/tests/ExpressionFuzzerTest.cpp index d206add7c3bf..007e8a428e1d 100644 --- a/velox/expression/tests/ExpressionFuzzerTest.cpp +++ b/velox/expression/tests/ExpressionFuzzerTest.cpp @@ -54,6 +54,10 @@ int main(int argc, char** argv) { "width_bucket", // Fuzzer cannot generate valid 'comparator' lambda. "array_sort(array(T),constant function(T,T,bigint)) -> array(T)", + // https://github.com/facebookincubator/velox/issues/8438#issuecomment-1907234044 + "regexp_extract", + "regexp_extract_all", + "regexp_like", }; size_t initialSeed = FLAGS_seed == 0 ? std::time(nullptr) : FLAGS_seed; return FuzzerRunner::run(initialSeed, skipFunctions); diff --git a/velox/expression/tests/FunctionCallToSpecialFormTest.cpp b/velox/expression/tests/FunctionCallToSpecialFormTest.cpp index e3b72a513ec6..7c9bd3a59022 100644 --- a/velox/expression/tests/FunctionCallToSpecialFormTest.cpp +++ b/velox/expression/tests/FunctionCallToSpecialFormTest.cpp @@ -185,3 +185,48 @@ TEST_F(FunctionCallToSpecialFormTest, notASpecialForm) { config_); ASSERT_EQ(specialForm, nullptr); } + +class FunctionCallToSpecialFormSanitizeNameTest : public testing::Test, + public VectorTestBase { + protected: + static void SetUpTestCase() { + // This class does not pre-register the special forms. + memory::MemoryManager::testingSetInstance({}); + } +}; + +TEST_F(FunctionCallToSpecialFormSanitizeNameTest, sanitizeName) { + // Make sure no special forms are registered. + unregisterAllFunctionCallToSpecialForm(); + + ASSERT_FALSE(isFunctionCallToSpecialFormRegistered("and")); + ASSERT_FALSE(isFunctionCallToSpecialFormRegistered("AND")); + ASSERT_FALSE(isFunctionCallToSpecialFormRegistered("or")); + ASSERT_FALSE(isFunctionCallToSpecialFormRegistered("OR")); + + registerFunctionCallToSpecialForm( + "and", std::make_unique(true /* isAnd */)); + registerFunctionCallToSpecialForm( + "OR", std::make_unique(false /* isAnd */)); + + auto testLookup = [this](const std::string& name) { + auto type = resolveTypeForSpecialForm(name, {BOOLEAN(), BOOLEAN()}); + ASSERT_EQ(type, BOOLEAN()); + + auto specialForm = constructSpecialForm( + name, + BOOLEAN(), + {std::make_shared( + vectorMaker_.constantVector({true})), + std::make_shared( + vectorMaker_.constantVector({false}))}, + false, + core::QueryConfig{{}}); + ASSERT_EQ(typeid(*specialForm), typeid(const ConjunctExpr&)); + }; + + testLookup("and"); + testLookup("AND"); + testLookup("or"); + testLookup("OR"); +} diff --git a/velox/functions/lib/CMakeLists.txt b/velox/functions/lib/CMakeLists.txt index 85fa8d64c5ec..fe1bdb405ff9 100644 --- a/velox/functions/lib/CMakeLists.txt +++ b/velox/functions/lib/CMakeLists.txt @@ -28,6 +28,7 @@ add_library( KllSketch.cpp MapConcat.cpp Re2Functions.cpp + Repeat.cpp StringEncodingUtils.cpp SubscriptUtil.cpp CheckNestedNulls.cpp diff --git a/velox/functions/lib/Re2Functions.cpp b/velox/functions/lib/Re2Functions.cpp index d14a2bbe8f7e..8c2509f03adc 100644 --- a/velox/functions/lib/Re2Functions.cpp +++ b/velox/functions/lib/Re2Functions.cpp @@ -23,6 +23,48 @@ namespace { static const int kMaxCompiledRegexes = 20; +void checkForBadPattern(const RE2& re) { + if (UNLIKELY(!re.ok())) { + VELOX_USER_FAIL("invalid regular expression:{}", re.error()); + } +} + +template +re2::StringPiece toStringPiece(const T& s) { + return re2::StringPiece(s.data(), s.size()); +} + +// A cache of compiled regular expressions (RE2 instances). Allows up to +// 'kMaxCompiledRegexes' different expressions. +// +// Compiling regular expressions is expensive. It can take up to 200 times +// more CPU time to compile a regex vs. evaluate it. +class ReCache { + public: + RE2* findOrCompile(const StringView& pattern) { + const std::string key = pattern; + + auto reIt = cache_.find(key); + if (reIt != cache_.end()) { + return reIt->second.get(); + } + + VELOX_USER_CHECK_LT( + cache_.size(), kMaxCompiledRegexes, "Max number of regex reached"); + + auto re = std::make_unique(toStringPiece(pattern), RE2::Quiet); + checkForBadPattern(*re); + + auto [it, inserted] = cache_.emplace(key, std::move(re)); + VELOX_CHECK(inserted); + + return it->second.get(); + } + + private: + folly::F14FastMap> cache_; +}; + std::string printTypesCsv( const std::vector& inputArgs) { std::string result; @@ -34,11 +76,6 @@ std::string printTypesCsv( return result; } -template -re2::StringPiece toStringPiece(const T& s) { - return re2::StringPiece(s.data(), s.size()); -} - // If v is a non-null constant vector, returns the constant value. Otherwise // returns nullopt. template @@ -50,12 +87,6 @@ std::optional getIfConstant(const BaseVector& v) { return std::nullopt; } -void checkForBadPattern(const RE2& re) { - if (UNLIKELY(!re.ok())) { - VELOX_USER_FAIL("invalid regular expression:{}", re.error()); - } -} - FlatVector& ensureWritableBool( const SelectivityVector& rows, exec::EvalCtx& context, @@ -220,11 +251,13 @@ class Re2Match final : public exec::VectorFunction { exec::LocalDecodedVector toSearch(context, *args[0], rows); exec::LocalDecodedVector pattern(context, *args[1], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); result.set(row, Fn(toSearch->valueAt(row), re)); }); } + + private: + mutable ReCache cache_; }; void checkForBadGroupId(int64_t groupId, const RE2& re) { @@ -348,8 +381,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { if (args.size() == 2) { groups.resize(1); context.applyToSelectedNoThrow(rows, [&](vector_size_t i) { - RE2 re(toStringPiece(pattern->valueAt(i)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(i)); mustRefSourceStrings |= re2Extract(result, i, re, toSearch, groups, 0, emptyNoMatch_); }); @@ -357,8 +389,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { exec::LocalDecodedVector groupIds(context, *args[2], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t i) { const auto groupId = groupIds->valueAt(i); - RE2 re(toStringPiece(pattern->valueAt(i)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(i)); checkForBadGroupId(groupId, re); groups.resize(groupId + 1); mustRefSourceStrings |= @@ -372,6 +403,7 @@ class Re2SearchAndExtract final : public exec::VectorFunction { private: const bool emptyNoMatch_; + mutable ReCache cache_; }; namespace { @@ -839,8 +871,9 @@ class LikeWithRe2 final : public exec::VectorFunction { // This function is constructed when pattern or escape are not constants. // It allows up to kMaxCompiledRegexes different regular expressions to be -// compiled throughout the query life per function, note that optimized regular -// expressions that are not compiled are not counted. +// compiled throughout the query lifetime per expression and thread of +// execution, note that optimized regular expressions that are not compiled are +// not counted. class LikeGeneric final : public exec::VectorFunction { void apply( const SelectivityVector& rows, @@ -853,25 +886,8 @@ class LikeGeneric final : public exec::VectorFunction { auto applyWithRegex = [&](const StringView& input, const StringView& pattern, const std::optional& escapeChar) -> bool { - RE2::Options opt{RE2::Quiet}; - opt.set_dot_nl(true); - bool validEscapeUsage; - auto regex = likePatternToRe2(pattern, escapeChar, validEscapeUsage); - VELOX_USER_CHECK( - validEscapeUsage, - "Escape character must be followed by '%', '_' or the escape character itself"); - - auto key = - std::pair>{pattern, escapeChar}; - - auto [it, inserted] = compiledRegularExpressions_.emplace( - key, std::make_unique(toStringPiece(regex), opt)); - VELOX_USER_CHECK_LE( - compiledRegularExpressions_.size(), - kMaxCompiledRegexes, - "Max number of regex reached"); - checkForBadPattern(*it->second); - return re2FullMatch(input, *it->second); + auto* re = findOrCompileRegex(pattern, escapeChar); + return re2FullMatch(input, *re); }; auto applyRow = [&](const StringView& input, @@ -963,6 +979,40 @@ class LikeGeneric final : public exec::VectorFunction { } private: + RE2* findOrCompileRegex( + const StringView& pattern, + std::optional escapeChar) const { + const auto key = + std::pair>{pattern, escapeChar}; + + auto reIt = compiledRegularExpressions_.find(key); + if (reIt != compiledRegularExpressions_.end()) { + return reIt->second.get(); + } + + VELOX_USER_CHECK_LT( + compiledRegularExpressions_.size(), + kMaxCompiledRegexes, + "Max number of regex reached"); + + bool validEscapeUsage; + auto regex = likePatternToRe2(pattern, escapeChar, validEscapeUsage); + VELOX_USER_CHECK( + validEscapeUsage, + "Escape character must be followed by '%', '_' or the escape character itself"); + + RE2::Options opt{RE2::Quiet}; + opt.set_dot_nl(true); + auto re = std::make_unique(toStringPiece(regex), opt); + checkForBadPattern(*re); + + auto [it, inserted] = + compiledRegularExpressions_.emplace(key, std::move(re)); + VELOX_CHECK(inserted); + + return it->second.get(); + } + mutable folly::F14FastMap< std::pair>, std::unique_ptr> @@ -1108,8 +1158,7 @@ class Re2ExtractAll final : public exec::VectorFunction { // groups.resize(1); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); re2ExtractAll(resultWriter, re, inputStrs, row, groups, 0); }); } else { @@ -1118,8 +1167,7 @@ class Re2ExtractAll final : public exec::VectorFunction { exec::LocalDecodedVector groupIds(context, *args[2], rows); context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { const T groupId = groupIds->valueAt(row); - RE2 re(toStringPiece(pattern->valueAt(row)), RE2::Quiet); - checkForBadPattern(re); + auto& re = *cache_.findOrCompile(pattern->valueAt(row)); checkForBadGroupId(groupId, re); groups.resize(groupId + 1); re2ExtractAll(resultWriter, re, inputStrs, row, groups, groupId); @@ -1132,6 +1180,9 @@ class Re2ExtractAll final : public exec::VectorFunction { ->asFlatVector() ->acquireSharedStringBuffers(inputStrs->base()); } + + private: + mutable ReCache cache_; }; template @@ -1152,9 +1203,8 @@ std::shared_ptr makeRe2MatchImpl( return std::make_shared>( constantPattern->as>()->valueAt(0)); } - static std::shared_ptr> kMatchExpr = - std::make_shared>(); - return kMatchExpr; + + return std::make_shared>(); } } // namespace diff --git a/velox/functions/prestosql/Repeat.cpp b/velox/functions/lib/Repeat.cpp similarity index 77% rename from velox/functions/prestosql/Repeat.cpp rename to velox/functions/lib/Repeat.cpp index 39993b6b1f1e..885f13b6b43a 100644 --- a/velox/functions/prestosql/Repeat.cpp +++ b/velox/functions/lib/Repeat.cpp @@ -18,10 +18,14 @@ namespace facebook::velox::functions { namespace { - // See documentation at https://prestodb.io/docs/current/functions/array.html class RepeatFunction : public exec::VectorFunction { public: + // @param allowNegativeCount If true, negative 'count' is allowed + // and treated the same as zero (Spark's behavior). + explicit RepeatFunction(bool allowNegativeCount) + : allowNegativeCount_(allowNegativeCount) {} + static constexpr int32_t kMaxResultEntries = 10'000; bool isDefaultNullBehavior() const override { @@ -37,29 +41,36 @@ class RepeatFunction : public exec::VectorFunction { VectorPtr& result) const override { VectorPtr localResult; if (args[1]->isConstantEncoding()) { - localResult = applyConstant(rows, args, outputType, context); + localResult = applyConstantCount(rows, args, outputType, context); if (localResult == nullptr) { return; } } else { - localResult = applyFlat(rows, args, outputType, context); + localResult = applyNonConstantCount(rows, args, outputType, context); } context.moveOrCopyResult(localResult, rows, result); } private: - static void checkCount(const int32_t count) { - VELOX_USER_CHECK_GE( - count, - 0, - "Count argument of repeat function must be greater than or equal to 0"); + // Check count to make sure it is in valid range. + static int32_t checkCount(int32_t count, bool allowNegativeCount) { + if (count < 0) { + if (allowNegativeCount) { + return 0; + } + VELOX_USER_FAIL( + "({} vs. {}) Count argument of repeat function must be greater than or equal to 0", + count, + 0); + } VELOX_USER_CHECK_LE( count, kMaxResultEntries, "Count argument of repeat function must be less than or equal to 10000"); + return count; } - VectorPtr applyConstant( + VectorPtr applyConstantCount( const SelectivityVector& rows, std::vector& args, const TypePtr& outputType, @@ -73,14 +84,13 @@ class RepeatFunction : public exec::VectorFunction { return BaseVector::createNullConstant(outputType, numRows, pool); } - const auto count = constantCount->valueAt(0); + auto count = constantCount->valueAt(0); try { - checkCount(count); + count = checkCount(count, allowNegativeCount_); } catch (const VeloxUserError&) { context.setErrors(rows, std::current_exception()); return nullptr; } - const auto totalCount = count * numRows; // Allocate new vectors for indices, lengths and offsets. @@ -109,7 +119,7 @@ class RepeatFunction : public exec::VectorFunction { BaseVector::wrapInDictionary(nullptr, indices, totalCount, args[0])); } - VectorPtr applyFlat( + VectorPtr applyNonConstantCount( const SelectivityVector& rows, std::vector& args, const TypePtr& outputType, @@ -120,7 +130,7 @@ class RepeatFunction : public exec::VectorFunction { context.applyToSelectedNoThrow(rows, [&](auto row) { auto count = countDecoded->isNullAt(row) ? 0 : countDecoded->valueAt(row); - checkCount(count); + count = checkCount(count, allowNegativeCount_); totalCount += count; }); @@ -156,6 +166,9 @@ class RepeatFunction : public exec::VectorFunction { return; } auto count = countDecoded->valueAt(row); + if (count < 0) { + count = 0; + } rawSizes[row] = count; rawOffsets[row] = offset; std::fill(rawIndices + offset, rawIndices + offset + count, row); @@ -171,9 +184,12 @@ class RepeatFunction : public exec::VectorFunction { sizes, BaseVector::wrapInDictionary(nullptr, indices, totalCount, args[0])); } + + const bool allowNegativeCount_; }; +} // namespace -static std::vector> signatures() { +std::vector> repeatSignatures() { // T, integer -> array(T) return {exec::FunctionSignatureBuilder() .typeVariable("T") @@ -182,11 +198,19 @@ static std::vector> signatures() { .argumentType("integer") .build()}; } -} // namespace -VELOX_DECLARE_VECTOR_FUNCTION( - udf_repeat, - signatures(), - std::make_unique()); +std::shared_ptr makeRepeat( + const std::string& /* name */, + const std::vector& /* inputArgs */, + const core::QueryConfig& /*config*/) { + return std::make_unique(false); +} + +std::shared_ptr makeRepeatAllowNegativeCount( + const std::string& /* name */, + const std::vector& /* inputArgs */, + const core::QueryConfig& /*config*/) { + return std::make_unique(true); +} } // namespace facebook::velox::functions diff --git a/velox/functions/lib/Repeat.h b/velox/functions/lib/Repeat.h new file mode 100644 index 000000000000..2721ede1e9a8 --- /dev/null +++ b/velox/functions/lib/Repeat.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "velox/expression/VectorFunction.h" + +namespace facebook::velox::functions { + +std::vector> repeatSignatures(); + +// Does not allow negative count. +std::shared_ptr makeRepeat( + const std::string& name, + const std::vector& inputArgs, + const core::QueryConfig& config); + +// Allows negative count (Spark's behavior). +std::shared_ptr makeRepeatAllowNegativeCount( + const std::string& name, + const std::vector& inputArgs, + const core::QueryConfig& config); +} // namespace facebook::velox::functions diff --git a/velox/functions/lib/tests/CMakeLists.txt b/velox/functions/lib/tests/CMakeLists.txt index 424aef6daf47..8427d4c34044 100644 --- a/velox/functions/lib/tests/CMakeLists.txt +++ b/velox/functions/lib/tests/CMakeLists.txt @@ -20,6 +20,7 @@ add_executable( KllSketchTest.cpp MapConcatTest.cpp Re2FunctionsTest.cpp + RepeatTest.cpp ZetaDistributionTest.cpp CheckNestedNullsTest.cpp) diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index 1d0bdfb505b2..bb7503d74b5f 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -1431,5 +1431,45 @@ TEST_F(Re2FunctionsTest, regexExtractAllLarge) { "No group 4611686018427387904 in regex '(\\d+)([a-z]+)") } +// Make sure we do not compile more than kMaxCompiledRegexes. +TEST_F(Re2FunctionsTest, limit) { + auto data = makeRowVector({ + makeFlatVector( + 100, + [](auto row) { return fmt::format("Apples and oranges {}", row); }), + makeFlatVector( + 100, + [](auto row) { return fmt::format("Apples (.*) oranges {}", row); }), + makeFlatVector( + 100, + [](auto row) { + return fmt::format("Apples (.*) oranges {}", row % 20); + }), + }); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract(c0, c1)", data), "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract(c0, c2)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract(c0, c1, 1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract(c0, c2, 1)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract_all(c0, c1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract_all(c0, c2)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_extract_all(c0, c1, 1)", data), + "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_extract_all(c0, c2, 1)", data)); + + VELOX_ASSERT_THROW( + evaluate("regexp_like(c0, c1)", data), "Max number of regex reached"); + ASSERT_NO_THROW(evaluate("regexp_like(c0, c2)", data)); +} + } // namespace } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/tests/RepeatTest.cpp b/velox/functions/lib/tests/RepeatTest.cpp similarity index 75% rename from velox/functions/prestosql/tests/RepeatTest.cpp rename to velox/functions/lib/tests/RepeatTest.cpp index 95275c8c7e4f..0969f8e95c06 100644 --- a/velox/functions/prestosql/tests/RepeatTest.cpp +++ b/velox/functions/lib/tests/RepeatTest.cpp @@ -14,17 +14,27 @@ * limitations under the License. */ +#include "velox/functions/lib/Repeat.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" -using namespace facebook::velox; using namespace facebook::velox::test; -using namespace facebook::velox::functions::test; +namespace facebook::velox::functions { namespace { -class RepeatTest : public FunctionBaseTest { +class RepeatTest : public functions::test::FunctionBaseTest { protected: + static void SetUpTestCase() { + FunctionBaseTest::SetUpTestCase(); + exec::registerStatefulVectorFunction( + "repeat", functions::repeatSignatures(), functions::makeRepeat); + exec::registerStatefulVectorFunction( + "repeat_allow_negative_count", + functions::repeatSignatures(), + functions::makeRepeatAllowNegativeCount); + } + void testExpression( const std::string& expression, const std::vector& input, @@ -41,7 +51,6 @@ class RepeatTest : public FunctionBaseTest { evaluate(expression, makeRowVector(input)), expectedError); } }; -} // namespace TEST_F(RepeatTest, repeat) { const auto elementVector = makeNullableFlatVector( @@ -124,3 +133,34 @@ TEST_F(RepeatTest, repeatWithInvalidCount) { {elementVector}, "(10001 vs. 10000) Count argument of repeat function must be less than or equal to 10000"); } + +TEST_F(RepeatTest, repeatAllowNegativeCount) { + const auto elementVector = makeNullableFlatVector( + {0.0, -2.0, 3.333333, 4.0004, std::nullopt, 5.12345}); + auto expected = makeArrayVector({{}, {}, {}, {}, {}, {}}); + + // Test negative count. + auto countVector = + makeNullableFlatVector({-1, -2, -3, -5, -10, -100}); + testExpression( + "repeat_allow_negative_count(C0, C1)", + {elementVector, countVector}, + expected); + + // Test using a constant as the count argument. + testExpression( + "repeat_allow_negative_count(C0, '-5'::INTEGER)", + {elementVector}, + expected); + + // Test mixed case. + expected = makeArrayVector( + {{0.0}, {-2.0, -2.0}, {}, {}, {}, {5.12345, 5.12345, 5.12345}}); + countVector = makeNullableFlatVector({1, 2, -1, 0, -10, 3}); + testExpression( + "repeat_allow_negative_count(C0, C1)", + {elementVector, countVector}, + expected); +} +} // namespace +} // namespace facebook::velox::functions diff --git a/velox/functions/lib/window/CMakeLists.txt b/velox/functions/lib/window/CMakeLists.txt index 687472181e79..ba8ba199bcb6 100644 --- a/velox/functions/lib/window/CMakeLists.txt +++ b/velox/functions/lib/window/CMakeLists.txt @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_functions_window NthValue.cpp Rank.cpp RowNumber.cpp) +add_library(velox_functions_window NthValue.cpp Rank.cpp RowNumber.cpp + Ntile.cpp) target_link_libraries(velox_functions_window velox_buffer velox_exec Folly::folly) diff --git a/velox/functions/prestosql/window/Ntile.cpp b/velox/functions/lib/window/Ntile.cpp similarity index 84% rename from velox/functions/prestosql/window/Ntile.cpp rename to velox/functions/lib/window/Ntile.cpp index 99f542facfa7..eacd562d2744 100644 --- a/velox/functions/prestosql/window/Ntile.cpp +++ b/velox/functions/lib/window/Ntile.cpp @@ -19,21 +19,23 @@ #include "velox/expression/FunctionSignature.h" #include "velox/vector/FlatVector.h" -namespace facebook::velox::window::prestosql { +namespace facebook::velox::functions::window { namespace { +template class NtileFunction : public exec::WindowFunction { public: explicit NtileFunction( const std::vector& args, + const TypePtr& resultType, velox::memory::MemoryPool* pool) - : WindowFunction(BIGINT(), pool, nullptr) { + : WindowFunction(resultType, pool, nullptr) { if (args[0].constantValue) { auto argBuckets = args[0].constantValue; if (!argBuckets->isNullAt(0)) { numFixedBuckets_ = - argBuckets->as>()->valueAt(0); + argBuckets->as>()->valueAt(0); VELOX_USER_CHECK_GE( numFixedBuckets_.value(), 1, "{}", kBucketErrorString); } @@ -41,8 +43,8 @@ class NtileFunction : public exec::WindowFunction { } bucketColumn_ = args[0].index; - bucketVector_ = BaseVector::create(BIGINT(), 0, pool); - bucketFlatVector_ = bucketVector_->asFlatVector(); + bucketVector_ = BaseVector::create(resultType_, 0, pool); + bucketFlatVector_ = bucketVector_->asFlatVector(); } void resetPartition(const exec::WindowPartition* partition) override { @@ -87,21 +89,21 @@ class NtileFunction : public exec::WindowFunction { struct BucketMetrics { // To compute the bucket number for a row, we find the number of rows in // a bucket as the (number of rows in partition) / (number of buckets). - int64_t rowsPerBucket; + TResult rowsPerBucket; // There could be some buckets with rowsPerBucket + 1 number of rows, // as the partition rows might not be exactly divisible // by the number of buckets. There are // (number of rows in partition) % (number of buckets) such buckets. - int64_t bucketsWithExtraRow; + TResult bucketsWithExtraRow; // When assigning bucket numbers, the first 'bucketsWithExtraRow' buckets // will have (rowsPerBucket + 1) rows. This row number at this boundary is // extraBucketsBoundary = bucketsWithExtraRow * (rowsPerBucket + 1). Beyond // this row number in the partition, the buckets will have only // rowsPerBucket number of rows. This boundary is useful when computing the // bucket value. - int64_t extraBucketsBoundary; + TResult extraBucketsBoundary; - int64_t computeBucketValue(vector_size_t rowNumber) const { + TResult computeBucketValue(vector_size_t rowNumber) const { if (rowNumber < extraBucketsBoundary) { return rowNumber / (rowsPerBucket + 1) + 1; } @@ -115,7 +117,7 @@ class NtileFunction : public exec::WindowFunction { vector_size_t numRows, int64_t partitionOffset, vector_size_t resultOffset, - int64_t* rawResultValues) { + TResult* rawResultValues) { int64_t i = 0; // This loop terminates if it reaches extraBucketBoundary or numRows // in the result vector are filled. @@ -130,7 +132,7 @@ class NtileFunction : public exec::WindowFunction { } }; - BucketMetrics computeBucketMetrics(int64_t numBuckets) const { + BucketMetrics computeBucketMetrics(TResult numBuckets) const { auto rowsPerBucket = numPartitionRows_ / numBuckets; auto bucketsWithExtraRow = numPartitionRows_ % numBuckets; auto extraBucketsBoundary = (rowsPerBucket + 1) * bucketsWithExtraRow; @@ -145,7 +147,7 @@ class NtileFunction : public exec::WindowFunction { partition_->extractColumn( bucketColumn_.value(), partitionOffset_, numRows, 0, bucketVector_); - auto* resultFlatVector = result->asFlatVector(); + auto* resultFlatVector = result->asFlatVector(); auto* rawValues = resultFlatVector->mutableRawValues(); for (auto i = 0; i < numRows; i++) { if (bucketFlatVector_->isNullAt(i)) { @@ -170,7 +172,7 @@ class NtileFunction : public exec::WindowFunction { vector_size_t resultOffset, const VectorPtr& result) { if (numFixedBuckets_.has_value()) { - auto rawValues = result->asFlatVector()->mutableRawValues(); + auto rawValues = result->asFlatVector()->mutableRawValues(); if (fixedBucketsMoreThanPartition_) { std::iota( rawValues + resultOffset, @@ -183,7 +185,7 @@ class NtileFunction : public exec::WindowFunction { } else { // This is a function call with a constant null value. Set all result // rows to null. - auto* resultVector = result->asFlatVector(); + auto* resultVector = result->asFlatVector(); auto mutableRawNulls = resultVector->mutableRawNulls(); bits::fillBits( mutableRawNulls, resultOffset, resultOffset + numRows, bits::kNull); @@ -195,7 +197,7 @@ class NtileFunction : public exec::WindowFunction { // Number of buckets if a constant value. Is optional as the value could // be null. - std::optional numFixedBuckets_; + std::optional numFixedBuckets_; // If number of buckets is greater than the partition rows, then the output // bucket number is simply row number + 1. So bucket computation can be @@ -209,29 +211,30 @@ class NtileFunction : public exec::WindowFunction { // Current WindowPartition used for accessing rows in the apply method. const exec::WindowPartition* partition_; - int64_t numPartitionRows_ = 0; + TResult numPartitionRows_ = 0; // Denotes how far along the partition rows are output already. int64_t partitionOffset_ = 0; // Vector used to read the bucket column values. VectorPtr bucketVector_; - FlatVector* bucketFlatVector_; + FlatVector* bucketFlatVector_; static const std::string kBucketErrorString; }; -const std::string NtileFunction::kBucketErrorString = +template +const std::string NtileFunction::kBucketErrorString = "Buckets must be greater than 0"; } // namespace -void registerNtile(const std::string& name) { - // ntile(bigint) -> bigint. +template +void registerNtile(const std::string& name, const std::string& type) { std::vector signatures{ exec::FunctionSignatureBuilder() - .returnType("bigint") - .argumentType("bigint") + .returnType(type) + .argumentType(type) .build(), }; @@ -240,13 +243,21 @@ void registerNtile(const std::string& name) { std::move(signatures), [name]( const std::vector& args, - const TypePtr& /*resultType*/, + const TypePtr& resultType, bool /*ignoreNulls*/, velox::memory::MemoryPool* pool, HashStringAllocator* /*stringAllocator*/, const core::QueryConfig& /*queryConfig*/) -> std::unique_ptr { - return std::make_unique(args, pool); + return std::make_unique>(args, resultType, pool); }); } -} // namespace facebook::velox::window::prestosql + +void registerNtileBigint(const std::string& name) { + registerNtile(name, "bigint"); +} +void registerNtileInteger(const std::string& name) { + registerNtile(name, "integer"); +} + +} // namespace facebook::velox::functions::window diff --git a/velox/functions/lib/window/RegistrationFunctions.h b/velox/functions/lib/window/RegistrationFunctions.h index 726b49551837..17dc185d777c 100644 --- a/velox/functions/lib/window/RegistrationFunctions.h +++ b/velox/functions/lib/window/RegistrationFunctions.h @@ -54,4 +54,12 @@ void registerDenseRankInteger(const std::string& name); // Returns the percentage ranking of a value in a group of values. void registerPercentRank(const std::string& name); +// Register the Presto function ntile() with the bigint data type +// for the return and input value. +void registerNtileBigint(const std::string& name); + +// Register the Spark function ntile() with the integer data type +// for the return and input value. +void registerNtileInteger(const std::string& name); + } // namespace facebook::velox::functions::window diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index d2a17ca1c32a..fcb4d97e3e22 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -44,7 +44,6 @@ add_library( MapZipWith.cpp Not.cpp Reduce.cpp - Repeat.cpp Reverse.cpp RowFunction.cpp Sequence.cpp diff --git a/velox/functions/prestosql/Reduce.cpp b/velox/functions/prestosql/Reduce.cpp index 363eeb9a9244..33c47d45ff06 100644 --- a/velox/functions/prestosql/Reduce.cpp +++ b/velox/functions/prestosql/Reduce.cpp @@ -19,6 +19,34 @@ namespace facebook::velox::functions { namespace { +// Throws if any array in any of 'rows' has more than 10K elements. +// Evaluating 'reduce' lambda function on very large arrays is too slow. +void checkArraySizes( + const SelectivityVector& rows, + DecodedVector& decodedArray, + exec::EvalCtx& context) { + const auto* indices = decodedArray.indices(); + const auto* rawSizes = decodedArray.base()->as()->rawSizes(); + + static const vector_size_t kMaxArraySize = 10'000; + + rows.applyToSelected([&](auto row) { + if (decodedArray.isNullAt(row)) { + return; + } + const auto size = rawSizes[indices[row]]; + try { + VELOX_USER_CHECK_LT( + size, + kMaxArraySize, + "reduce lambda function doesn't support arrays with more than {} elements", + kMaxArraySize); + } catch (VeloxUserError&) { + context.setError(row, std::current_exception()); + } + }); +} + /// Populates indices of the n-th elements of the arrays. /// Selects 'row' in 'arrayRows' if corresponding array has an n-th element. /// Sets elementIndices[row] to the index of the n-th element in the 'elements' @@ -75,6 +103,36 @@ class ReduceFunction : public exec::VectorFunction { exec::LocalDecodedVector arrayDecoder(context, *args[0], rows); auto& decodedArray = *arrayDecoder.get(); + checkArraySizes(rows, decodedArray, context); + + exec::LocalSelectivityVector remainingRows(context, rows); + context.deselectErrors(*remainingRows); + + doApply(*remainingRows, args, decodedArray, outputType, context, result); + } + + static std::vector> signatures() { + // array(T), S, function(S, T, S), function(S, R) -> R + return {exec::FunctionSignatureBuilder() + .typeVariable("T") + .typeVariable("S") + .typeVariable("R") + .returnType("R") + .argumentType("array(T)") + .argumentType("S") + .argumentType("function(S,T,S)") + .argumentType("function(S,R)") + .build()}; + } + + private: + void doApply( + const SelectivityVector& rows, + std::vector& args, + DecodedVector& decodedArray, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const { auto flatArray = flattenArray(rows, args[0], decodedArray); // Identify the rows need to be computed. exec::LocalSelectivityVector nonNullRowsHolder(*context.execCtx()); @@ -157,6 +215,7 @@ class ReduceFunction : public exec::VectorFunction { n++; } } + // Apply output function. VectorPtr localResult; auto outputFuncIt = @@ -178,20 +237,6 @@ class ReduceFunction : public exec::VectorFunction { } context.moveOrCopyResult(localResult, rows, result); } - - static std::vector> signatures() { - // array(T), S, function(S, T, S), function(S, R) -> R - return {exec::FunctionSignatureBuilder() - .typeVariable("T") - .typeVariable("S") - .typeVariable("R") - .returnType("R") - .argumentType("array(T)") - .argumentType("S") - .argumentType("function(S,T,S)") - .argumentType("function(S,R)") - .build()}; - } }; } // namespace diff --git a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp index 46c7400ce8e3..7524a274ce7e 100644 --- a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp @@ -17,6 +17,7 @@ #include #include "velox/functions/Registerer.h" +#include "velox/functions/lib/Repeat.h" #include "velox/functions/prestosql/ArrayConstructor.h" #include "velox/functions/prestosql/ArrayFunctions.h" #include "velox/functions/prestosql/ArraySort.h" @@ -144,7 +145,8 @@ void registerArrayFunctions(const std::string& prefix) { }); VELOX_REGISTER_VECTOR_FUNCTION(udf_array_sum, prefix + "array_sum"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_repeat, prefix + "repeat"); + exec::registerStatefulVectorFunction( + prefix + "repeat", repeatSignatures(), makeRepeat); VELOX_REGISTER_VECTOR_FUNCTION(udf_sequence, prefix + "sequence"); exec::registerStatefulVectorFunction( diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index e893aad1d98a..4ba742ff61c8 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -76,7 +76,6 @@ add_executable( RandTest.cpp ReduceTest.cpp RegexpReplaceTest.cpp - RepeatTest.cpp ReverseTest.cpp RoundTest.cpp ScalarFunctionRegTest.cpp diff --git a/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp b/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp index c361037215b9..c9344f095e5a 100644 --- a/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp @@ -314,6 +314,18 @@ TEST_F(DecimalArithmeticTest, decimalDivTest) { "divide(c0, c1)", {shortFlat, longFlat}); + testDecimalExpr( + makeFlatVector( + {HugeInt::parse("20000000000000000"), + HugeInt::parse("50000000000000000")}, + DECIMAL(38, 19)), + "divide(c0, c1)", + {makeFlatVector({100, 200}, DECIMAL(17, 4)), + makeFlatVector( + {HugeInt::parse("50000000000000000000"), + HugeInt::parse("40000000000000000000")}, + DECIMAL(21, 19))}); + // Divide long and long, returning long. testDecimalExpr( makeFlatVector({500, 300}, DECIMAL(22, 2)), diff --git a/velox/functions/prestosql/tests/FindFirstTest.cpp b/velox/functions/prestosql/tests/FindFirstTest.cpp index dfaf4f742811..121721297b85 100644 --- a/velox/functions/prestosql/tests/FindFirstTest.cpp +++ b/velox/functions/prestosql/tests/FindFirstTest.cpp @@ -212,7 +212,7 @@ TEST_F(FindFirstTest, invalidIndex) { "SQL array indices start at 1. Got 0."); // Mark 3rd row null. Expect no error. - data->setNull(2, true); + data->childAt(1)->setNull(2, true); expected = makeAllNullFlatVector(4); verify("find_first(c0, c1, x -> (x > 0))", data, expected); } diff --git a/velox/functions/prestosql/tests/ReduceTest.cpp b/velox/functions/prestosql/tests/ReduceTest.cpp index 39ec8a770632..a75cb4ba76b4 100644 --- a/velox/functions/prestosql/tests/ReduceTest.cpp +++ b/velox/functions/prestosql/tests/ReduceTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" using namespace facebook::velox; @@ -242,3 +243,29 @@ TEST_F(ReduceTest, nullArray) { assertEqualVectors( makeNullableFlatVector({std::nullopt, std::nullopt}), result); } + +// Verify limit on the number of array elements. +TEST_F(ReduceTest, limit) { + // Make array vector with huge arrays in rows 2 and 4. + auto data = makeRowVector({makeArrayVector( + {0, 1'000, 10'000, 100'000, 100'010}, makeConstant(123, 1'000'000))}); + + VELOX_ASSERT_THROW( + evaluate("reduce(c0, 0, (s, x) -> s + x, s -> s)", data), + "reduce lambda function doesn't support arrays with more than 10000 elements"); + + // Exclude huge arrays. + SelectivityVector rows(4); + rows.setValid(2, false); + rows.updateBounds(); + auto result = evaluate("reduce(c0, 0, (s, x) -> s + x, s -> s)", data, rows); + auto expected = + makeFlatVector({123 * 1'000, 123 * 9'000, -1, 123 * 10}); + assertEqualVectors(expected, result, rows); + + // Mask errors with TRY. + result = evaluate("TRY(reduce(c0, 0, (s, x) -> s + x, s -> s))", data); + expected = makeNullableFlatVector( + {123 * 1'000, 123 * 9'000, std::nullopt, 123 * 10, std::nullopt}); + assertEqualVectors(expected, result); +} diff --git a/velox/functions/prestosql/window/CMakeLists.txt b/velox/functions/prestosql/window/CMakeLists.txt index 9770146d8719..61bcfedd81e2 100644 --- a/velox/functions/prestosql/window/CMakeLists.txt +++ b/velox/functions/prestosql/window/CMakeLists.txt @@ -15,7 +15,7 @@ if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) endif() -add_library(velox_window CumeDist.cpp FirstLastValue.cpp LeadLag.cpp Ntile.cpp +add_library(velox_window CumeDist.cpp FirstLastValue.cpp LeadLag.cpp WindowFunctionsRegistration.cpp) target_link_libraries(velox_window velox_buffer velox_exec diff --git a/velox/functions/prestosql/window/WindowFunctionsRegistration.cpp b/velox/functions/prestosql/window/WindowFunctionsRegistration.cpp index 4c7a97beb913..351f3c6bc92b 100644 --- a/velox/functions/prestosql/window/WindowFunctionsRegistration.cpp +++ b/velox/functions/prestosql/window/WindowFunctionsRegistration.cpp @@ -21,7 +21,7 @@ namespace facebook::velox::window { namespace prestosql { extern void registerCumeDist(const std::string& name); -extern void registerNtile(const std::string& name); +extern void registerNtileBigint(const std::string& name); extern void registerFirstValue(const std::string& name); extern void registerLastValue(const std::string& name); extern void registerLag(const std::string& name); @@ -33,7 +33,7 @@ void registerAllWindowFunctions(const std::string& prefix) { functions::window::registerDenseRankBigint(prefix + "dense_rank"); functions::window::registerPercentRank(prefix + "percent_rank"); registerCumeDist(prefix + "cume_dist"); - registerNtile(prefix + "ntile"); + functions::window::registerNtileBigint(prefix + "ntile"); functions::window::registerNthValueBigint(prefix + "nth_value"); registerFirstValue(prefix + "first_value"); registerLastValue(prefix + "last_value"); diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 4f4c14bdc495..a24c0bf215ae 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -21,6 +21,7 @@ #include "velox/functions/lib/IsNull.h" #include "velox/functions/lib/Re2Functions.h" #include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/lib/Repeat.h" #include "velox/functions/prestosql/DateTimeFunctions.h" #include "velox/functions/prestosql/JsonFunctions.h" #include "velox/functions/prestosql/StringFunctions.h" @@ -257,6 +258,11 @@ void registerFunctions(const std::string& prefix) { exec::registerStatefulVectorFunction( prefix + "sort_array", sortArraySignatures(), makeSortArray); + exec::registerStatefulVectorFunction( + prefix + "array_repeat", + repeatSignatures(), + makeRepeatAllowNegativeCount); + // Register date functions. registerFunction({prefix + "year"}); registerFunction({prefix + "year"}); diff --git a/velox/functions/sparksql/specialforms/SparkCastHooks.cpp b/velox/functions/sparksql/specialforms/SparkCastHooks.cpp index 7c0750b741a5..3daf9a0f4bcf 100644 --- a/velox/functions/sparksql/specialforms/SparkCastHooks.cpp +++ b/velox/functions/sparksql/specialforms/SparkCastHooks.cpp @@ -41,7 +41,8 @@ int32_t SparkCastHooks::castStringToDate(const StringView& dateString) const { void SparkCastHooks::castTimestampToString( const Timestamp& timestamp, - exec::StringWriter& out) const { + exec::StringWriter& out, + const date::time_zone* /*timeZone*/) const { static constexpr TimestampToStringOptions options = { .precision = TimestampToStringOptions::Precision::kMicroseconds, .leadingPositiveSign = true, diff --git a/velox/functions/sparksql/specialforms/SparkCastHooks.h b/velox/functions/sparksql/specialforms/SparkCastHooks.h index 8c441aa96d46..44c5321a448d 100644 --- a/velox/functions/sparksql/specialforms/SparkCastHooks.h +++ b/velox/functions/sparksql/specialforms/SparkCastHooks.h @@ -35,7 +35,8 @@ class SparkCastHooks : public exec::CastHooks { /// first if the year exceeds 9999. void castTimestampToString( const Timestamp& timestamp, - exec::StringWriter& out) const override; + exec::StringWriter& out, + const date::time_zone* timeZone) const override; // Returns false. bool legacy() const override; diff --git a/velox/functions/sparksql/window/WindowFunctionsRegistration.cpp b/velox/functions/sparksql/window/WindowFunctionsRegistration.cpp index 7100b779bcec..5a603b75f602 100644 --- a/velox/functions/sparksql/window/WindowFunctionsRegistration.cpp +++ b/velox/functions/sparksql/window/WindowFunctionsRegistration.cpp @@ -23,6 +23,7 @@ void registerWindowFunctions(const std::string& prefix) { functions::window::registerRowNumberInteger(prefix + "row_number"); functions::window::registerRankInteger(prefix + "rank"); functions::window::registerDenseRankInteger(prefix + "dense_rank"); + functions::window::registerNtileInteger(prefix + "ntile"); } } // namespace facebook::velox::functions::window::sparksql diff --git a/velox/functions/sparksql/window/tests/SparkWindowTest.cpp b/velox/functions/sparksql/window/tests/SparkWindowTest.cpp index df87c432f9ba..97e3f7217eb2 100644 --- a/velox/functions/sparksql/window/tests/SparkWindowTest.cpp +++ b/velox/functions/sparksql/window/tests/SparkWindowTest.cpp @@ -28,7 +28,12 @@ static const std::vector kSparkWindowFunctions = { std::string("nth_value(c0, c3)"), std::string("row_number()"), std::string("rank()"), - std::string("dense_rank()")}; + std::string("dense_rank()"), + std::string("ntile(c3)"), + std::string("ntile(1)"), + std::string("ntile(7)"), + std::string("ntile(10)"), + std::string("ntile(16)")}; struct SparkWindowTestParam { const std::string function; diff --git a/velox/serializers/CompactRowSerializer.cpp b/velox/serializers/CompactRowSerializer.cpp index a47be6c7e742..b44f804995ef 100644 --- a/velox/serializers/CompactRowSerializer.cpp +++ b/velox/serializers/CompactRowSerializer.cpp @@ -28,7 +28,7 @@ void CompactRowVectorSerde::estimateSerializedSize( } namespace { -class CompactRowVectorSerializer : public VectorSerializer { +class CompactRowVectorSerializer : public IterativeVectorSerializer { public: using TRowSize = uint32_t; @@ -120,7 +120,8 @@ std::string concatenatePartialRow( } // namespace -std::unique_ptr CompactRowVectorSerde::createSerializer( +std::unique_ptr +CompactRowVectorSerde::createIterativeSerializer( RowTypePtr /* type */, int32_t /* numRows */, StreamArena* streamArena, diff --git a/velox/serializers/CompactRowSerializer.h b/velox/serializers/CompactRowSerializer.h index 10c5a928cb49..8b6f1a998450 100644 --- a/velox/serializers/CompactRowSerializer.h +++ b/velox/serializers/CompactRowSerializer.h @@ -33,7 +33,7 @@ class CompactRowVectorSerde : public VectorSerde { // This method is not used in production code. It is only used to // support round-trip tests for deserialization. - std::unique_ptr createSerializer( + std::unique_ptr createIterativeSerializer( RowTypePtr type, int32_t numRows, StreamArena* streamArena, diff --git a/velox/serializers/PrestoSerializer.cpp b/velox/serializers/PrestoSerializer.cpp index e7363e762e6d..149a70948fc5 100644 --- a/velox/serializers/PrestoSerializer.cpp +++ b/velox/serializers/PrestoSerializer.cpp @@ -980,6 +980,50 @@ void checkTypeEncoding(std::string_view encoding, const TypePtr& type) { encoding); } +// This is used when there's a mismatch between the encoding in the serialized +// page and the expected output encoding. If the serialized encoding is +// BYTE_ARRAY, it may represent an all-null vector of the expected output type. +// We attempt to read the serialized page as an UNKNOWN type, check if all +// values are null, and set the columnResult accordingly. If all values are +// null, we return true; otherwise, we return false. +bool tryReadNullColumn( + ByteInputStream* source, + velox::memory::MemoryPool* pool, + const TypePtr& columnType, + VectorPtr& columnResult, + vector_size_t resultOffset, + bool useLosslessTimestamp) { + auto unknownType = UNKNOWN(); + VectorPtr tempResult = BaseVector::create(unknownType, 0, pool); + read( + source, + unknownType, + pool, + tempResult, + 0 /*resultOffset*/, + useLosslessTimestamp); + auto deserializedSize = tempResult->size(); + // Ensure it contains all null values. + auto numNulls = BaseVector::countNulls(tempResult->nulls(), deserializedSize); + if (deserializedSize != numNulls) { + return false; + } + if (resultOffset == 0) { + columnResult = + BaseVector::createNullConstant(columnType, deserializedSize, pool); + } else { + columnResult->resize(resultOffset + deserializedSize); + + SelectivityVector nullRows(resultOffset + deserializedSize, false); + nullRows.setValidRange(resultOffset, resultOffset + deserializedSize, true); + nullRows.updateBounds(); + + BaseVector::ensureWritable(nullRows, columnType, pool, columnResult); + columnResult->addNulls(nullRows); + } + return true; +} + void readColumns( ByteInputStream* source, velox::memory::MemoryPool* pool, @@ -1037,7 +1081,20 @@ void readColumns( resultOffset, useLosslessTimestamp); } else { - checkTypeEncoding(encoding, columnType); + auto typeToEncoding = typeToEncodingName(columnType); + if (encoding != typeToEncoding) { + if (encoding == "BYTE_ARRAY" && + tryReadNullColumn( + source, + pool, + columnType, + columnResult, + resultOffset, + useLosslessTimestamp)) { + return; + } + checkTypeEncoding(encoding, columnType); + } const auto it = readers.find(columnType->kind()); VELOX_CHECK( it != readers.end(), @@ -3259,9 +3316,9 @@ class PrestoBatchVectorSerializer : public BatchVectorSerializer { const std::unique_ptr codec_; }; -class PrestoVectorSerializer : public VectorSerializer { +class PrestoIterativeVectorSerializer : public IterativeVectorSerializer { public: - PrestoVectorSerializer( + PrestoIterativeVectorSerializer( const RowTypePtr& rowType, std::vector encodings, int32_t numRows, @@ -3292,7 +3349,7 @@ class PrestoVectorSerializer : public VectorSerializer { // Constructor that takes a row vector instead of only the types. This is // different because then we know exactly how each vector is encoded // (recursively). - PrestoVectorSerializer( + PrestoIterativeVectorSerializer( const RowVectorPtr& rowVector, StreamArena* streamArena, bool useLosslessTimestamp, @@ -3398,13 +3455,14 @@ void PrestoVectorSerde::estimateSerializedSize( estimateSerializedSizeInt(vector->loadedVector(), rows, sizes, scratch); } -std::unique_ptr PrestoVectorSerde::createSerializer( +std::unique_ptr +PrestoVectorSerde::createIterativeSerializer( RowTypePtr type, int32_t numRows, StreamArena* streamArena, const Options* options) { const auto prestoOptions = toPrestoOptions(options); - return std::make_unique( + return std::make_unique( type, prestoOptions.encodings, numRows, @@ -3427,7 +3485,7 @@ void PrestoVectorSerde::deprecatedSerializeEncoded( const Options* options, OutputStream* out) { auto prestoOptions = toPrestoOptions(options); - auto serializer = std::make_unique( + auto serializer = std::make_unique( vector, streamArena, prestoOptions.useLosslessTimestamp, @@ -3525,6 +3583,43 @@ void PrestoVectorSerde::deserialize( (*result)->size(), 0, nullptr, nullptr, **result, resultOffset); } +void PrestoVectorSerde::deserializeSingleColumn( + ByteInputStream* source, + velox::memory::MemoryPool* pool, + TypePtr type, + VectorPtr* result, + const Options* options) { + const auto prestoOptions = toPrestoOptions(options); + VELOX_CHECK_EQ( + prestoOptions.compressionKind, + common::CompressionKind::CompressionKind_NONE); + const bool useLosslessTimestamp = prestoOptions.useLosslessTimestamp; + + if (*result && result->unique()) { + VELOX_CHECK( + *(*result)->type() == *type, + "Unexpected type: {} vs. {}", + (*result)->type()->toString(), + type->toString()); + (*result)->prepareForReuse(); + } else { + *result = BaseVector::create(type, 0, pool); + } + + auto types = {type}; + std::vector resultList = {*result}; + readColumns(source, pool, types, resultList, 0, useLosslessTimestamp); + + auto rowType = asRowType(ROW(types)); + RowVectorPtr tempRow = std::make_shared( + pool, rowType, nullptr, resultList[0]->size(), resultList); + scatterStructNulls(tempRow->size(), 0, nullptr, nullptr, *tempRow, 0); + // A copy of the 'result' shared_ptr was passed to scatterStructNulls() via + // 'resultList'. Make sure we re-assign 'result' in case the copy was replaced + // with a new vector. + *result = resultList[0]; +} + void testingScatterStructNulls( vector_size_t size, vector_size_t scatterSize, diff --git a/velox/serializers/PrestoSerializer.h b/velox/serializers/PrestoSerializer.h index 7c3202bdd6b0..a957e27ce76e 100644 --- a/velox/serializers/PrestoSerializer.h +++ b/velox/serializers/PrestoSerializer.h @@ -24,10 +24,11 @@ namespace facebook::velox::serializer::presto { /// There are two ways to serialize data using PrestoVectorSerde: /// /// 1. In order to append multiple RowVectors into the same serialized payload, -/// one can first create a VectorSerializer using createSerializer(), then -/// append successive RowVectors using VectorSerializer::append(). In this case, -/// since different RowVector might encode columns differently, data is always -/// flattened in the serialized payload. +/// one can first create an IterativeVectorSerializer using +/// createIterativeSerializer(), then append successive RowVectors using +/// IterativeVectorSerializer::append(). In this case, since different RowVector +/// might encode columns differently, data is always flattened in the serialized +/// payload. /// /// Note that there are two flavors of append(), one that takes a range of rows, /// and one that takes a list of row ids. The former is useful when serializing @@ -76,7 +77,7 @@ class PrestoVectorSerde : public VectorSerde { vector_size_t** sizes, Scratch& scratch) override; - std::unique_ptr createSerializer( + std::unique_ptr createIterativeSerializer( RowTypePtr type, int32_t numRows, StreamArena* streamArena, @@ -129,6 +130,18 @@ class PrestoVectorSerde : public VectorSerde { vector_size_t resultOffset, const Options* options) override; + /// This function is used to deserialize a single column that is serialized in + /// PrestoPage format. It is important to note that the PrestoPage format used + /// here does not include the Presto page header. Therefore, the 'source' + /// should contain uncompressed, serialized binary data, beginning at the + /// column header. + void deserializeSingleColumn( + ByteInputStream* source, + velox::memory::MemoryPool* pool, + TypePtr type, + VectorPtr* result, + const Options* options); + static void registerVectorSerde(); }; diff --git a/velox/serializers/UnsafeRowSerializer.cpp b/velox/serializers/UnsafeRowSerializer.cpp index 342311c5d926..0d940000fd70 100644 --- a/velox/serializers/UnsafeRowSerializer.cpp +++ b/velox/serializers/UnsafeRowSerializer.cpp @@ -29,7 +29,7 @@ void UnsafeRowVectorSerde::estimateSerializedSize( } namespace { -class UnsafeRowVectorSerializer : public VectorSerializer { +class UnsafeRowVectorSerializer : public IterativeVectorSerializer { public: using TRowSize = uint32_t; @@ -122,7 +122,8 @@ std::string concatenatePartialRow( } // namespace -std::unique_ptr UnsafeRowVectorSerde::createSerializer( +std::unique_ptr +UnsafeRowVectorSerde::createIterativeSerializer( RowTypePtr /* type */, int32_t /* numRows */, StreamArena* streamArena, diff --git a/velox/serializers/UnsafeRowSerializer.h b/velox/serializers/UnsafeRowSerializer.h index 1bf41c4f0cae..1c793c98717e 100644 --- a/velox/serializers/UnsafeRowSerializer.h +++ b/velox/serializers/UnsafeRowSerializer.h @@ -31,7 +31,7 @@ class UnsafeRowVectorSerde : public VectorSerde { // This method is not used in production code. It is only used to // support round-trip tests for deserialization. - std::unique_ptr createSerializer( + std::unique_ptr createIterativeSerializer( RowTypePtr type, int32_t numRows, StreamArena* streamArena, diff --git a/velox/serializers/tests/CompactRowSerializerTest.cpp b/velox/serializers/tests/CompactRowSerializerTest.cpp index a556c1e538ab..50fcd8d2c91f 100644 --- a/velox/serializers/tests/CompactRowSerializerTest.cpp +++ b/velox/serializers/tests/CompactRowSerializerTest.cpp @@ -43,7 +43,8 @@ class CompactRowSerializerTest : public ::testing::Test, auto arena = std::make_unique(pool_.get()); auto rowType = asRowType(rowVector->type()); - auto serializer = serde_->createSerializer(rowType, numRows, arena.get()); + auto serializer = + serde_->createIterativeSerializer(rowType, numRows, arena.get()); Scratch scratch; serializer->append(rowVector, folly::Range(rows.data(), numRows), scratch); diff --git a/velox/serializers/tests/PrestoSerializerTest.cpp b/velox/serializers/tests/PrestoSerializerTest.cpp index 8922eb8c106b..6c75a67d6c3a 100644 --- a/velox/serializers/tests/PrestoSerializerTest.cpp +++ b/velox/serializers/tests/PrestoSerializerTest.cpp @@ -36,8 +36,10 @@ class PrestoSerializerTest : public ::testing::TestWithParam, public VectorTestBase { protected: - static void SetUpTestCase() { - serializer::presto::PrestoVectorSerde::registerVectorSerde(); + static void SetUpTestSuite() { + if (!isRegisteredVectorSerde()) { + serializer::presto::PrestoVectorSerde::registerVectorSerde(); + } memory::MemoryManager::testingSetInstance({}); } @@ -90,8 +92,8 @@ class PrestoSerializerTest auto rowType = asRowType(rowVector->type()); auto numRows = rowVector->size(); auto paramOptions = getParamSerdeOptions(serdeOptions); - auto serializer = - serde_->createSerializer(rowType, numRows, arena.get(), ¶mOptions); + auto serializer = serde_->createIterativeSerializer( + rowType, numRows, arena.get(), ¶mOptions); vector_size_t sizeEstimate = 0; Scratch scratch; @@ -383,6 +385,125 @@ class PrestoSerializerTest }); } + RowVectorPtr encodingsArrayElementsTestVector() { + auto baseNoNulls = makeFlatVector({1, 2, 3, 4}); + auto baseWithNulls = + makeNullableFlatVector({1, std::nullopt, 2, 3}); + auto baseArray = + makeArrayVector({{1, 2, 3}, {}, {4, 5}, {6, 7, 8, 9, 10}}); + auto elementIndices = makeIndices(16, [](auto row) { return row / 4; }); + std::vector offsets{0, 2, 4, 6, 8, 10, 12, 14, 16}; + + return makeRowVector({ + makeArrayVector( + offsets, + BaseVector::wrapInDictionary( + nullptr, elementIndices, 16, baseNoNulls)), + makeArrayVector( + offsets, + BaseVector::wrapInDictionary( + nullptr, elementIndices, 16, baseWithNulls)), + makeArrayVector( + offsets, + BaseVector::wrapInDictionary( + nullptr, elementIndices, 16, baseArray)), + makeArrayVector( + offsets, + BaseVector::createConstant(INTEGER(), 123, 16, pool_.get())), + makeArrayVector( + offsets, + BaseVector::createNullConstant(VARCHAR(), 16, pool_.get())), + makeArrayVector(offsets, BaseVector::wrapInConstant(16, 1, baseArray)), + makeRowVector({ + makeArrayVector( + offsets, + BaseVector::wrapInDictionary( + nullptr, elementIndices, 16, baseNoNulls)), + makeArrayVector( + offsets, + BaseVector::wrapInDictionary( + nullptr, elementIndices, 16, baseWithNulls)), + makeArrayVector( + offsets, + BaseVector::wrapInDictionary( + nullptr, elementIndices, 16, baseArray)), + makeArrayVector( + offsets, + BaseVector::createConstant(INTEGER(), 123, 16, pool_.get())), + makeArrayVector( + offsets, + BaseVector::createNullConstant(VARCHAR(), 16, pool_.get())), + makeArrayVector( + offsets, BaseVector::wrapInConstant(16, 1, baseArray)), + }), + }); + } + + RowVectorPtr encodingsMapValuesTestVector() { + auto baseNoNulls = makeFlatVector({1, 2, 3, 4}); + auto baseWithNulls = + makeNullableFlatVector({1, std::nullopt, 2, 3}); + auto baseArray = + makeArrayVector({{1, 2, 3}, {}, {4, 5}, {6, 7, 8, 9, 10}}); + auto valueIndices = makeIndices(16, [](auto row) { return row / 4; }); + std::vector offsets{0, 2, 4, 6, 8, 10, 12, 14, 16}; + auto mapKeys = makeFlatVector(16, [](auto row) { return row; }); + + return makeRowVector({ + makeMapVector( + offsets, + mapKeys, + BaseVector::wrapInDictionary( + nullptr, valueIndices, 16, baseNoNulls)), + makeMapVector( + offsets, + mapKeys, + BaseVector::wrapInDictionary( + nullptr, valueIndices, 16, baseWithNulls)), + makeMapVector( + offsets, + mapKeys, + BaseVector::wrapInDictionary(nullptr, valueIndices, 16, baseArray)), + makeMapVector( + offsets, + mapKeys, + BaseVector::createConstant(INTEGER(), 123, 16, pool_.get())), + makeMapVector( + offsets, + mapKeys, + BaseVector::createNullConstant(VARCHAR(), 16, pool_.get())), + makeMapVector( + offsets, mapKeys, BaseVector::wrapInConstant(16, 1, baseArray)), + makeRowVector({ + makeMapVector( + offsets, + mapKeys, + BaseVector::wrapInDictionary( + nullptr, valueIndices, 16, baseNoNulls)), + makeMapVector( + offsets, + mapKeys, + BaseVector::wrapInDictionary( + nullptr, valueIndices, 16, baseWithNulls)), + makeMapVector( + offsets, + mapKeys, + BaseVector::wrapInDictionary( + nullptr, valueIndices, 16, baseArray)), + makeMapVector( + offsets, + mapKeys, + BaseVector::createConstant(INTEGER(), 123, 16, pool_.get())), + makeMapVector( + offsets, + mapKeys, + BaseVector::createNullConstant(VARCHAR(), 16, pool_.get())), + makeMapVector( + offsets, mapKeys, BaseVector::wrapInConstant(16, 1, baseArray)), + }), + }); + } + std::unique_ptr serde_; }; @@ -486,16 +607,60 @@ TEST_P(PrestoSerializerTest, intervalDayTime) { } TEST_P(PrestoSerializerTest, unknown) { + // Verify vectors of UNKNOWN type. Also verifies a special case where a + // vector, not of UNKNOWN type and with all nulls is serialized as an UNKNOWN + // type having BYTE_ARRAY encoding. + auto testAllNullSerializedAsUnknown = [&](VectorPtr vector, + TypePtr outputType) { + auto rowVector = makeRowVector({vector}); + auto expected = makeRowVector( + {BaseVector::createNullConstant(outputType, vector->size(), pool())}); + std::ostringstream out; + serialize(rowVector, &out, nullptr); + + auto rowType = asRowType(expected->type()); + auto deserialized = deserialize(rowType, out.str(), nullptr); + assertEqualVectors(expected, deserialized); + + if (rowVector->size() < 3) { + return; + } + + // Split input into 3 batches. Serialize each separately. Then, deserialize + // all into one vector. + auto splits = split(rowVector, 3); + std::vector serialized; + for (const auto& split : splits) { + std::ostringstream oss; + serialize(split, &oss, nullptr); + serialized.push_back(oss.str()); + } + + auto paramOptions = getParamSerdeOptions(nullptr); + RowVectorPtr result; + vector_size_t offset = 0; + for (auto i = 0; i < serialized.size(); ++i) { + auto byteStream = toByteStream(serialized[i]); + serde_->deserialize( + &byteStream, pool_.get(), rowType, &result, offset, ¶mOptions); + offset = result->size(); + } + + assertEqualVectors(expected, result); + }; + const vector_size_t size = 123; auto constantVector = - BaseVector::createNullConstant(UNKNOWN(), 123, pool_.get()); + BaseVector::createNullConstant(UNKNOWN(), size, pool_.get()); testRoundTrip(constantVector); + testAllNullSerializedAsUnknown(constantVector, BIGINT()); auto flatVector = BaseVector::create(UNKNOWN(), size, pool_.get()); for (auto i = 0; i < size; i++) { flatVector->setNull(i, true); } testRoundTrip(flatVector); + testAllNullSerializedAsUnknown(flatVector, BIGINT()); } TEST_P(PrestoSerializerTest, multiPage) { @@ -591,6 +756,28 @@ TEST_P(PrestoSerializerTest, encodingsBatchVectorSerializer) { testBatchVectorSerializerRoundTrip(encodingsTestVector()); } +// Test that array elements have their encodings preserved. +TEST_P(PrestoSerializerTest, encodingsArrayElements) { + testEncodedRoundTrip(encodingsArrayElementsTestVector()); +} + +// Test that array elements have their encodings preserved by the +// PrestoBatchVectorSerializer. +TEST_P(PrestoSerializerTest, encodingsArrayElementsBatchVectorSerializer) { + testBatchVectorSerializerRoundTrip(encodingsArrayElementsTestVector()); +} + +// Test that map values have their encodings preserved. +TEST_P(PrestoSerializerTest, encodingsMapValues) { + testEncodedRoundTrip(encodingsMapValuesTestVector()); +} + +// Test that map values have their encodings preserved by the +// PrestoBatchVectorSerializer. +TEST_P(PrestoSerializerTest, encodingsMapValuesBatchVectorSerializer) { + testBatchVectorSerializerRoundTrip(encodingsMapValuesTestVector()); +} + TEST_P(PrestoSerializerTest, scatterEncoded) { // Makes a struct with nulls and constant/dictionary encoded children. The // children need to get gaps where the parent struct has a null. @@ -733,3 +920,76 @@ INSTANTIATE_TEST_SUITE_P( common::CompressionKind::CompressionKind_ZSTD, common::CompressionKind::CompressionKind_LZ4, common::CompressionKind::CompressionKind_GZIP)); + +TEST_F(PrestoSerializerTest, deserializeSingleColumn) { + // Verify that deserializeSingleColumn API can handle all supported types. + static const size_t kPrestoPageHeaderBytes = 21; + static const size_t kNumOfColumnsSerializedBytes = sizeof(int32_t); + static const size_t kBytesToTrim = + kPrestoPageHeaderBytes + kNumOfColumnsSerializedBytes; + + auto testRoundTripSingleColumn = [&](const VectorPtr& vector) { + auto rowVector = makeRowVector({vector}); + // Serialize to PrestoPage format. + std::ostringstream output; + auto arena = std::make_unique(pool_.get()); + auto rowType = asRowType(rowVector->type()); + auto numRows = rowVector->size(); + auto serializer = + serde_->createSerializer(rowType, numRows, arena.get(), nullptr); + serializer->append(rowVector); + facebook::velox::serializer::presto::PrestoOutputStreamListener listener; + OStreamOutputStream out(&output, &listener); + serializer->flush(&out); + + // Remove the PrestoPage header and Number of columns section from the + // serialized data. + std::string input = output.str().substr(kBytesToTrim); + + auto byteStream = toByteStream(input); + VectorPtr deserialized; + serde_->deserializeSingleColumn( + &byteStream, pool(), vector->type(), &deserialized, nullptr); + assertEqualVectors(vector, deserialized); + }; + + std::vector typesToTest = { + BOOLEAN(), + TINYINT(), + SMALLINT(), + INTEGER(), + BIGINT(), + REAL(), + DOUBLE(), + VARCHAR(), + TIMESTAMP(), + ROW({VARCHAR(), INTEGER()}), + ARRAY(INTEGER()), + ARRAY(INTEGER()), + MAP(VARCHAR(), INTEGER()), + MAP(VARCHAR(), ARRAY(INTEGER())), + }; + + VectorFuzzer::Options opts; + opts.vectorSize = 5; + opts.nullRatio = 0.1; + opts.dictionaryHasNulls = false; + opts.stringVariableLength = true; + opts.stringLength = 20; + opts.containerVariableLength = false; + opts.timestampPrecision = + VectorFuzzer::Options::TimestampPrecision::kMilliSeconds; + opts.containerLength = 10; + + auto seed = 0; + + LOG(ERROR) << "Seed: " << seed; + SCOPED_TRACE(fmt::format("seed: {}", seed)); + VectorFuzzer fuzzer(opts, pool_.get(), seed); + + for (const auto& type : typesToTest) { + SCOPED_TRACE(fmt::format("Type: {}", type->toString())); + auto data = fuzzer.fuzz(type); + testRoundTripSingleColumn(data); + } +} diff --git a/velox/serializers/tests/UnsafeRowSerializerTest.cpp b/velox/serializers/tests/UnsafeRowSerializerTest.cpp index a8cc0d4c2773..ce6ad23aa1f5 100644 --- a/velox/serializers/tests/UnsafeRowSerializerTest.cpp +++ b/velox/serializers/tests/UnsafeRowSerializerTest.cpp @@ -43,7 +43,8 @@ class UnsafeRowSerializerTest : public ::testing::Test, auto arena = std::make_unique(pool_.get()); auto rowType = std::dynamic_pointer_cast(rowVector->type()); - auto serializer = serde_->createSerializer(rowType, numRows, arena.get()); + auto serializer = + serde_->createIterativeSerializer(rowType, numRows, arena.get()); Scratch scratch; serializer->append(rowVector, folly::Range(rows.data(), numRows), scratch); diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index deb47cb81f2d..e04bc6491dc3 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -251,7 +251,7 @@ class DecimalUtil { uint8_t /*bRescale*/) { VELOX_USER_CHECK_NE(b, 0, "Division by zero"); int resultSign = 1; - A unsignedDividendRescaled(a); + R unsignedDividendRescaled(a); if (a < 0) { resultSign = -1; unsignedDividendRescaled *= -1; diff --git a/velox/vector/VectorStream.cpp b/velox/vector/VectorStream.cpp index 758874c11239..069f2472f894 100644 --- a/velox/vector/VectorStream.cpp +++ b/velox/vector/VectorStream.cpp @@ -41,7 +41,7 @@ class DefaultBatchVectorSerializer : public BatchVectorSerializer { } StreamArena arena(pool_); - auto serializer = serde_->createSerializer( + auto serializer = serde_->createIterativeSerializer( asRowType(vector->type()), numRows, &arena, options_); serializer->append(vector, ranges, scratch); serializer->flush(stream); @@ -67,7 +67,7 @@ getNamedVectorSerdeImpl() { } // namespace -void VectorSerializer::append(const RowVectorPtr& vector) { +void IterativeVectorSerializer::append(const RowVectorPtr& vector) { const IndexRange allRows{0, vector->size()}; Scratch scratch; append(vector, folly::Range(&allRows, 1), scratch); @@ -144,7 +144,7 @@ void VectorStreamGroup::createStreamTree( RowTypePtr type, int32_t numRows, const VectorSerde::Options* options) { - serializer_ = serde_->createSerializer(type, numRows, this, options); + serializer_ = serde_->createIterativeSerializer(type, numRows, this, options); } void VectorStreamGroup::append( diff --git a/velox/vector/VectorStream.h b/velox/vector/VectorStream.h index ef9a38eb8094..6590532b03b1 100644 --- a/velox/vector/VectorStream.h +++ b/velox/vector/VectorStream.h @@ -39,9 +39,9 @@ struct IndexRange { /// Uses successive calls to `append` to add more rows to the serialization /// buffer. Then call `flush` to write the aggregate serialized data to an /// OutputStream. -class VectorSerializer { +class IterativeVectorSerializer { public: - virtual ~VectorSerializer() = default; + virtual ~IterativeVectorSerializer() = default; /// Serialize a subset of rows in a vector. virtual void append( @@ -159,7 +159,19 @@ class VectorSerde { /// /// This is more appropriate if the use case involves many small writes, e.g. /// partitioning a RowVector across multiple destinations. - virtual std::unique_ptr createSerializer( + /// + /// TODO: Remove createSerializer once Presto is updated to call + /// createIterativeSerializer. + virtual std::unique_ptr createSerializer( + RowTypePtr type, + int32_t numRows, + StreamArena* streamArena, + const Options* options = nullptr) { + return createIterativeSerializer( + std::move(type), numRows, streamArena, options); + } + + virtual std::unique_ptr createIterativeSerializer( RowTypePtr type, int32_t numRows, StreamArena* streamArena, @@ -298,7 +310,7 @@ class VectorStreamGroup : public StreamArena { const VectorSerde::Options* options = nullptr); private: - std::unique_ptr serializer_; + std::unique_ptr serializer_; VectorSerde* serde_{nullptr}; }; diff --git a/velox/vector/arrow/Bridge.cpp b/velox/vector/arrow/Bridge.cpp index 601cbc08e847..6522ab9f18fc 100644 --- a/velox/vector/arrow/Bridge.cpp +++ b/velox/vector/arrow/Bridge.cpp @@ -248,6 +248,8 @@ const char* exportArrowFormatStr( return "u"; // utf-8 string case TypeKind::VARBINARY: return "z"; // binary + case TypeKind::UNKNOWN: + return "n"; // NullType case TypeKind::TIMESTAMP: return "ttn"; // time64 [nanoseconds] @@ -598,6 +600,7 @@ void exportFlat( case TypeKind::REAL: case TypeKind::DOUBLE: case TypeKind::TIMESTAMP: + case TypeKind::UNKNOWN: exportValues(vec, rows, out, pool, holder); break; case TypeKind::VARCHAR: @@ -940,6 +943,8 @@ TypePtr importFromArrowImpl( return REAL(); case 'g': return DOUBLE(); + case 'n': + return UNKNOWN(); // Map both utf-8 and large utf-8 string to varchar. case 'u': diff --git a/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp b/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp index 8def65ce8e8e..a880d93f1a97 100644 --- a/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp +++ b/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp @@ -195,6 +195,8 @@ TEST_F(ArrowBridgeSchemaExportTest, scalar) { testScalarType(DECIMAL(10, 4), "d:10,4"); testScalarType(DECIMAL(20, 15), "d:20,15"); + + testScalarType(UNKNOWN(), "n"); } TEST_F(ArrowBridgeSchemaExportTest, nested) { @@ -238,24 +240,14 @@ TEST_F(ArrowBridgeSchemaExportTest, constant) { testConstant(DOUBLE(), "g"); testConstant(VARCHAR(), "u"); testConstant(DATE(), "tdD"); + testConstant(UNKNOWN(), "n"); testConstant(ARRAY(INTEGER()), "+l"); + testConstant(ARRAY(UNKNOWN()), "+l"); testConstant(MAP(BOOLEAN(), REAL()), "+m"); + testConstant(MAP(UNKNOWN(), REAL()), "+m"); testConstant(ROW({TIMESTAMP(), DOUBLE()}), "+s"); -} - -TEST_F(ArrowBridgeSchemaExportTest, unsupported) { - // Try some combination of unsupported types to ensure there's no crash or - // memory leak in failure scenarios. - EXPECT_THROW(testScalarType(UNKNOWN(), ""), VeloxException); - - EXPECT_THROW(testScalarType(ARRAY(UNKNOWN()), ""), VeloxException); - EXPECT_THROW(testScalarType(MAP(UNKNOWN(), INTEGER()), ""), VeloxException); - EXPECT_THROW(testScalarType(MAP(BIGINT(), UNKNOWN()), ""), VeloxException); - - EXPECT_THROW(testScalarType(ROW({BIGINT(), UNKNOWN()}), ""), VeloxException); - EXPECT_THROW( - testScalarType(ROW({BIGINT(), REAL(), UNKNOWN()}), ""), VeloxException); + testConstant(ROW({UNKNOWN(), UNKNOWN()}), "+s"); } class ArrowBridgeSchemaImportTest : public ArrowBridgeSchemaExportTest { @@ -395,7 +387,6 @@ TEST_F(ArrowBridgeSchemaImportTest, complexTypes) { } TEST_F(ArrowBridgeSchemaImportTest, unsupported) { - EXPECT_THROW(testSchemaImport("n"), VeloxUserError); EXPECT_THROW(testSchemaImport("C"), VeloxUserError); EXPECT_THROW(testSchemaImport("S"), VeloxUserError); EXPECT_THROW(testSchemaImport("I"), VeloxUserError); diff --git a/velox/vector/tests/VectorStreamTest.cpp b/velox/vector/tests/VectorStreamTest.cpp index ecc9f8cda2e0..a5b6707163fd 100644 --- a/velox/vector/tests/VectorStreamTest.cpp +++ b/velox/vector/tests/VectorStreamTest.cpp @@ -25,7 +25,7 @@ class MockVectorSerde : public VectorSerde { const folly::Range& ranges, vector_size_t** sizes) override {} - std::unique_ptr createSerializer( + std::unique_ptr createIterativeSerializer( RowTypePtr type, int32_t numRows, StreamArena* streamArena,