diff --git a/src/cypher/execution_plan/ops/op_filter.h b/src/cypher/execution_plan/ops/op_filter.h index 6684247e2e..b4cfd8f275 100644 --- a/src/cypher/execution_plan/ops/op_filter.h +++ b/src/cypher/execution_plan/ops/op_filter.h @@ -24,6 +24,7 @@ namespace cypher { class OpFilter : public OpBase { friend class EdgeFilterPushdownExpand; + friend class EdgeFilterPushdownVarLenExpand; std::shared_ptr filter_; /* FilterState * Different states in which ExpandAll can be at. */ diff --git a/src/cypher/execution_plan/ops/op_var_len_expand.cpp b/src/cypher/execution_plan/ops/op_var_len_expand.cpp index 48318bf6ff..9cf22daf8e 100644 --- a/src/cypher/execution_plan/ops/op_var_len_expand.cpp +++ b/src/cypher/execution_plan/ops/op_var_len_expand.cpp @@ -14,6 +14,454 @@ // // Created by wt on 18-8-30. +// Modified by bxj on 24-3-30. // #include "cypher/execution_plan/ops/op_var_len_expand.h" + +namespace cypher { + +// DFS State Class +DfsState::DfsState(RTContext *ctx, lgraph::VertexId id, int level, cypher::Relationship *relp, + ExpandTowards expand_direction, bool needNext, bool isMaxHop) + : currentNodeId(id), level(level), count(1), needNext(needNext) { + auto &types = relp->Types(); + auto iter_type = lgraph::EIter::NA; + switch (expand_direction) { + case ExpandTowards::FORWARD: + iter_type = types.empty() ? lgraph::EIter::OUT_EDGE : lgraph::EIter::TYPE_OUT_EDGE; + break; + case ExpandTowards::REVERSED: + iter_type = types.empty() ? lgraph::EIter::IN_EDGE : lgraph::EIter::TYPE_IN_EDGE; + break; + case ExpandTowards::BIDIRECTIONAL: + iter_type = types.empty() ? lgraph::EIter::BI_EDGE : lgraph::EIter::BI_TYPE_EDGE; + break; + } + if (!isMaxHop) { + // if reach max hop, do not init eiter + (relp->ItsRef()[level]).Initialize(ctx->txn_->GetTxn().get(), iter_type, id, types); + currentEit = &(relp->ItsRef()[level]); + } +} + +// Predicate Class +bool HeadPredicate::eval(std::vector &eits) { + auto ret = cypher::FieldData::Array(0); + // get first edge's timestamp, check whether it fits the condition + for (auto &eit : eits) { + if (eit.IsValid()) { + ret.array->emplace_back(lgraph::FieldData(eit.GetField("timestamp"))); + } + } + if (ret.array->empty()) { + return true; + } + FieldData head = FieldData(ret.array->front()); + switch (op) { + case lgraph::CompareOp::LBR_GT: + return head > operand; + case lgraph::CompareOp::LBR_GE: + return head >= operand; + case lgraph::CompareOp::LBR_LT: + return head < operand; + case lgraph::CompareOp::LBR_LE: + return head <= operand; + case lgraph::CompareOp::LBR_EQ: + return head == operand; + case lgraph::CompareOp::LBR_NEQ: + return head != operand; + default: + break; + } + return false; +} + +bool LastPredicate::eval(std::vector &eits) { + auto ret = cypher::FieldData::Array(0); + // get last edge's timestamp, check whether it fits the condition + for (auto &eit : eits) { + if (eit.IsValid()) { + ret.array->emplace_back(lgraph::FieldData(eit.GetField("timestamp"))); + } + } + if (ret.array->empty()) { + return true; + } + FieldData last = FieldData(ret.array->back()); + switch (op) { + case lgraph::CompareOp::LBR_GT: + return last > operand; + case lgraph::CompareOp::LBR_GE: + return last >= operand; + case lgraph::CompareOp::LBR_LT: + return last < operand; + case lgraph::CompareOp::LBR_LE: + return last <= operand; + case lgraph::CompareOp::LBR_EQ: + return last == operand; + case lgraph::CompareOp::LBR_NEQ: + return last != operand; + default: + break; + } + return false; +} + +bool IsAscPredicate::eval(std::vector &eits) { + auto ret = cypher::FieldData::Array(0); + for (auto &eit : eits) { + if (eit.IsValid()) { + ret.array->emplace_back(lgraph::FieldData(eit.GetField("timestamp"))); + } + } + if (ret.array->empty()) { + return true; + } + for (size_t i = 1; i < ret.array->size(); i++) { + if ((*ret.array)[i - 1] >= (*ret.array)[i]) { + return false; + } + } + return true; +} + +bool IsDescPredicate::eval(std::vector &eits) { + auto ret = cypher::FieldData::Array(0); + for (auto &eit : eits) { + if (eit.IsValid()) { + ret.array->emplace_back(lgraph::FieldData(eit.GetField("timestamp"))); + } + } + if (ret.array->empty()) { + return true; + } + for (size_t i = 1; i < ret.array->size(); i++) { + if ((*ret.array)[i - 1] <= (*ret.array)[i]) { + return false; + } + } + return true; +} + +bool MaxInListPredicate::eval(std::vector &eits) { + auto ret = cypher::FieldData::Array(0); + for (auto &eit : eits) { + if (eit.IsValid()) { + ret.array->emplace_back(lgraph::FieldData(eit.GetField("timestamp"))); + } + } + if (ret.array->empty()) { + return true; + } + // find max in path + size_t pos = 0; + for (size_t i = 0; i < ret.array->size(); i++) { + if ((*ret.array)[i] > (*ret.array)[pos]) { + pos = i; + } + } + + FieldData maxInList = cypher::FieldData((*ret.array)[pos]); + switch (op) { + case lgraph::CompareOp::LBR_GT: + return maxInList > operand; + case lgraph::CompareOp::LBR_GE: + return maxInList >= operand; + case lgraph::CompareOp::LBR_LT: + return maxInList < operand; + case lgraph::CompareOp::LBR_LE: + return maxInList <= operand; + case lgraph::CompareOp::LBR_EQ: + return maxInList == operand; + case lgraph::CompareOp::LBR_NEQ: + return maxInList != operand; + default: + break; + } + return false; +} + +bool MinInListPredicate::eval(std::vector &eits) { + auto ret = cypher::FieldData::Array(0); + for (auto &eit : eits) { + if (eit.IsValid()) { + ret.array->emplace_back(lgraph::FieldData(eit.GetField("timestamp"))); + } + } + if (ret.array->empty()) { + return true; + } + // find min in path + size_t pos = 0; + for (size_t i = 0; i < ret.array->size(); i++) { + if ((*ret.array)[i] < (*ret.array)[pos]) { + pos = i; + } + } + + FieldData minInList = cypher::FieldData((*ret.array)[pos]); + switch (op) { + case lgraph::CompareOp::LBR_GT: + return minInList > operand; + case lgraph::CompareOp::LBR_GE: + return minInList >= operand; + case lgraph::CompareOp::LBR_LT: + return minInList < operand; + case lgraph::CompareOp::LBR_LE: + return minInList <= operand; + case lgraph::CompareOp::LBR_EQ: + return minInList == operand; + case lgraph::CompareOp::LBR_NEQ: + return minInList != operand; + default: + break; + } + return false; +} + +// VarLenExpand Class +bool VarLenExpand::NextWithFilter(RTContext *ctx) { + while (!stack.empty()) { + if (needPop) { + // reach here means, in the previous loop, the path returns, + // so the relp_->path needs pop + relp_->path_.PopBack(); + needPop = false; + } + auto ¤tState = stack.back(); + auto currentNodeId = currentState.currentNodeId; + auto ¤tEit = currentState.currentEit; + auto currentLevel = currentState.level; + + // the number of the neighbor + auto ¤tCount = currentState.count; + + // if currentNodeId's needNext is true, currentEit.next(), then set needNext to false + auto &needNext = currentState.needNext; + if (needNext) { + currentEit->Next(); + currentCount++; + needNext = false; + } + + if (currentLevel == max_hop_) { + // When reach here, the top eiter must be invalid, and the path meets the condition. + // check path unique + if (ctx->path_unique_ && relp_->path_.Length() != 0) { + CYPHER_THROW_ASSERT(pattern_graph_->VisitedEdges().Erase( + relp_->path_.GetNthEdgeWithTid(relp_->path_.Length() - 1))); + } + stack.pop_back(); + + neighbor_->PushVid(currentNodeId); + + if (relp_->path_.Length() != 0) { + needPop = true; + } + return true; + } + + if (currentEit->IsValid()) { + // eit is valid, set currentNodeId's eiter's needNext to true + needNext = true; + + // check predicates here, path derived from eiters in stack + bool passPredicate = true; + for (auto &p : predicates) { + if (!p->eval(relp_->ItsRef())) { + passPredicate = false; + break; + } + } + + if (passPredicate) { + // check path unique + if (ctx->path_unique_ && pattern_graph_->VisitedEdges().Contains(*currentEit)) { + currentEit->Next(); + currentCount++; + needNext = false; + continue; + } else if (ctx->path_unique_) { + pattern_graph_->VisitedEdges().Add(*currentEit); + } + // add edge's euid to path + relp_->path_.Append(currentEit->GetUid()); + auto neighbor = currentEit->GetNbr(expand_direction_); + stack.emplace_back(ctx, neighbor, currentLevel + 1, relp_, expand_direction_, false, + currentLevel + 1 == max_hop_); + } + } else { + // check unique + if (ctx->path_unique_ && relp_->path_.Length() != 0) { + CYPHER_THROW_ASSERT(pattern_graph_->VisitedEdges().Erase( + relp_->path_.GetNthEdgeWithTid(relp_->path_.Length() - 1))); + } + + stack.pop_back(); + if (currentLevel >= min_hop_) { + neighbor_->PushVid(currentNodeId); + + if (relp_->path_.Length() != 0) { + needPop = true; + } + + return true; + } + if (relp_->path_.Length() != 0) { + relp_->path_.PopBack(); + } + } + } + return false; +} + +bool VarLenExpand::Next(RTContext *ctx) { + // check label filter + do { + if (!NextWithFilter(ctx)) return false; + } while (!neighbor_->Label().empty() && neighbor_->IsValidAfterMaterialize(ctx) && + neighbor_->ItRef()->GetLabel() != neighbor_->Label()); + return true; +} + +VarLenExpand::VarLenExpand(PatternGraph *pattern_graph, Node *start, Node *neighbor, + Relationship *relp) + : OpBase(OpType::VAR_LEN_EXPAND, "Variable Length Expand"), + pattern_graph_(pattern_graph), + start_(start), + neighbor_(neighbor), + relp_(relp), + min_hop_(relp->MinHop()), + max_hop_(relp->MaxHop()) { + modifies.emplace_back(neighbor_->Alias()); + modifies.emplace_back(relp_->Alias()); + auto &sym_tab = pattern_graph->symbol_table; + auto sit = sym_tab.symbols.find(start_->Alias()); + auto dit = sym_tab.symbols.find(neighbor_->Alias()); + auto rit = sym_tab.symbols.find(relp_->Alias()); + CYPHER_THROW_ASSERT(sit != sym_tab.symbols.end() && dit != sym_tab.symbols.end() && + rit != sym_tab.symbols.end()); + expand_direction_ = relp_->Undirected() ? BIDIRECTIONAL + : relp_->Src() == start_->ID() ? FORWARD + : REVERSED; + start_rec_idx_ = sit->second.id; + nbr_rec_idx_ = dit->second.id; + relp_rec_idx_ = rit->second.id; +} + +void VarLenExpand::addPredicate(std::unique_ptr p) { + predicates.push_back(std::move(p)); +} + +void VarLenExpand::PushFilter(std::shared_ptr filter) { + if (filter) { + if (filter->Type() == lgraph::Filter::RANGE_FILTER) { + std::shared_ptr tmp_filter = + std::static_pointer_cast(filter); + if (tmp_filter->GetAeLeft().op.type == cypher::ArithOpNode::AR_OP_FUNC) { + std::string func_name = tmp_filter->GetAeLeft().op.func_name; + std::transform(func_name.begin(), func_name.end(), func_name.begin(), ::tolower); + if (func_name == "isasc") { + auto p = std::make_unique(); + addPredicate(std::move(p)); + } else if (func_name == "isdesc") { + auto p = std::make_unique(); + addPredicate(std::move(p)); + } else if (func_name == "head") { + lgraph::CompareOp op = tmp_filter->GetCompareOp(); + FieldData operand = tmp_filter->GetAeRight().operand.constant; + auto p = std::make_unique(op, operand); + addPredicate(std::move(p)); + } else if (func_name == "last") { + lgraph::CompareOp op = tmp_filter->GetCompareOp(); + FieldData operand = tmp_filter->GetAeRight().operand.constant; + auto p = std::make_unique(op, operand); + addPredicate(std::move(p)); + } else if (func_name == "maxinlist") { + lgraph::CompareOp op = tmp_filter->GetCompareOp(); + FieldData operand = tmp_filter->GetAeRight().operand.constant; + auto p = std::make_unique(op, operand); + addPredicate(std::move(p)); + } else if (func_name == "mininlist") { + lgraph::CompareOp op = tmp_filter->GetCompareOp(); + FieldData operand = tmp_filter->GetAeRight().operand.constant; + auto p = std::make_unique(op, operand); + addPredicate(std::move(p)); + } else { + throw lgraph::CypherException("Not in 6 predicates."); + } + } + } + PushFilter(filter->Left()); + PushFilter(filter->Right()); + } + return; +} + +void VarLenExpand::PushDownEdgeFilter(std::shared_ptr edge_filter) { + edge_filter_ = edge_filter; + // add filter to local Predicates + PushFilter(edge_filter); +} + +OpBase::OpResult VarLenExpand::Initialize(RTContext *ctx) { + CYPHER_THROW_ASSERT(!children.empty()); + auto child = children[0]; + auto res = child->Initialize(ctx); + if (res != OP_OK) return res; + record = child->record; + record->values[start_rec_idx_].type = Entry::NODE; + record->values[start_rec_idx_].node = start_; + record->values[nbr_rec_idx_].type = Entry::NODE; + record->values[nbr_rec_idx_].node = neighbor_; + record->values[relp_rec_idx_].type = Entry::VAR_LEN_RELP; + record->values[relp_rec_idx_].relationship = relp_; + relp_->ItsRef().resize(max_hop_); + needPop = false; + return OP_OK; +} + +OpBase::OpResult VarLenExpand::RealConsume(RTContext *ctx) { + CYPHER_THROW_ASSERT(!children.empty()); + auto child = children[0]; + while (!Next(ctx)) { + auto res = child->Consume(ctx); + relp_->path_.Clear(); + if (res != OP_OK) { + return res; + } + // init the first of stack + lgraph::VertexId startVid = start_->PullVid(); + if (startVid < 0) { + continue; + } + CYPHER_THROW_ASSERT(stack.empty()); + // push the first node and the related eiter into the stack + stack.emplace_back(ctx, startVid, 0, relp_, expand_direction_, false, !max_hop_); + + relp_->path_.SetStart(startVid); + } + return OP_OK; +} + +OpBase::OpResult VarLenExpand::ResetImpl(bool complete) { + std::vector().swap(stack); + // stack.clear(); + relp_->path_.Clear(); + return OP_OK; +} + +std::string VarLenExpand::ToString() const { + auto towards = expand_direction_ == FORWARD ? "-->" + : expand_direction_ == REVERSED ? "<--" + : "--"; + std::string edgefilter_str = "VarLenEdgeFilter"; + return fma_common::StringFormatter::Format( + "{}({}) [{} {}*{}..{} {} {}]", name, "All", start_->Alias(), towards, + std::to_string(min_hop_), std::to_string(max_hop_), neighbor_->Alias(), + edge_filter_ ? edgefilter_str.append(" (").append(edge_filter_->ToString()).append(")") + : ""); +} + +} // namespace cypher diff --git a/src/cypher/execution_plan/ops/op_var_len_expand.h b/src/cypher/execution_plan/ops/op_var_len_expand.h index b846285c8f..a937d9c710 100644 --- a/src/cypher/execution_plan/ops/op_var_len_expand.h +++ b/src/cypher/execution_plan/ops/op_var_len_expand.h @@ -14,10 +14,12 @@ // // Created by wt on 18-8-30. +// Modified by bxj on 24-3-30. // #pragma once #include "cypher/execution_plan/ops/op.h" +#include "filter/filter.h" #ifndef NDEBUG #define VAR_LEN_EXP_DUMP_FOR_DEBUG() \ @@ -32,274 +34,101 @@ namespace cypher { +struct DfsState { + // current node id + lgraph::VertexId currentNodeId; + // current index for current node + lgraph::EIter *currentEit; + // level, or path length + int level; + // number of neighbors of this node + int count; + // whether the eiter need Next() + bool needNext; + + DfsState(RTContext *ctx, lgraph::VertexId id, int level, cypher::Relationship *relp, + ExpandTowards expand_direction, bool needNext, bool isMaxHop); +}; + +class Predicate { + public: + virtual bool eval(std::vector &eits) = 0; + virtual ~Predicate() = default; +}; + +class HeadPredicate : public Predicate { + private: + // operator + lgraph::CompareOp op; + // operand, on the right + FieldData operand; + + public: + HeadPredicate(lgraph::CompareOp op, FieldData operand) : op(op), operand(operand) {} + bool eval(std::vector &eits) override; +}; + +class LastPredicate : public Predicate { + private: + // operator + lgraph::CompareOp op; + // operand, on the right + FieldData operand; + + public: + LastPredicate(lgraph::CompareOp op, FieldData operand) : op(op), operand(operand) {} + bool eval(std::vector &eits) override; +}; + +class IsAscPredicate : public Predicate { + public: + IsAscPredicate() {} + bool eval(std::vector &eits) override; +}; + +class IsDescPredicate : public Predicate { + public: + IsDescPredicate() {} + bool eval(std::vector &eits) override; +}; + +class MaxInListPredicate : public Predicate { + private: + lgraph::CompareOp op; + FieldData operand; + + public: + MaxInListPredicate(lgraph::CompareOp op, FieldData operand) : op(op), operand(operand) {} + bool eval(std::vector &eits) override; +}; + +class MinInListPredicate : public Predicate { + private: + lgraph::CompareOp op; + FieldData operand; + + public: + MinInListPredicate(lgraph::CompareOp op, FieldData operand) : op(op), operand(operand) {} + bool eval(std::vector &eits) override; +}; + /* Variable Length Expand */ class VarLenExpand : public OpBase { - void _InitializeEdgeIter(RTContext *ctx, int64_t vid, lgraph::EIter &eit, size_t &count) { - auto &types = relp_->Types(); - auto iter_type = lgraph::EIter::NA; - switch (expand_direction_) { - case ExpandTowards::FORWARD: - iter_type = types.empty() ? lgraph::EIter::OUT_EDGE : lgraph::EIter::TYPE_OUT_EDGE; - break; - case ExpandTowards::REVERSED: - iter_type = types.empty() ? lgraph::EIter::IN_EDGE : lgraph::EIter::TYPE_IN_EDGE; - break; - case ExpandTowards::BIDIRECTIONAL: - iter_type = types.empty() ? lgraph::EIter::BI_EDGE : lgraph::EIter::BI_TYPE_EDGE; - break; - } - eit.Initialize(ctx->txn_->GetTxn().get(), iter_type, vid, types); - count = 1; - } - -#if 0 // 20210704 - void _CollectFrontierByDFS(RTContext *ctx, int64_t vid, const std::set &types, int min_hop, int max_hop) { // NOLINT - if (hop_ >= min_hop) { - if (neighbor_->Label().empty() - || ctx->txn_->GetVertexLabel(ctx->txn_->GetVertexIterator(vid)) == neighbor_->Label()) { // NOLINT - frontier_buffer_.emplace(vid); - path_buffer_.emplace(relp_->path_); - } -#ifndef NDEBUG - FMA_DBG() << __func__ << ": hop=" << hop_ << ",vid=" << vid; - FMA_DBG() << pattern_graph_->VisitedEdges().Dump(); -#endif - } - if (hop_ == max_hop) return; - hop_++; - lgraph::EIter eit; - _InitializeEdgeIter(ctx, vid, eit); - while (eit.IsValid()) { - if (!pattern_graph_->VisitedEdges().Contains(eit)) { - auto r = pattern_graph_->VisitedEdges().Add(eit); - if (!r.second) CYPHER_INTL_ERR(); - relp_->path_.Append(eit.GetUid()); - _CollectFrontierByDFS(ctx, eit.GetNbr(expand_direction_), types, min_hop, max_hop); // NOLINT - relp_->path_.PopBack(); - pattern_graph_->VisitedEdges().Erase(r.first); - } - eit.Next(); - } - hop_--; - } - - void _CollectFrontierByDFS(RTContext *ctx, int64_t vid, const std::set &types, int min_hop) { // NOLINT - if (hop_ == min_hop) { - if (neighbor_->Label().empty() - || ctx->txn_->GetVertexLabel(ctx->txn_->GetVertexIterator(vid)) == neighbor_->Label()) { // NOLINT - frontier_buffer_.emplace(vid); - path_buffer_.emplace(relp_->path_); - } -#ifndef NDEBUG - FMA_LOG() << __func__ << ": hop=" << hop_ << ",vid=" << vid; - FMA_LOG() << pattern_graph_->VisitedEdges().Dump(); -#endif - return; - } - hop_++; - lgraph::EIter eit; - _InitializeEdgeIter(ctx, vid, eit); - while (eit.IsValid()) { - if (!pattern_graph_->VisitedEdges().Contains(eit)) { - auto r = pattern_graph_->VisitedEdges().Add(eit); - if (!r.second) CYPHER_INTL_ERR(); - relp_->path_.Append(eit.GetUid()); - _CollectFrontierByDFS(ctx, eit.GetNbr(expand_direction_), types, min_hop); - relp_->path_.PopBack(); - pattern_graph_->VisitedEdges().Erase(r.first); - } - eit.Next(); - } - hop_--; - } - - OpResult Next(RTContext *ctx) { - if (state_ == Uninitialized) return OP_REFRESH; - /* Start node iterator may be invalid, such as when the start is an argument - * produced by OPTIONAL MATCH. */ - if (!start_it_->IsValid()) return OP_REFRESH; - auto &types = relp_->Types(); - if (collect_all_ || min_hop_ == 0) { // we didnot handle 0hop in other branch - if (state_ == Resetted) { - relp_->path_.SetStart(start_it_->GetId()); - /* collect all the vertex, save them into result_buffer_ */ - _CollectFrontierByDFS(ctx, start_it_->GetId(), types, - min_hop_, max_hop_); - state_ = Consuming; - } - if (frontier_buffer_.empty()) return OP_REFRESH; - nbr_it_->Initialize(ctx->txn_.get(), lgraph::VIter::VERTEX_ITER, - frontier_buffer_.front()); - frontier_buffer_.pop(); - relp_->path_ = path_buffer_.front(); - path_buffer_.pop(); - } else { - // produce one by one - if (state_ == Resetted) { - relp_->path_.SetStart(start_it_->GetId()); - hop_ = 0; - _CollectFrontierByDFS(ctx, start_it_->GetId(), types, min_hop_); - state_ = Consuming; - } - if (frontier_buffer_.empty()) return OP_REFRESH; - auto vid = frontier_buffer_.front(); - frontier_buffer_.pop(); - nbr_it_->Initialize(ctx->txn_.get(), lgraph::VIter::VERTEX_ITER, vid); - relp_->path_ = path_buffer_.front(); - path_buffer_.pop(); - if (relp_->path_.Length() < max_hop_) { - lgraph::EIter eit; - _InitializeEdgeIter(ctx, vid, eit); - // construct visitedEdges from relp_->path_ - pattern_graph_->VisitedEdges().euid_hash_set.clear(); - for (size_t i = 0; i < relp_->path_.Length(); i++) { - pattern_graph_->VisitedEdges().euid_hash_set.emplace(relp_->path_.GetNthEdge(i)); // NOLINT - } - while (eit.IsValid()) { - if (!pattern_graph_->VisitedEdges().Contains(eit)) { - if (neighbor_->Label().empty() || - ctx->txn_->GetVertexLabel( - ctx->txn_->GetVertexIterator(vid)) == - neighbor_->Label()) { - frontier_buffer_.emplace(eit.GetNbr(expand_direction_)); - relp_->path_.Append(eit.GetUid()); - path_buffer_.emplace(relp_->path_); - relp_->path_.PopBack(); - } - } - eit.Next(); - } - } - } // if collect all -#ifndef NDEBUG - FMA_DBG() << "[" << __FILE__ << "] neighbor:" << nbr_it_->GetId(); -#endif - return OP_OK; - } -#endif + bool Next(RTContext *ctx); + bool NextWithFilter(RTContext *ctx); + + void PushFilter(std::shared_ptr filter); + + // save 6 types of predicates + std::vector> predicates; + // add predicate to the vector + void addPredicate(std::unique_ptr p); - bool PerNodeLimit(RTContext *ctx, size_t k) { - return !ctx->per_node_limit_.has_value() || - expand_counts_[k] <= ctx->per_node_limit_.value(); - } - - int64_t GetFirstFromKthHop(RTContext *ctx, size_t k) { - auto start_id = start_->PullVid(); - relp_->path_.Clear(); - relp_->path_.SetStart(start_id); - if (k == 0) return start_id; - _InitializeEdgeIter(ctx, start_id, eits_[0], expand_counts_[0]); - if (!eits_[0].IsValid() || !PerNodeLimit(ctx, 0)) { - return -1; - } - if (k == 1) { - relp_->path_.Append(eits_[0].GetUid()); - if (ctx->path_unique_) pattern_graph_->VisitedEdges().Add(eits_[0]); - return eits_[0].GetNbr(expand_direction_); - } - // k >= 2 - for (size_t i = 0; i < k; i++) { - lgraph::EdgeUid dummy(start_id, start_id, -1, 0, -1); - relp_->path_.Append(dummy); - } - return GetNextFromKthHop(ctx, k, true); - } - - // curr_hop start from 1,2,3.. - int64_t GetNextFromKthHop(RTContext *ctx, size_t k, bool get_first) { - if (k == 0) return -1; - if (ctx->path_unique_) pattern_graph_->VisitedEdges().Erase(eits_[k - 1]); - relp_->path_.PopBack(); - /* If get the first node, the 1st edge(eits[0]) is the only iterator - * that is initialized and should not go next. - **/ - if (!get_first || k != 1 || - (ctx->path_unique_ && pattern_graph_->VisitedEdges().Contains(eits_[k - 1]))) { - do { - expand_counts_[k - 1] += 1; - eits_[k - 1].Next(); - } while (eits_[k - 1].IsValid() && PerNodeLimit(ctx, k - 1) && ctx->path_unique_ && - pattern_graph_->VisitedEdges().Contains(eits_[k - 1])); - } - do { - if (!eits_[k - 1].IsValid() || !PerNodeLimit(ctx, k - 1)) { - auto id = GetNextFromKthHop(ctx, k - 1, get_first); - if (id < 0) return id; - _InitializeEdgeIter(ctx, id, eits_[k - 1], expand_counts_[k - 1]); - /* We have called get_next previously, mark get_first as - * false. */ - get_first = false; - } - while (ctx->path_unique_ && pattern_graph_->VisitedEdges().Contains(eits_[k - 1])) { - expand_counts_[k - 1] += 1; - eits_[k - 1].Next(); - } - } while (!eits_[k - 1].IsValid() || !PerNodeLimit(ctx, k - 1)); - if (!eits_[k - 1].IsValid() || !PerNodeLimit(ctx, k - 1)) { - return -1; - } - relp_->path_.Append(eits_[k - 1].GetUid()); - if (ctx->path_unique_) pattern_graph_->VisitedEdges().Add(eits_[k - 1]); - return eits_[k - 1].GetNbr(expand_direction_); - } - - OpResult NextWithoutLabelFilter(RTContext *ctx) { - if (state_ == Uninitialized) return OP_REFRESH; - /* Start node iterator may be invalid, such as when the start is an argument - * produced by OPTIONAL MATCH. */ - if (start_->PullVid() < 0) return OP_REFRESH; - if (state_ == Resetted) { - // go to min_hop - hop_ = min_hop_; - int64_t nbr_id = GetFirstFromKthHop(ctx, hop_); - if (nbr_id < 0) return OP_REFRESH; - neighbor_->PushVid(nbr_id); - VAR_LEN_EXP_DUMP_FOR_DEBUG(); - state_ = Consuming; - return OP_OK; - } - auto vid = GetNextFromKthHop(ctx, hop_, false); - if (vid >= 0) { - neighbor_->PushVid(vid); - VAR_LEN_EXP_DUMP_FOR_DEBUG(); - return OP_OK; - } else { - // need expand to next hop - if (hop_ == max_hop_) return OP_REFRESH; - hop_++; - auto vid = GetFirstFromKthHop(ctx, hop_ - 1); - if (vid < 0) return OP_REFRESH; - if (hop_ > 1 && !eits_[hop_ - 2].IsValid()) CYPHER_INTL_ERR(); - _InitializeEdgeIter(ctx, vid, eits_[hop_ - 1], expand_counts_[hop_ - 1]); - // TODO(anyone) merge these code similiar to GetNextFromKthHop - do { - if (!eits_[hop_ - 1].IsValid() || !PerNodeLimit(ctx, hop_ - 1)) { - auto v = GetNextFromKthHop(ctx, hop_ - 1, false); - if (v < 0) return OP_REFRESH; - _InitializeEdgeIter(ctx, v, eits_[hop_ - 1], expand_counts_[hop_ - 1]); - } - while (ctx->path_unique_ && - pattern_graph_->VisitedEdges().Contains(eits_[hop_ - 1])) { - expand_counts_[hop_ - 1] += 1; - eits_[hop_ - 1].Next(); - } - } while (!eits_[hop_ - 1].IsValid() || !PerNodeLimit(ctx, hop_ - 1)); - neighbor_->PushVid(eits_[hop_ - 1].GetNbr(expand_direction_)); - relp_->path_.Append(eits_[hop_ - 1].GetUid()); - // TODO(anyone) remove in last hop - if (ctx->path_unique_) pattern_graph_->VisitedEdges().Add(eits_[hop_ - 1]); - VAR_LEN_EXP_DUMP_FOR_DEBUG(); - return OP_OK; - } - } - - OpResult Next(RTContext *ctx) { - do { - if (NextWithoutLabelFilter(ctx) != OP_OK) return OP_REFRESH; - } while (!neighbor_->Label().empty() && neighbor_->IsValidAfterMaterialize(ctx) && - neighbor_->ItRef()->GetLabel() != neighbor_->Label()); - return OP_OK; - } + // stack for DFS + std::vector stack; + + // this flag decides whether need to pop relp_->Path + bool needPop; public: cypher::PatternGraph *pattern_graph_ = nullptr; @@ -311,98 +140,22 @@ class VarLenExpand : public OpBase { int relp_rec_idx_; int min_hop_; int max_hop_; - int hop_; // current hop working on - bool collect_all_; ExpandTowards expand_direction_; - std::vector &eits_; - std::vector expand_counts_; - enum State { - Uninitialized, /* ExpandAll wasn't initialized it. */ - Resetted, /* ExpandAll was just restarted. */ - Consuming, /* ExpandAll consuming data. */ - } state_; - - VarLenExpand(PatternGraph *pattern_graph, Node *start, Node *neighbor, Relationship *relp) - : OpBase(OpType::VAR_LEN_EXPAND, "Variable Length Expand"), - pattern_graph_(pattern_graph), - start_(start), - neighbor_(neighbor), - relp_(relp), - min_hop_(relp->MinHop()), - max_hop_(relp->MaxHop()), - hop_(0), - collect_all_(false), - eits_(relp_->ItsRef()) { - modifies.emplace_back(neighbor_->Alias()); - modifies.emplace_back(relp_->Alias()); - auto &sym_tab = pattern_graph->symbol_table; - auto sit = sym_tab.symbols.find(start_->Alias()); - auto dit = sym_tab.symbols.find(neighbor_->Alias()); - auto rit = sym_tab.symbols.find(relp_->Alias()); - CYPHER_THROW_ASSERT(sit != sym_tab.symbols.end() && dit != sym_tab.symbols.end() && - rit != sym_tab.symbols.end()); - expand_direction_ = relp_->Undirected() ? BIDIRECTIONAL - : relp_->Src() == start_->ID() ? FORWARD - : REVERSED; - start_rec_idx_ = sit->second.id; - nbr_rec_idx_ = dit->second.id; - relp_rec_idx_ = rit->second.id; - expand_counts_.resize(eits_.size()); - state_ = Uninitialized; - } - - OpResult Initialize(RTContext *ctx) override { - CYPHER_THROW_ASSERT(!children.empty()); - auto child = children[0]; - auto res = child->Initialize(ctx); - if (res != OP_OK) return res; - record = child->record; - record->values[start_rec_idx_].type = Entry::NODE; - record->values[start_rec_idx_].node = start_; - record->values[nbr_rec_idx_].type = Entry::NODE; - record->values[nbr_rec_idx_].node = neighbor_; - record->values[relp_rec_idx_].type = Entry::VAR_LEN_RELP; - record->values[relp_rec_idx_].relationship = relp_; - eits_.resize(max_hop_); - return OP_OK; - } - - OpResult RealConsume(RTContext *ctx) override { - CYPHER_THROW_ASSERT(!children.empty()); - auto child = children[0]; - while (state_ == Uninitialized || Next(ctx) == OP_REFRESH) { - auto res = child->Consume(ctx); - relp_->path_.Clear(); - state_ = Resetted; - if (res != OP_OK) { - /* When consume after the stream is DEPLETED, make sure - * the result always be DEPLETED. */ - state_ = Uninitialized; - return res; - } - /* Most of the time, the start_it is definitely valid after child's Consume - * returns OK, except when the child is an OPTIONAL operation. */ - } - return OP_OK; - } - - OpResult ResetImpl(bool complete) override { - state_ = Uninitialized; - // std::queue().swap(frontier_buffer_); - // std::queue().swap(path_buffer_); - hop_ = 0; - // TODO(anyone) reset modifies - return OP_OK; - } - - std::string ToString() const override { - auto towards = expand_direction_ == FORWARD ? "-->" - : expand_direction_ == REVERSED ? "<--" - : "--"; - return fma_common::StringFormatter::Format( - "{}({}) [{} {}*{}..{} {}]", name, "All", start_->Alias(), towards, - std::to_string(min_hop_), std::to_string(max_hop_), neighbor_->Alias()); - } + + // edge_filter_ is temp used + std::shared_ptr edge_filter_ = nullptr; + + VarLenExpand(PatternGraph *pattern_graph, Node *start, Node *neighbor, Relationship *relp); + + void PushDownEdgeFilter(std::shared_ptr edge_filter); + + OpResult Initialize(RTContext *ctx) override; + + OpResult RealConsume(RTContext *ctx) override; + + OpResult ResetImpl(bool complete) override; + + std::string ToString() const override; Node *GetStartNode() const { return start_; } Node *GetNeighborNode() const { return neighbor_; } diff --git a/src/cypher/execution_plan/optimization/edge_filter_pushdown_varlenexpand.h b/src/cypher/execution_plan/optimization/edge_filter_pushdown_varlenexpand.h new file mode 100644 index 0000000000..ae683d4aea --- /dev/null +++ b/src/cypher/execution_plan/optimization/edge_filter_pushdown_varlenexpand.h @@ -0,0 +1,126 @@ +/** + * Copyright 2024 AntGroup CO., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +/* + * Created by bxj on 3/25/23. + */ +#pragma once + +#include "cypher/execution_plan/ops/op_filter.h" +#include "cypher/execution_plan/ops/op_var_len_expand.h" +#include "cypher/execution_plan/optimization/opt_pass.h" + +namespace cypher { +/* + * EdgeFilterPushdownVarLenExpand: + * MATCH p=(src:Account)-[e:transfer*1..3]->(dst:Account) WHERE + * isAsc(getMemberProp(e,'timestamp'))=true + * + * Plan before optimization: + * Filter [{isasc(false,getmemberprop(false,e1,timestamp)) = true}] + * Variable Length Expand(All) [acc -->*1..3 dst] + * + * Plan after optimization: + * Variable Length Expand(All) [acc -->*1..3 dst VarLenEdgeFilter + * {isasc(false,getmemberprop(false,e1,timestamp)) = true}] + */ + +class EdgeFilterPushdownVarLenExpand : public OptPass { + void _AddEdgeFilterOp(OpFilter *&op_filter, VarLenExpand *&op_varlenexpand) { + // op_post -> op_filter -> op_no_edge_filter -> op_varlenexpand + auto filter = op_filter->filter_; + op_varlenexpand->PushDownEdgeFilter(filter); + auto op_post = op_filter->parent; + // find the place of op_filter, then replace it by op_filter->children[0] + for (auto i = op_post->children.begin(); i != op_post->children.end(); i++) { + if (*i == op_filter) { + op_post->RemoveChild(op_filter); + op_post->InsertChild(i, op_filter->children[0]); + delete op_filter; + op_filter = nullptr; + break; + } + } + } + + bool _FindEdgeFilter(OpBase *root, OpFilter *&op_filter, VarLenExpand *&op_varlenexpand) { + auto op = root; + if (op->type == OpType::FILTER && op->children.size() == 1 && + op->children[0]->type == OpType::VAR_LEN_EXPAND) { + op_filter = dynamic_cast(op); + op_varlenexpand = dynamic_cast(op->children[0]); + // if exist filter on varlenexpand edge + std::string edge_alias = op_varlenexpand->relp_->Alias(); + if (op_filter->filter_->ContainAlias({edge_alias}) && + op_filter->filter_->BinaryOnlyContainsAND()) { + // if filter has edge_filter, split filters + auto clone_filter = op_filter->filter_->Clone(); + + // collect filters which only contain edge_alias + // e.g. head(getMemberProp(e2, 'timestamp')) > 1662123596189 + op_filter->filter_->RemoveFilterWhen(op_filter->filter_, + [&edge_alias](const auto &b, const auto &e) { + for (auto it = b; it != e; it++) { + if (*it == edge_alias) return false; + } + return true; + }); + + // collect filter which not contain edge_alias + // e.g. {dst.id = 4687403336918373745} + clone_filter->RemoveFilterWhen(clone_filter, + [&edge_alias](const auto &b, const auto &e) { + for (auto it = b; it != e; it++) { + if (*it == edge_alias) return true; + } + return false; + }); + + // split into two filters, when both are not nullpter + if (clone_filter && op_filter->filter_) { + // op_filter -> op_no_edge_filter -> varlenexpand + auto op_no_edge_filter = new OpFilter(clone_filter); + op_no_edge_filter->AddChild(op_filter->children[0]); + op_filter->RemoveChild(op_filter->children[0]); + op_filter->AddChild(op_no_edge_filter); + } + return true; + } + } + for (auto child : op->children) { + if (_FindEdgeFilter(child, op_filter, op_varlenexpand)) return true; + } + return false; + } + + void _AdjustFilter(OpBase *root) { + OpFilter *op_filter = nullptr; + VarLenExpand *op_varlenexpand = nullptr; + // traverse the query execution plan to judge whether edge_filter exists + while (_FindEdgeFilter(root, op_filter, op_varlenexpand)) { + _AddEdgeFilterOp(op_filter, op_varlenexpand); + } + } + + public: + EdgeFilterPushdownVarLenExpand() : OptPass(typeid(EdgeFilterPushdownVarLenExpand).name()) {} + + bool Gate() override { return true; } + + int Execute(OpBase *root) override { + _AdjustFilter(root); + return 0; + } +}; +} // namespace cypher diff --git a/src/cypher/execution_plan/optimization/pass_manager.h b/src/cypher/execution_plan/optimization/pass_manager.h index 948d68f720..e2b57baab2 100644 --- a/src/cypher/execution_plan/optimization/pass_manager.h +++ b/src/cypher/execution_plan/optimization/pass_manager.h @@ -25,6 +25,7 @@ #include "cypher/execution_plan/optimization/locate_node_by_indexed_prop.h" #include "cypher/execution_plan/optimization/parallel_traversal.h" #include "cypher/execution_plan/optimization/opt_rewrite_with_schema_inference.h" +#include "cypher/execution_plan/optimization/edge_filter_pushdown_varlenexpand.h" namespace cypher { @@ -41,6 +42,7 @@ class PassManager { all_passes_.emplace_back(new PassVarLenExpandWithLimit()); all_passes_.emplace_back(new LocateNodeByVid()); all_passes_.emplace_back(new LocateNodeByIndexedProp()); + all_passes_.emplace_back(new EdgeFilterPushdownVarLenExpand()); // todo(kehuang): ParallelTraversal will cause a crash, temporarily disabling it. // all_passes_.emplace_back(new ParallelTraversal()); } diff --git a/src/cypher/filter/iterator.h b/src/cypher/filter/iterator.h index 9d21c79f18..99de23ac26 100644 --- a/src/cypher/filter/iterator.h +++ b/src/cypher/filter/iterator.h @@ -1137,7 +1137,11 @@ struct EuidHashSet { bool Erase(const lgraph::EIter &eit) { if (!eit.IsValid()) return false; - auto it = euid_hash_set.find(eit.GetUid()); + return Erase(eit.GetUid()); + } + + bool Erase(const lgraph::EdgeUid &euid) { + auto it = euid_hash_set.find(euid); if (it == euid_hash_set.end()) return false; euid_hash_set.erase(it); return true; diff --git a/src/cypher/graph/common.h b/src/cypher/graph/common.h index 29cd525861..deba7d753e 100644 --- a/src/cypher/graph/common.h +++ b/src/cypher/graph/common.h @@ -43,6 +43,7 @@ struct Path { std::vector dirs_; std::vector lids_; std::vector ids_; + std::vector tids_; Path() = default; @@ -63,12 +64,14 @@ struct Path { lids_.push_back(edge.lid); ids_.push_back(edge.eid); ids_.push_back(edge.dst); + tids_.push_back(edge.tid); } else if (ids_.back() == edge.dst) { // backward dirs_.push_back(false); lids_.push_back(edge.lid); ids_.push_back(edge.eid); ids_.push_back(edge.src); + tids_.push_back(edge.tid); } else { throw std::runtime_error("The edge doesn't match the path's end."); } @@ -79,6 +82,7 @@ struct Path { std::reverse(lids_.begin(), lids_.end()); std::reverse(dirs_.begin(), dirs_.end()); for (size_t i = 0; i < dirs_.size(); i++) dirs_[i] = !dirs_[i]; + std::reverse(tids_.begin(), tids_.end()); } void PopBack() { @@ -86,12 +90,14 @@ struct Path { lids_.pop_back(); ids_.pop_back(); ids_.pop_back(); + tids_.pop_back(); } void Clear() { dirs_.clear(); lids_.clear(); ids_.clear(); + tids_.clear(); } lgraph::EdgeUid GetNthEdge(size_t n) const { @@ -105,6 +111,17 @@ struct Path { ids_[1 + n * 2]); // TODO(heng) } + lgraph::EdgeUid GetNthEdgeWithTid(size_t n) const { + size_t length = dirs_.size(); + if (n >= length) { + throw std::runtime_error("Access out of range."); + } + return dirs_[n] ? lgraph::EdgeUid(ids_[0 + n * 2], ids_[2 + n * 2], lids_[n], tids_[n], + ids_[1 + n * 2]) + : lgraph::EdgeUid(ids_[2 + n * 2], ids_[0 + n * 2], lids_[n], tids_[n], + ids_[1 + n * 2]); + } + std::string ToString() const { std::string str = "["; int64_t src, dst;