Skip to content

Commit

Permalink
Merge pull request #1262 from mabel-dev/#1261-Part1
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Nov 14, 2023
2 parents 617051f + 7630c39 commit 457c604
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 21 deletions.
5 changes: 4 additions & 1 deletion opteryx/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def query_planner(operation: str, parameters: list, connection, qid: str):

from opteryx.components.ast_rewriter import do_ast_rewriter
from opteryx.components.binder import do_bind_phase
from opteryx.components.heuristic_optimizer import do_heuristic_optimizer
from opteryx.components.logical_planner import do_logical_planning_phase
from opteryx.components.sql_rewriter import do_sql_rewrite
from opteryx.components.temporary_physical_planner import create_physical_plan
Expand Down Expand Up @@ -102,7 +103,9 @@ def query_planner(operation: str, parameters: list, connection, qid: str):
# common_table_expressions=ctes,
)

heuristic_optimized_plan = do_heuristic_optimizer(bound_plan)

# before we write the new optimizer and execution engine, convert to a V1 plan
query_properties = QueryProperties(qid=qid, variables=connection.context.variables)
physical_plan = create_physical_plan(bound_plan, query_properties)
physical_plan = create_physical_plan(heuristic_optimized_plan, query_properties)
yield physical_plan
3 changes: 2 additions & 1 deletion opteryx/components/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def locate_identifier_in_loaded_schemas(
found = schema.find_column(value)
if found:
if column and found_source_relation:
# test for duplicates
raise AmbiguousIdentifierError(identifier=value)
found_source_relation = schema
column = found
column = found # don't exit here, so we can test for duplicates

return column, found_source_relation

Expand Down
50 changes: 33 additions & 17 deletions opteryx/components/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def visit_node(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
"""
node_type = node.node_type.name
visit_method_name = f"visit_{CAMEL_TO_SNAKE.sub('_', node_type).lower()}"

visit_method = getattr(self, visit_method_name, None)
if visit_method is None:
return node, context
Expand All @@ -233,17 +232,22 @@ def visit_node(self, node: Node, context: BindingContext) -> Tuple[Node, Binding

if not isinstance(return_context, BindingContext):
raise InvalidInternalStateError(
f"Internal Error - function {visit_method_name} didn't return a BindingContext"
f"Internal Error - function '{visit_method_name}' didn't return a BindingContext"
)

if not all(isinstance(schema, RelationSchema) for schema in context.schemas.values()):
raise InvalidInternalStateError(
f"Internal Error - function {visit_method_name} returned invalid Schemas"
f"Internal Error - function '{visit_method_name}' returned invalid Schemas"
)

if not all(isinstance(col, (Node, LogicalColumn)) for col in return_node.columns or []):
raise InvalidInternalStateError(
f"Internal Error - function {visit_method_name} put unexpected items in 'columns' attribute"
f"Internal Error - function '{visit_method_name}' put unexpected items in 'columns' attribute"
)

if return_node.node_type.name != "Scan" and return_node.columns is None:
raise InvalidInternalStateError(
f"Internal Error - function {visit_method_name} did not populate 'columns'"
)

return return_node, return_context
Expand Down Expand Up @@ -280,14 +284,11 @@ def visit_aggregate_and_group(
tmp_groups, _ = zip(*(inner_binder(group, context) for group in node.groups))
columns_to_keep = {col.schema_column.identity for col in tmp_groups}
# 2) the columns referenced in the SELECT
all_identifiers = [
node.schema_column.identity
for node in get_all_nodes_of_type(
node.aggregates + node.groups, select_nodes=(NodeType.IDENTIFIER,)
)
]
node.columns = get_all_nodes_of_type(
node.aggregates + node.groups, select_nodes=(NodeType.IDENTIFIER,)
)
all_identifiers = [node.schema_column.identity for node in node.columns]
columns_to_keep = columns_to_keep.union(all_identifiers)
node.all_identifiers = columns_to_keep

for name, schema in list(context.schemas.items()):
schema_columns = [
Expand All @@ -310,10 +311,12 @@ def visit_aggregate_and_group(
visit_aggregate = visit_aggregate_and_group

def visit_distinct(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
node.columns = []
if node.on:
# Bind the local columns to physical columns
node.on, group_contexts = zip(*(inner_binder(col, context) for col in node.on))
context.schemas = merge_schemas(*[ctx.schemas for ctx in group_contexts])
node.columns = get_all_nodes_of_type(node.on, (NodeType.IDENTIFIER,))

return node, context

Expand Down Expand Up @@ -394,6 +397,14 @@ def keep_column(column, identities):

return node, context

def visit_filter(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
# We don't update the context, otherwise we'd be adding the predicates as columns
original_context = context.copy()
node.condition, context = inner_binder(node.condition, context)
node.columns = get_all_nodes_of_type(node.condition, (NodeType.IDENTIFIER,))

return node, original_context

def visit_function_dataset(
self, node: Node, context: BindingContext
) -> Tuple[Node, BindingContext]:
Expand Down Expand Up @@ -526,6 +537,7 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
Tuple[Node, Dict]
Updated node and context.
"""
node.columns = []
# Handle 'natural join' by converting to a 'using'
if node.type == "natural join":
left_columns = [
Expand Down Expand Up @@ -563,6 +575,10 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
"JOIN conditions cannot include literal constant values."
)

# we need to put the referenced columns into the columns attribute for the
# optimizers
node.columns = get_all_nodes_of_type(node.on, (NodeType.IDENTIFIER,))

if node.using:
# Remove the columns used in the join condition from both relations, they're in
# the result set but not belonging to either table, whilst still belonging to both.
Expand Down Expand Up @@ -608,6 +624,7 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
node.unnest_alias = f"UNNEST({node.unnest_column.query_column})"
# this is the column which is being unnested
node.unnest_column, context = inner_binder(node.unnest_column, context)
node.columns += [node.unnest_column]
# this is the column that is being created - find it from its name
node.unnest_target, found_source_relation = locate_identifier_in_loaded_schemas(
node.unnest_alias, context.schemas
Expand Down Expand Up @@ -697,14 +714,9 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind

return node, context

def visit_filter(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
original_context = context.copy()
node.condition, context = inner_binder(node.condition, context)

return node, original_context

def visit_order(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
order_by = []
columns = []
for column, direction in node.order_by:
bound_column, context = inner_binder(column, context)

Expand All @@ -714,8 +726,10 @@ def visit_order(self, node: Node, context: BindingContext) -> Tuple[Node, Bindin
"ascending" if direction else "descending",
)
)
columns.append(bound_column)

node.order_by = order_by
node.columns = columns
return node, context

def visit_scan(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
Expand Down Expand Up @@ -754,11 +768,13 @@ def visit_scan(self, node: Node, context: BindingContext) -> Tuple[Node, Binding

def visit_set(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
node.variables = context.connection.variables
node.columns = []
return node, context

def visit_show_columns(
self, node: Node, context: BindingContext
) -> Tuple[Node, BindingContext]:
node.columns = []
node.schema = context.schemas[node.relation]
return node, context

Expand Down
89 changes: 87 additions & 2 deletions opteryx/components/heuristic_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
└───────────┘ └───────────┘ ╚═══════════╝
~~~
The plan rewriter does basic heuristic rewrites of the plan, this is an evolution of the old optimizer
The plan rewriter does basic heuristic rewrites of the plan, this is an evolution of the old optimizer.
Do things like:
- split predicates into as many AND conditions as possible
Expand All @@ -49,5 +49,90 @@
New things:
- replace subqueries with joins
- use knowledge about value ranges to prefilter (e.g. prune at read-time before joins)
This is written as a Visitor, unlike the binder which is working from the scanners up to
the projection, this starts at the projection and works toward the scanners. This works well because
the main activity we're doing is splitting nodes, individual node rewrites, and push downs.
"""
from opteryx.components.logical_planner import LogicalPlan
from opteryx.components.logical_planner import LogicalPlanStepType


# Context object to carry state
class HeuristicOptimizerContext:
def __init__(self, tree: LogicalPlan):
self.pre_optimized_tree = tree
self.optimized_tree = LogicalPlan()

# We collect predicates that reference single relations so we can push them
# as close to the read as possible, including off to remote systems
self.collected_predicates = []

# We collect column identities so we can push column selection as close to the
# read as possible, including off to remote systems
self.collected_identities = []


# Optimizer Visitor
class HeuristicOptimizerVisitor:
def rewrite_predicates(self, node):
pass

def collect_columns(self, node):
if node.columns:
return [col.schema_column.identity for col in node.columns if col.schema_column]
return []

def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext):
# collect column references to push PROJECTION
# rewrite conditions to get as many AND conditions as possible
# collect predicates which reference one relation to push SELECTIONS
# get rid of NESTED nodes

node = context.pre_optimized_tree[nid]

# do this before any transformations
if node.node_type != LogicalPlanStepType.Scan:
context.collected_identities.extend(self.collect_columns(node))

if node.node_type == LogicalPlanStepType.Filter:
# rewrite predicates, to favor conjuctions and reduce negations
# split conjunctions
# collect predicates
pass
if node.node_type == LogicalPlanStepType.Scan:
# push projections
node_columns = [
col
for col in node.schema.columns
if col.identity in set(context.collected_identities)
]
# print("FOUND")
# print(node_columns)
# print("NOT FOUND")
# print([col for col in node.schema.columns if col.identity not in set(context.collected_identities)])
# push selections
pass

context.optimized_tree.add_node(nid, node)
if parent:
context.optimized_tree.add_edge(nid, parent)

return context

def traverse(self, tree: LogicalPlan):
root = tree.get_exit_points().pop()
context = HeuristicOptimizerContext(tree)

def _inner(parent, node, context):
context = self.visit(parent, node, context)
for child, _, _ in tree.ingoing_edges(node):
_inner(node, child, context)

_inner(None, root, context)
return context.optimized_tree


def do_heuristic_optimizer(plan: LogicalPlan) -> LogicalPlan:
optimizer = HeuristicOptimizerVisitor()
return optimizer.traverse(plan)
1 change: 1 addition & 0 deletions opteryx/models/query_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, qid: int, variables):
self.variables: dict[str, Any] = variables
self.temporal_filters: list = []
self.date = datetime.datetime.utcnow().date()
self.current_time = datetime.datetime.utcnow()
self.cache = None
self.qid = qid
self.ctes: dict = {}

0 comments on commit 457c604

Please sign in to comment.