Skip to content

Commit

Permalink
Merge pull request #1276 from mabel-dev/#1275
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Nov 19, 2023
2 parents e2d844b + d7a4875 commit e7553bc
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 62 deletions.
77 changes: 58 additions & 19 deletions opteryx/components/heuristic_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,15 @@
the projection, this starts at the projection and works toward the scanners. This works well because
the main activity we're doing is splitting nodes, individual node rewrites, and push downs.
"""
from orso.tools import random_string

from opteryx.components.logical_planner import LogicalPlan
from opteryx.components.logical_planner import LogicalPlanNode
from opteryx.components.logical_planner import LogicalPlanStepType
from opteryx.components.rules import heuristic_optimizer
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import get_all_nodes_of_type
from opteryx.models import Node


# Context object to carry state
Expand Down Expand Up @@ -111,38 +116,72 @@ def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext):
if node.node_type == LogicalPlanStepType.Filter:
# rewrite predicates, to favor conjuctions and reduce negations
# split conjunctions
nodes = heuristic_optimizer.rule_split_conjunctive_predicates(node)
# deduplicate the nodes - note this 'randomizes' the order
nodes = _unique_nodes(nodes)

previous = parent
for predicate_node in nodes:
predicate_nid = random_string()
plan_node = LogicalPlanNode(
node_type=LogicalPlanStepType.Filter, condition=predicate_node
)
context.optimized_tree.add_node(predicate_nid, plan_node)
context.optimized_tree.add_edge(predicate_nid, previous)
previous = predicate_nid

# collect predicates
pass
if node.node_type == LogicalPlanStepType.Scan:
# push projections
node_columns = [
col for col in node.schema.columns if col.identity in context.collected_identities
]
# these are the pushed columns
node.columns = node_columns
if node.node_type == LogicalPlanStepType.Join:
# push predicates which reference multiple relations here
pass

context.optimized_tree.add_node(nid, node)
if parent:
context.optimized_tree.add_edge(nid, parent)

return context

return previous, context

else:
if node.node_type == LogicalPlanStepType.Scan:
# push projections
node_columns = [
col
for col in node.schema.columns
if col.identity in context.collected_identities
]
# these are the pushed columns
node.columns = node_columns
elif node.node_type == LogicalPlanStepType.Join:
# push predicates which reference multiple relations here
pass

context.optimized_tree.add_node(nid, LogicalPlanNode(**node.properties))
if parent:
context.optimized_tree.add_edge(nid, parent)

return None, context

def traverse(self, tree: LogicalPlan):
root = tree.get_exit_points().pop()
context = HeuristicOptimizerContext(tree)

def _inner(parent, node, context):
context = self.visit(parent, node, context)
parent, context = self.visit(parent, node, context)
for child, _, _ in tree.ingoing_edges(node):
_inner(node, child, context)
_inner(parent or node, child, context)

_inner(None, root, context)
# print(context.optimized_tree.draw())
return context.optimized_tree


def _unique_nodes(nodes: list) -> list:
seen_identities = {}

for node in nodes:
identity = node.schema_column.identity
if identity not in seen_identities:
seen_identities[identity] = node
else:
if node.left.schema_column and node.right.schema_column:
seen_identities[identity] = node

return list(seen_identities.values())


def do_heuristic_optimizer(plan: LogicalPlan) -> LogicalPlan:
optimizer = HeuristicOptimizerVisitor()
return optimizer.traverse(plan)
1 change: 1 addition & 0 deletions opteryx/components/rules/heuristic_optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .split_conjuctive_predicates import rule_split_conjunctive_predicates
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
Type: Heuristic
Goal: Reduce rows
"""
from opteryx import operators

from opteryx.managers.expression import NodeType
from opteryx.utils import random_string


def split_conjunctive_predicates(plan, properties):
def rule_split_conjunctive_predicates(node):
"""
Conjunctive Predicates (ANDs) can be split and executed in any order to get the
same result. This means we can split them into separate steps in the plan.
Expand All @@ -37,38 +36,14 @@ def split_conjunctive_predicates(plan, properties):
the check (a numeric check is faster than a string check)
"""

def _inner_split(plan, nid, operator):
selection = operator.filter
if selection.node_type != NodeType.AND:
return plan
def _inner_split(node):
if node.node_type != NodeType.AND:
return [node]

# get the left and right filters
left_node = operators.SelectionNode(filter=selection.left, properties=properties)
right_node = operators.SelectionNode(filter=selection.right, properties=properties)
# insert them into the plan and remove the old node
# we're chaining the new operators
uid = random_string() # avoid collisions
plan.insert_node_before(f"{nid}-{uid}-right", right_node, nid)
plan.insert_node_before(f"{nid}-{uid}-left", left_node, f"{nid}-{uid}-right")
plan.remove_node(nid, heal=True)

# recurse until we get to a non-AND condition
plan = _inner_split(plan, f"{nid}-{uid}-right", right_node)
plan = _inner_split(plan, f"{nid}-{uid}-left", left_node)

