Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Sep 15, 2023
1 parent 6365b4d commit 6cfd3dc
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 54 deletions.
11 changes: 11 additions & 0 deletions opteryx/components/logical_planner_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,16 @@ def nested(branch, alias=None, key=None):
)


def hex_literal(branch, alias=None, key=None):
value = int(branch, 16)
return Node(
NodeType.LITERAL,
type=OrsoTypes.INTEGER,
value=value,
# alias=alias or f"0x{branch}"
)


def tuple_literal(branch, alias=None, key=None):
return Node(
NodeType.LITERAL,
Expand Down Expand Up @@ -661,6 +671,7 @@ def build(value, alias=None, key=None):
"Extract": extract,
"Floor": floor,
"Function": function,
"HexStringLiteral": hex_literal,
"Identifier": identifier,
"ILike": pattern_match,
"InList": in_list,
Expand Down
87 changes: 48 additions & 39 deletions opteryx/functions/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,39 @@
# limitations under the License.

import datetime
from typing import Any
from typing import Dict
from typing import Union

import numpy
import pyarrow
from pyarrow import compute

from opteryx.utils import dates

BINARY_OPERATORS = (
"Divide",
"Minus",
"Modulo",
"Multiply",
"Plus",
"StringConcat",
"MyIntegerDivide",
)
OPERATOR_FUNCTION_MAP: Dict[str, Any] = {
"Divide": numpy.divide,
"Minus": numpy.subtract,
"Modulo": numpy.mod,
"Multiply": numpy.multiply,
"Plus": numpy.add,
"StringConcat": compute.binary_join_element_wise,
"MyIntegerDivide": lambda left, right: numpy.trunc(numpy.divide(left, right)).astype(
numpy.int64
),
"BitwiseOr": numpy.bitwise_or,
"BitwiseAnd": numpy.bitwise_and,
"BitwiseXor": numpy.bitwise_xor,
"ShiftLeft": numpy.left_shift,
"ShiftRight": numpy.right_shift,
}

BINARY_OPERATORS = set(OPERATOR_FUNCTION_MAP.keys())

INTERVALS = (pyarrow.lib.MonthDayNano, pyarrow.lib.MonthDayNanoIntervalArray)

# Also supported by the AST but not implemented
# BitwiseOr => ("|"),
# BitwiseAnd => ("&"),
# BitwiseXor => ("^"),

# PGBitwiseXor => ("#"), -- not supported in mysql
# PGBitwiseShiftLeft => ("<<"), -- not supported in mysql
# PGBitwiseShiftRight => (">>"), -- not supported in mysql
Expand Down Expand Up @@ -97,38 +108,36 @@ def _has_intervals(left, right):
)


def binary_operations(left, operator, right):
def binary_operations(left, operator: str, right) -> Union[numpy.ndarray, pyarrow.Array]:
"""
Execute inline operators (e.g. the add in 3 + 4)
Execute inline operators (e.g. the add in 3 + 4).
Parameters:
left: Union[numpy.ndarray, pyarrow.Array]
The left operand
operator: str
The operator to be applied
right: Union[numpy.ndarray, pyarrow.Array]
The right operand
Returns:
Union[numpy.ndarray, pyarrow.Array]
The result of the binary operation
"""
operation = OPERATOR_FUNCTION_MAP.get(operator)

# if all of the values are null
# if (
# compute.is_null(left, nan_is_null=True).false_count == 0
# or compute.is_null(right, nan_is_null=True).false_count == 0
# ):
# return numpy.full(right.size, False)

# new operations for Opteryx
if operator == "Divide":
return numpy.divide(left, right)
if operator == "Minus":
if _has_intervals(left, right):
return _date_minus_interval(left, right)
return numpy.subtract(left, right)
if operator == "Modulo":
return numpy.mod(left, right)
if operator == "Multiply":
return numpy.multiply(left, right)
if operator == "Plus":
if operation is None:
raise NotImplementedError(f"Operator `{operator}` is not implemented!")

if operator == "Minus" or operator == "Plus":
if _has_intervals(left, right):
return _date_plus_interval(left, right)
return numpy.add(left, right)
return (
_date_minus_interval(left, right)
if operator == "Minus"
else _date_plus_interval(left, right)
)

