Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Oct 29, 2023
1 parent dd7a721 commit 0781a2a
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 45 deletions.
19 changes: 11 additions & 8 deletions opteryx/components/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ def locate_identifier_in_loaded_schemas(
return column, found_source_relation


def locate_identifier(node: Node, context: Dict[str, Any]) -> Tuple[Node, Dict]:
def locate_identifier(node: Node, context: "BindingContext") -> Tuple[Node, Dict]:
"""
Locate which schema the identifier is defined in. We return a populated node
and the context.
Parameters:
node: Node
The node representing the identifier
context: Dict[str, Any]
context: BindingContext
The current query context.
Returns:
Expand All @@ -110,8 +110,9 @@ def locate_identifier(node: Node, context: Dict[str, Any]) -> Tuple[Node, Dict]:
UnexpectedDatasetReferenceError: If the source dataset is not found.
ColumnNotFoundError: If the column is not found in the schema.
"""
from opteryx.components.binder import BindingContext

def create_variable_node(node: Node, context: Dict[str, Any]) -> Node:
def create_variable_node(node: Node, context: BindingContext) -> Node:
"""Populates a Node object for a variable."""
schema_column = context.connection.variables.as_column(node.value)
new_node = Node(
Expand All @@ -122,11 +123,6 @@ def create_variable_node(node: Node, context: Dict[str, Any]) -> Node:
)
return new_node

# Check if the identifier is a variable
if node.current_name[0] == "@":
node = create_variable_node(node, context)
return node, context

# get the list of candidate schemas
if node.source:
candidate_schemas = {
Expand All @@ -148,6 +144,12 @@ def create_variable_node(node: Node, context: Dict[str, Any]) -> Node:

# if we didn't find the column, suggest alternatives
if not column:
# Check if the identifier is a variable
if node.current_name[0] == "@":
node = create_variable_node(node, context)
context.schemas["$derived"].columns.append(node.schema_column)
return node, context

from opteryx.utils import suggest_alternative

suggestion = suggest_alternative(
Expand Down Expand Up @@ -198,6 +200,7 @@ def inner_binder(node: Node, context: Dict[str, Any], step: str) -> Tuple[Node,
if node_type in (NodeType.IDENTIFIER, NodeType.EVALUATED):
return locate_identifier(node, context)

# Expression Lists are part of how CASE statements are represented
if node_type == NodeType.EXPRESSION_LIST:
node.value, new_contexts = zip(*(inner_binder(parm, context, step) for parm in node.value))
merged_schemas = merge_schemas(*[ctx.schemas for ctx in new_contexts])
Expand Down
86 changes: 85 additions & 1 deletion opteryx/components/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from orso.schema import FlatColumn
from orso.schema import RelationSchema
from orso.tools import random_string
from orso.types import OrsoTypes

from opteryx.components.binder.binder import inner_binder
from opteryx.components.binder.binder import locate_identifier_in_loaded_schemas
Expand All @@ -26,6 +27,7 @@
from opteryx.exceptions import AmbiguousDatasetError
from opteryx.exceptions import ColumnNotFoundError
from opteryx.exceptions import InvalidInternalStateError
from opteryx.exceptions import UnsupportedSyntaxError
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import get_all_nodes_of_type
from opteryx.models import LogicalColumn
Expand Down Expand Up @@ -308,6 +310,16 @@ def visit_aggregate_and_group(

visit_aggregate = visit_aggregate_and_group

def visit_distinct(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
if node.on:
# Bind the local columns to physical columns
node.on, group_contexts = zip(
*(inner_binder(col, context, node.identity) for col in node.on)
)
context.schemas = merge_schemas(*[ctx.schemas for ctx in group_contexts])

return node, context

def visit_exit(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
# LOG: Exit

Expand Down Expand Up @@ -398,7 +410,6 @@ def visit_function_dataset(
context.schemas[relation_name] = schema
node.columns = columns
elif node.function == "GENERATE_SERIES":
node.alias = node.alias or "generate_series"
node.relation_name = node.alias
columns = [
LogicalColumn(
Expand All @@ -414,6 +425,64 @@ def visit_function_dataset(
)
context.schemas[node.relation_name] = schema
node.columns = columns
elif node.function == "FAKE":
from orso.schema import ColumnDisposition

node.relation_name = node.alias
node.rows = int(node.args[0].value)

if node.args[1].node_type == NodeType.NESTED:
column_definition = [node.args[1].centre]
else:
column_definition = node.args[1].value

special_handling = {
"NAME": (OrsoTypes.VARCHAR, ColumnDisposition.NAME),
"AGE": (OrsoTypes.INTEGER, ColumnDisposition.AGE),
}

columns = []
if isinstance(column_definition, tuple):
for i, column_type in enumerate(column_definition):
name = node.columns[i] if i < len(node.columns) else f"column_{i}"
column_type = str(column_type).upper()
if column_type in special_handling:
actual_type, disposition = special_handling[column_type]
schema_column = FlatColumn(
name=name, type=actual_type, disposition=disposition
)
else:
schema_column = FlatColumn(name=name, type=column_type)
columns.append(
LogicalColumn(
node_type=NodeType.IDENTIFIER,
source_column=schema_column.name,
source=node.alias,
schema_column=schema_column,
)
)
node.columns = columns
else:
column_definition = int(column_definition)
names = node.columns + tuple(
f"column_{i}" for i in range(len(node.columns), column_definition)
)
node.columns = [
LogicalColumn(
node_type=NodeType.IDENTIFIER,
source_column=names[i],
source=node.alias,
schema_column=FlatColumn(name=names[i], type=OrsoTypes.INTEGER),
)
for i in range(column_definition)
]

schema = RelationSchema(
name=node.relation_name,
columns=[c.schema_column for c in node.columns],
)
context.schemas[node.relation_name] = schema
node.schema = schema
else:
raise NotImplementedError(f"{node.function} does not exist")
return node, context
Expand Down Expand Up @@ -463,6 +532,11 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding

raise IncompatibleTypesError(**mismatches)

if get_all_nodes_of_type(node.on, (NodeType.LITERAL,)):
raise UnsupportedSyntaxError(
"JOIN conditions cannot include literal constant values."
)

if node.using:
# Remove the columns used in the join condition from both relations, they're in
# the result set but not belonging to either table, whilst still belonging to both.
Expand All @@ -489,6 +563,15 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
context.schemas[f"$shared-{random_string()}"] = RelationSchema(
name=f"^{left_relation_name}#^{right_relation_name}#", columns=columns
)

# SEMI and ANTI joins only return columns from one table
if node.type in ("left anti", "left semi"):
for schema in node.left_relation_names:
context.schemas.pop(schema)
if node.type in ("right anti", "right semi"):
for schema in node.right_relation_names:
context.schemas.pop(schema)

if node.column:
if not node.alias:
node.alias = f"UNNEST({node.column.query_column})"
Expand Down Expand Up @@ -577,6 +660,7 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind
# We always have a $derived schema, even if it's empty
if "$derived" in context.schemas:
context.schemas["$project"] = context.schemas.pop("$derived")
context.schemas["$project"].name = "$project"
if not "$derived" in context.schemas:
context.schemas["$derived"] = derived.schema()

Expand Down
3 changes: 3 additions & 0 deletions opteryx/components/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def create_node_relation(relation):
else function["alias"]["name"]["value"]
)
function_step.args = [logical_planner_builders.build(arg) for arg in function["args"]]
function_step.columns = tuple(col["value"] for col in function["alias"]["columns"])

step_id = random_string()
sub_plan.add_node(step_id, function_step)
Expand Down Expand Up @@ -653,6 +654,8 @@ def plan_query(statement):
parent_plan_exit_id = parent_plan.get_entry_points()[0]
plan.add_edge(step_id, parent_plan_exit_id)

raise UnsupportedSyntaxError("Set operators are not supported")

return plan

# we do some minor AST rewriting
Expand Down
12 changes: 6 additions & 6 deletions opteryx/components/logical_planner_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,12 +483,12 @@ def hex_literal(branch, alias=None, key=None):


def tuple_literal(branch, alias=None, key=None):
return Node(
NodeType.LITERAL,
type=OrsoTypes.ARRAY,
value=[build(t["Value"]).value for t in branch],
alias=alias,
)
print(branch)
values = [build(t).value for t in branch]
if values and isinstance(values[0], dict):
values = [build(val["Identifier"]).value for val in values]
print(values)
return Node(NodeType.LITERAL, type=OrsoTypes.ARRAY, value=tuple(values), alias=alias)


def substring(branch, alias=None, key=None):
Expand Down
7 changes: 4 additions & 3 deletions opteryx/operators/distinct_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from pyarrow import concat_tables

from opteryx.exceptions import SqlError
from opteryx.models import QueryProperties
from opteryx.operators import BasePlanNode
from opteryx.third_party.pyarrow_ops import drop_duplicates
Expand All @@ -30,7 +29,9 @@
class DistinctNode(BasePlanNode):
def __init__(self, properties: QueryProperties, **config):
super().__init__(properties=properties)
self._distinct = config.get("distinct", True)
self._distinct_on = config.get("on")
if self._distinct_on:
self._distinct_on = [col.schema_column.identity for col in self._distinct_on]

@property
def config(self): # pragma: no cover
Expand All @@ -47,4 +48,4 @@ def name(self): # pragma: no cover
def execute(self) -> Iterable:
morsels = self._producers[0] # type:ignore

yield drop_duplicates(concat_tables(morsels.execute()))
yield drop_duplicates(concat_tables(morsels.execute()), self._distinct_on)
19 changes: 12 additions & 7 deletions opteryx/operators/function_dataset_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
This Node creates datasets based on function calls like VALUES and UNNEST.
"""
import random
import time
from typing import Iterable

Expand Down Expand Up @@ -51,11 +50,14 @@ def _values(**parameters):
return [{columns[i]: value.value for i, value in enumerate(values)} for values in values_array]


def _fake_data(alias, *args):
rows, columns = int(args[0].value), int(args[1].value)
return [
{f"column_{col}": random.getrandbits(16) for col in range(columns)} for row in range(rows)
]
def _fake_data(**kwargs):
from orso.faker import generate_fake_data

rows = kwargs["rows"]
schema = kwargs["schema"]
for column in schema.columns:
column.name = column.identity
return generate_fake_data(schema, rows)


FUNCTIONS = {
Expand Down Expand Up @@ -103,7 +105,10 @@ def execute(self) -> Iterable:
)
raise err

table = pyarrow.Table.from_pylist(data)
if isinstance(data, list):
table = pyarrow.Table.from_pylist(data)
if hasattr(data, "arrow"):
table = data.arrow()

self.statistics.rows_read += table.num_rows
self.statistics.columns_read += len(table.column_names)
Expand Down
1 change: 0 additions & 1 deletion opteryx/operators/limit_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import time
from typing import Iterable

from opteryx.exceptions import SqlError
from opteryx.models import QueryProperties
from opteryx.operators import BasePlanNode
from opteryx.utils import arrow
Expand Down
2 changes: 1 addition & 1 deletion opteryx/operators/projection_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, properties: QueryProperties, **config):

@property
def config(self): # pragma: no cover
return str(self._projection)
return str(self.projection)

@property
def name(self): # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion opteryx/third_party/pyarrow_ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyarrow
from pyarrow import compute

from .helpers import columns_to_array
from opteryx.third_party.pyarrow_ops.helpers import columns_to_array

# Added for Opteryx, comparisons in filter_operators updated to match
# this set is from sqloxide
Expand Down
8 changes: 8 additions & 0 deletions opteryx/utils/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def limit_records(
"""
remaining_rows = limit if limit is not None else float("inf")
rows_left_to_skip = max(0, offset)
at_least_one = False

for morsel in morsels:
if rows_left_to_skip > 0:
Expand All @@ -48,8 +49,15 @@ def limit_records(
if morsel.num_rows > 0:
if morsel.num_rows < remaining_rows:
yield morsel
at_least_one = True
else:
yield morsel.slice(offset=0, length=remaining_rows)
at_least_one = True

if not at_least_one:
# make sure we return at least an empty morsel from this function
yield morsel.slice(offset=0, length=0)
at_least_one = True

remaining_rows -= morsel.num_rows
if remaining_rows <= 0:
Expand Down
Loading

0 comments on commit 0781a2a

Please sign in to comment.