return plan

# find the in-scope nodes
selection_nodes = plan.get_nodes_of_type(operators.SelectionNode)

# killer questions - if any aren't met, bail
if selection_nodes is None:
return plan
left_nodes = _inner_split(node.left)
right_nodes = _inner_split(node.right)

# HAVING and WHERE are selection nodes
for nid in selection_nodes:
# get the node from the node_id
operator = plan[nid]
plan = _inner_split(plan, nid, operator)
return left_nodes + right_nodes

return plan
return _inner_split(node.condition)
19 changes: 12 additions & 7 deletions opteryx/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def register_store(prefix, connector, *, remove_prefix: bool = False, **kwargs):


def register_df(name, frame):
"""register a pandas or Polars dataframe"""
"""register a orso, pandas or Polars dataframe"""
# polars (maybe others) - the polars to arrow API is a mess
if hasattr(frame, "_df"):
frame = frame._df
Expand All @@ -61,6 +61,11 @@ def register_df(name, frame):
arrow = pyarrow.Table.from_batches(arrow)
register_arrow(name, arrow)
return
# orso
if hasattr(frame, "arrow"):
arrow = frame.arrow()
register_arrow(name, arrow)
return
# pandas
frame_type = str(type(frame))
if "pandas" in frame_type:
Expand Down Expand Up @@ -89,19 +94,19 @@ def connector_factory(dataset, statistics, **config):

# Look up the prefix from the registered prefixes
connector_entry: dict = config
for prefix in _storage_prefixes.keys():
if dataset.startswith(prefix):
connector_entry = _storage_prefixes[prefix].copy() # type: ignore
for prefix, storage_details in _storage_prefixes.items():
if dataset == prefix or dataset.startswith(prefix + "."):
connector_entry = storage_details.copy() # type: ignore
connector = connector_entry.pop("connector")
break
else:
if os.path.isfile(dataset):
from opteryx.connectors import file_connector

return file_connector.FileConnector(dataset=dataset, statistics=statistics)
else:
# fall back to the default connector (local disk if not set)
connector = _storage_prefixes.get("_default", DiskConnector)

# fall back to the default connector (local disk if not set)
connector = _storage_prefixes.get("_default", DiskConnector)

prefix = connector_entry.pop("prefix", "")
remove_prefix = connector_entry.pop("remove_prefix", False)
Expand Down
21 changes: 21 additions & 0 deletions tests/misc/test_connector_prefixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
"""
import os
import sys
import pytest

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

import opteryx
from opteryx.connectors import GcpFireStoreConnector, SqlConnector, register_store
from sqlalchemy.exc import NoSuchTableError
from opteryx.exceptions import DatasetNotFoundError

register_store(
"sqlite",
Expand Down Expand Up @@ -44,7 +47,25 @@ def test_connector_prefixes():
assert cur.rowcount == 7, cur.rowcount


def test_connector_prefixes_negative_tests():
with pytest.raises(NoSuchTableError):
# this should be the SQLAlchemy error
opteryx.query("SELECT * from planets.planets")

with pytest.raises(DatasetNotFoundError):
# this should NOT be the SQLAlchemy error
opteryx.query("SELECT * FROM planetsplanets.planets")

with pytest.raises(DatasetNotFoundError):
# this should NOT be the SQLAlchemy error
opteryx.query("SELECT * FROM planets_planets.planets")

with pytest.raises(DatasetNotFoundError):
opteryx.query("SELECT * FROM fsu.til")


if __name__ == "__main__": # pragma: no cover
from tests.tools import run_tests

test_connector_prefixes_negative_tests()
run_tests()
4 changes: 2 additions & 2 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,8 +685,8 @@
("SET disable_morsel_defragmentation = 100;", None, None, ValueError),
("SET disable_morsel_defragmentation = true; EXPLAIN SELECT * FROM $satellites WHERE id = 8", 2, 3, None),
("SET disable_optimizer = false; EXPLAIN SELECT * FROM $satellites WHERE id = 8", 2, 3, None),
("SET disable_optimizer = true; EXPLAIN SELECT * FROM $satellites WHERE id = 8 AND id = 7", 2, 3, None),
("SET disable_optimizer = false; EXPLAIN SELECT * FROM $satellites WHERE id = 8 AND id = 7", 2, 3, None),
("SET disable_optimizer = true; EXPLAIN SELECT * FROM $satellites WHERE id = 8 AND id = 7", 3, 3, None),
("SET disable_optimizer = false; EXPLAIN SELECT * FROM $satellites WHERE id = 8 AND id = 7", 3, 3, None),
("SET disable_optimizer = false; EXPLAIN SELECT * FROM $planets ORDER BY id LIMIT 5", 3, 3, None),
("SET disable_optimizer = true; EXPLAIN SELECT * FROM $planets ORDER BY id LIMIT 5", 3, 3, None),
("EXPLAIN SELECT * FROM $planets ORDER BY id LIMIT 5", 3, 3, None),
Expand Down

0 comments on commit e7553bc

Please sign in to comment.