if operator == "StringConcat":
empty = numpy.full(len(left), "")
joined = compute.binary_join_element_wise(left, right, empty)
return joined
if operator == "MyIntegerDivide":
return numpy.trunc(numpy.divide(left, right)).astype(numpy.int64)

raise NotImplementedError(f"Operator `{operator}` is not implemented!")
return operation(left, right)
1 change: 0 additions & 1 deletion opteryx/functions/unary_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
This are executed as functions on arrays rather than functions on elements in arrays.
"""
from decimal import Decimal

import numpy

Expand Down
27 changes: 13 additions & 14 deletions opteryx/managers/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
Expressions are evaluated against an entire morsel at a time.
"""
from enum import Enum
from typing import Callable
from typing import Dict
from typing import Optional

import numpy
Expand Down Expand Up @@ -94,6 +96,12 @@ class NodeType(int, Enum):
OrsoTypes.NULL: numpy.dtype("O"),
}

LOGICAL_OPERATIONS: Dict[NodeType, Callable] = {
NodeType.AND: pyarrow.compute.and_,
NodeType.OR: pyarrow.compute.or_,
NodeType.XOR: pyarrow.compute.xor,
}


class ExecutionContext:
def __init__(self):
Expand Down Expand Up @@ -161,23 +169,14 @@ def _inner_evaluate(root: Node, table: Table, context: ExecutionContext):

# BOOLEAN OPERATORS
if node_type & LOGICAL_TYPE == LOGICAL_TYPE:
left, right, centre = None, None, None

if root.left is not None:
left = _inner_evaluate(root.left, table, context)
if root.right is not None:
right = _inner_evaluate(root.right, table, context)
if root.centre is not None:
centre = _inner_evaluate(root.centre, table, context)
if node_type in LOGICAL_OPERATIONS:
left = _inner_evaluate(root.left, table, context) if root.left else None
right = _inner_evaluate(root.right, table, context) if root.right else None
return LOGICAL_OPERATIONS[node_type](left, right)

if node_type == NodeType.AND:
return pyarrow.compute.and_(left, right)
if node_type == NodeType.OR:
return pyarrow.compute.or_(left, right)
if node_type == NodeType.NOT:
centre = _inner_evaluate(root.centre, table, context) if root.centre else None
return pyarrow.compute.invert(centre)
if node_type == NodeType.XOR:
return pyarrow.compute.xor(left, right)

# INTERAL IDENTIFIERS
if node_type & INTERNAL_TYPE == INTERNAL_TYPE:
Expand Down
6 changes: 6 additions & 0 deletions opteryx/managers/expression/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def format_expression(root):
"Multiply": "*",
"Divide": "/",
"MyIntegerDivide": "div",
"Modulo": "%",
"BitwiseOr": "|",
"BitwiseAnd": "&",
"BitwiseXor": "^",
"ShiftLeft": "<<",
"ShiftRight": ">>",
}
return f"{format_expression(root.left)} {_map.get(root.value, root.value).upper()} {format_expression(root.right)}"
if node_type == NodeType.COMPARISON_OPERATOR:
Expand Down
4 changes: 4 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@
("SELECT * FROM $satellites WHERE id = (3 * 1) + 2", 1, 8, None),
("SELECT * FROM $satellites WHERE id = 6 DIV (3 + 1)", 1, 8, None),
("SELECT * FROM $satellites WHERE id BETWEEN 4 AND 6", 3, 8, None),
("SELECT * FROM $satellites WHERE id ^ 1", 176, 8, None),
("SELECT * FROM $satellites WHERE id & 1", 89, 8, None),
("SELECT * FROM $satellites WHERE id | 1", 177, 8, None),
("SELECT * FROM $satellites WHERE id = 0x08", 1, 8, None),

("SELECT * FROM $satellites WHERE magnitude = 5.29", 1, 8, None),
("SELECT * FROM $satellites WHERE id = 5 AND magnitude = 5.29", 1, 8, None),
Expand Down

0 comments on commit 6cfd3dc

Please sign in to comment.