Skip to content

Commit

Permalink
Merge pull request #2107 from mabel-dev/#2100/1
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Nov 27, 2024
2 parents c90f7f1 + 0e95cca commit 1e5e25f
Show file tree
Hide file tree
Showing 27 changed files with 349 additions and 207 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 865
__build__ = 871

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
301 changes: 188 additions & 113 deletions opteryx/models/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
The Execution Tree is the Graph which defines a Query Plan.
The execution tree contains functionality to:
- build and define the plan
- execute the plan
- manipulate the plan
"""

from queue import Empty
from queue import Queue
from threading import Lock
from threading import Thread
from typing import Any
from typing import Generator
from typing import Optional
from typing import Tuple

import pyarrow

from opteryx import EOS
from opteryx import config
from opteryx.constants import ResultType
from opteryx.exceptions import InvalidInternalStateError
from opteryx.third_party.travers import Graph

import pyarrow

morsel_lock = Lock()
active_task_lock = Lock()
active_tasks: int = 0

def active_tasks_increment(value: int):
global active_tasks
with active_task_lock:
active_tasks += value


class PhysicalPlan(Graph):
"""
The execution tree is defined separately to the planner to simplify the
complex code which is the planner from the tree that describes the plan.
The execution tree is defined separately from the planner to simplify the
complex code that is the planner from the tree that describes the plan.
"""

def explainv2(self, analyze: bool) -> Generator[pyarrow.Table, None, None]:
def depth_first_search_flat(
self, node: Optional[str] = None, visited: Optional[set] = None
) -> list:
"""
Returns a flat list representing the depth-first traversal of the graph with left/right ordering.
"""
if node is None:
node = self.get_exit_points()[0]

if visited is None:
visited = set()

visited.add(node)
traversal_list = [(node, self[node])]

# Sort neighbors based on relationship to ensure left, right, then unlabelled order
neighbors = sorted(self.ingoing_edges(node), key=lambda x: (x[2] == "right", x[2] == ""))

for neighbor, _, _ in neighbors:
if neighbor not in visited:
child_list = self.depth_first_search_flat(neighbor, visited)
traversal_list.extend(child_list)

return traversal_list

def explain(self, analyze: bool) -> Generator[pyarrow.Table, None, None]:
from opteryx import operators

def _inner_explain(node, depth):
Expand Down Expand Up @@ -86,61 +112,69 @@ def _inner_explain(node, depth):
plan = list(_inner_explain(head[0], 1))

table = pyarrow.Table.from_pylist(plan)
print(table)
return table

yield table

def depth_first_search_flat(
self, node: Optional[str] = None, visited: Optional[set] = None
) -> list:
"""
Returns a flat list representing the depth-first traversal of the graph with left/right ordering.
We do this so we always evaluate the left side of a join before the right side. It technically
doesn't need the entire plan flattened DFS-wise, but this is what we are doing here to achieve
the outcome we're after.
"""
if node is None:
node = self.get_exit_points()[0]

if visited is None:
visited = set()

visited.add(node)

# Collect this node's information in a flat list format
traversal_list = [
(
node,
self[node],
)
]

# Sort neighbors based on relationship to ensure left, right, then unlabelled order
neighbors = sorted(self.ingoing_edges(node), key=lambda x: (x[2] == "right", x[2] == ""))

# Traverse each child, prioritizing left, then right, then unlabelled
for neighbor, _, _ in neighbors:
if neighbor not in visited:
child_list = self.depth_first_search_flat(neighbor, visited)
traversal_list.extend(child_list)

return traversal_list

def execute(self, head_node=None) -> Tuple[Generator[pyarrow.Table, Any, Any], ResultType]:
def execute(self, head_node=None) -> Generator[Tuple[Any, ResultType], Any, Any]:
from opteryx.operators import ExplainNode
from opteryx.operators import JoinNode
from opteryx.operators import ReaderNode
from opteryx.operators import SetVariableNode
from opteryx.operators import ShowCreateNode
from opteryx.operators import ShowValueNode

# Validate query plan to ensure it's acyclic
morsel_accounting = {nid: 0 for nid in self.nodes()} # Total morsels received by each node
node_exhaustion = {nid: False for nid in self.nodes()} # Exhaustion state of each node

def mark_node_exhausted(node_id):
"""
Mark a node as exhausted and propagate exhaustion downstream.
"""
if node_exhaustion[node_id]:
return # Node is already marked as exhausted

node_exhaustion[node_id] = True
print("+", node_id, self[node_id].name)

# Notify downstream nodes
for _, downstream_node, _ in self.outgoing_edges(node_id):
# Check if all parents of downstream_node are exhausted
if all(
node_exhaustion[parent] for parent, _, _ in self.ingoing_edges(downstream_node)
):
work_queue.put((downstream_node, EOS)) # EOS signals exhaustion
active_tasks_increment(+1)
morsel_accounting[node_id] += 1

def update_morsel_accounting(node_id, morsel_count_change: int):
"""
Updates the morsel accounting for a node and checks for exhaustion.
Parameters:
node_id (str): The ID of the node to update.
morsel_count_change (int): The change in morsel count (+1 for increment, -1 for decrement).
Returns:
None
"""
with morsel_lock:
morsel_accounting[node_id] += morsel_count_change
# print(">", node_id, morsel_accounting[node_id], morsel_count_change, self[node_id].name)

# Check if the node is exhausted
if morsel_accounting[node_id] <= 0: # No more pending morsels for this node
# Ensure all parent nodes are exhausted
all_parents_exhausted = all(
node_exhaustion[parent] for parent, _, _ in self.ingoing_edges(node_id)
)
if all_parents_exhausted:
mark_node_exhausted(node_id)

if not self.is_acyclic():
raise InvalidInternalStateError("Query plan is cyclic, cannot execute.")

# Retrieve the tail of the query plan, which should ideally be a single head node
head_nodes = list(set(self.get_exit_points()))

if len(head_nodes) != 1:
raise InvalidInternalStateError(
f"Query plan has {len(head_nodes)} heads, expected exactly 1."
Expand All @@ -149,77 +183,118 @@ def execute(self, head_node=None) -> Tuple[Generator[pyarrow.Table, Any, Any], R
if head_node is None:
head_node = self[head_nodes[0]]

# add the left/right labels to the edges coming into the joins
joins = [(nid, node) for nid, node in self.nodes(True) if isinstance(node, JoinNode)]
for nid, join in joins:
for s, t, r in self.breadth_first_search(nid, reverse=True):
source_relations = self[s].parameters.get("all_relations", set())
if set(join._left_relation).intersection(source_relations):
self.remove_edge(s, t, r)
self.add_edge(s, t, "left")
elif set(join._right_relation).intersection(source_relations):
self.remove_edge(s, t, r)
self.add_edge(s, t, "right")

# Special case handling for 'Explain' queries
if isinstance(head_node, ExplainNode):
yield self.explainv2(head_node.analyze), ResultType.TABULAR

# Special case handling for 'Set' queries
elif isinstance(head_node, SetVariableNode):
yield head_node(None), ResultType.NON_TABULAR
yield self.explain(head_node.analyze), ResultType.TABULAR

elif isinstance(head_node, (ShowValueNode, ShowCreateNode)):
elif isinstance(head_node, (SetVariableNode, ShowValueNode, ShowCreateNode)):
yield head_node(None), ResultType.TABULAR

else:
# Work queue for worker tasks
work_queue = Queue()
# Response queue for results sent back to the engine
response_queue = Queue()
num_workers = 1
workers = []

def worker_process():
"""
Worker thread: Processes tasks from the work queue and sends results to the response queue.
"""
while True:
task = work_queue.get()
if task is None:
break

node_id, morsel = task
if morsel_accounting[node_id] is False:
print("RUNNING AN EXHAUSTED NODE")
operator = self[node_id]
results = operator(morsel)

for result in results:
# Send results back to the response queue
response_queue.put((node_id, result))

update_morsel_accounting(node_id, -1)

work_queue.task_done()

# Launch worker threads
for _ in range(num_workers):
worker = Thread(target=worker_process)
worker.daemon = True
worker.start()
workers.append(worker)

def inner_execute(plan):
# Get the pump nodes from the plan and execute them in order
# Identify pump nodes
global active_tasks

pump_nodes = [
(nid, node)
for nid, node in self.depth_first_search_flat()
if isinstance(node, ReaderNode)
]

# Main engine loop processes pump nodes and coordinates work
for pump_nid, pump_instance in pump_nodes:
for morsel in pump_instance(None):
yield from plan.process_node(pump_nid, morsel)
# Initial morsels pushed to the work queue
# Determine downstream operators
next_nodes = [target for _, target, _ in self.outgoing_edges(pump_nid)]
for downstream_node in next_nodes:
# Queue tasks for downstream operators
work_queue.put((downstream_node, morsel))
active_tasks_increment(+1)
update_morsel_accounting(downstream_node, +1)

# Pump is exhausted after emitting all morsels
mark_node_exhausted(pump_nid)

# Process results from the response queue
def should_stop():
all_nodes_exhausted = all(node_exhaustion.values())
queues_empty = work_queue.empty() and response_queue.empty()
all_nodes_inactive = active_tasks <= 0
print(node_exhaustion.values(), all(node_exhaustion.values()), work_queue.empty(), response_queue.empty(), active_tasks)
return all_nodes_exhausted and queues_empty and all_nodes_inactive

while not should_stop():
# Wait for results from workers
try:
node_id, result = response_queue.get(timeout=0.1)
except Empty:
continue

# Handle EOS
if result is None or result == EOS:
active_tasks_increment(-1)
continue

# Determine downstream operators
downstream_nodes = [target for _, target, _ in self.outgoing_edges(node_id)]
if len(downstream_nodes) == 0:
# print("YIELD")
yield result
else:
for downstream_node in downstream_nodes:
# Queue tasks for downstream operators
active_tasks_increment(+1)
work_queue.put((downstream_node, result))
update_morsel_accounting(downstream_node, +1)

yield inner_execute(self), ResultType.TABULAR
# decrement _after_ we've done the work relation to handling the task
active_tasks_increment(-1)

def process_node(self, nid, morsel):
from opteryx.operators import ReaderNode
# print("DONE!", node_exhaustion, work_queue.empty(), response_queue.empty())

node = self[nid]
for worker in workers:
work_queue.put(None)

if isinstance(node, ReaderNode):
children = (t for s, t, r in self.outgoing_edges(nid))
for child in children:
results = self.process_node(child, morsel)
results = list(results)
yield from results
else:
results = node(morsel)
if results is None:
return None
if not isinstance(results, list):
results = [results]
if morsel == EOS and not any(r == EOS for r in results):
results.append(EOS)
for result in results:
if result is not None:
children = [t for s, t, r in self.outgoing_edges(nid)]
for child in children:
yield from self.process_node(child, result)
if len(children) == 0 and result != EOS:
yield result
# Wait for all workers to complete
for worker in workers:
worker.join()

def sensors(self):
readings = {}
for nid in self.nodes():
node = self[nid]
readings[node.identity] = node.sensors()
return readings

def __del__(self):
pass
yield inner_execute(self), ResultType.TABULAR
3 changes: 3 additions & 0 deletions opteryx/models/query_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __setattr__(self, attr, value):
else:
self._stats[attr] = value

def increase(self, attr: str, amount: float):
self._stats[attr] += amount

def add_message(self, message: str):
"""collect warnings"""
if "messages" not in self._stats:
Expand Down
Loading

0 comments on commit 1e5e25f

Please sign in to comment.