diff --git a/opteryx/components/logical_planner_builders.py b/opteryx/components/logical_planner_builders.py index e67b3b9e3..b341b1a23 100644 --- a/opteryx/components/logical_planner_builders.py +++ b/opteryx/components/logical_planner_builders.py @@ -471,6 +471,16 @@ def nested(branch, alias=None, key=None): ) +def hex_literal(branch, alias=None, key=None): + value = int(branch, 16) + return Node( + NodeType.LITERAL, + type=OrsoTypes.INTEGER, + value=value, + # alias=alias or f"0x{branch}" + ) + + def tuple_literal(branch, alias=None, key=None): return Node( NodeType.LITERAL, @@ -661,6 +671,7 @@ def build(value, alias=None, key=None): "Extract": extract, "Floor": floor, "Function": function, + "HexStringLiteral": hex_literal, "Identifier": identifier, "ILike": pattern_match, "InList": in_list, diff --git a/opteryx/functions/binary_operators.py b/opteryx/functions/binary_operators.py index 3e5dcc40b..5855e7e97 100644 --- a/opteryx/functions/binary_operators.py +++ b/opteryx/functions/binary_operators.py @@ -11,6 +11,9 @@ # limitations under the License. import datetime +from typing import Any +from typing import Dict +from typing import Union import numpy import pyarrow @@ -18,21 +21,29 @@ from opteryx.utils import dates -BINARY_OPERATORS = ( - "Divide", - "Minus", - "Modulo", - "Multiply", - "Plus", - "StringConcat", - "MyIntegerDivide", -) +OPERATOR_FUNCTION_MAP: Dict[str, Any] = { + "Divide": numpy.divide, + "Minus": numpy.subtract, + "Modulo": numpy.mod, + "Multiply": numpy.multiply, + "Plus": numpy.add, + "StringConcat": compute.binary_join_element_wise, + "MyIntegerDivide": lambda left, right: numpy.trunc(numpy.divide(left, right)).astype( + numpy.int64 + ), + "BitwiseOr": numpy.bitwise_or, + "BitwiseAnd": numpy.bitwise_and, + "BitwiseXor": numpy.bitwise_xor, + "ShiftLeft": numpy.left_shift, + "ShiftRight": numpy.right_shift, +} + +BINARY_OPERATORS = set(OPERATOR_FUNCTION_MAP.keys()) + INTERVALS = (pyarrow.lib.MonthDayNano, pyarrow.lib.MonthDayNanoIntervalArray) # Also supported by the AST but not implemented -# BitwiseOr => ("|"), -# BitwiseAnd => ("&"), -# BitwiseXor => ("^"), + # PGBitwiseXor => ("#"), -- not supported in mysql # PGBitwiseShiftLeft => ("<<"), -- not supported in mysql # PGBitwiseShiftRight => (">>"), -- not supported in mysql @@ -97,38 +108,36 @@ def _has_intervals(left, right): ) -def binary_operations(left, operator, right): +def binary_operations(left, operator: str, right) -> Union[numpy.ndarray, pyarrow.Array]: """ - Execute inline operators (e.g. the add in 3 + 4) + Execute inline operators (e.g. the add in 3 + 4). + + Parameters: + left: Union[numpy.ndarray, pyarrow.Array] + The left operand + operator: str + The operator to be applied + right: Union[numpy.ndarray, pyarrow.Array] + The right operand + Returns: + Union[numpy.ndarray, pyarrow.Array] + The result of the binary operation """ + operation = OPERATOR_FUNCTION_MAP.get(operator) - # if all of the values are null - # if ( - # compute.is_null(left, nan_is_null=True).false_count == 0 - # or compute.is_null(right, nan_is_null=True).false_count == 0 - # ): - # return numpy.full(right.size, False) - - # new operations for Opteryx - if operator == "Divide": - return numpy.divide(left, right) - if operator == "Minus": - if _has_intervals(left, right): - return _date_minus_interval(left, right) - return numpy.subtract(left, right) - if operator == "Modulo": - return numpy.mod(left, right) - if operator == "Multiply": - return numpy.multiply(left, right) - if operator == "Plus": + if operation is None: + raise NotImplementedError(f"Operator `{operator}` is not implemented!") + + if operator == "Minus" or operator == "Plus": if _has_intervals(left, right): - return _date_plus_interval(left, right) - return numpy.add(left, right) + return ( + _date_minus_interval(left, right) + if operator == "Minus" + else _date_plus_interval(left, right) + ) + if operator == "StringConcat": empty = numpy.full(len(left), "") joined = compute.binary_join_element_wise(left, right, empty) return joined - if operator == "MyIntegerDivide": - return numpy.trunc(numpy.divide(left, right)).astype(numpy.int64) - - raise NotImplementedError(f"Operator `{operator}` is not implemented!") + return operation(left, right) diff --git a/opteryx/functions/unary_operations.py b/opteryx/functions/unary_operations.py index 9469a968f..79d67e1df 100644 --- a/opteryx/functions/unary_operations.py +++ b/opteryx/functions/unary_operations.py @@ -15,7 +15,6 @@ This are executed as functions on arrays rather than functions on elements in arrays. """ -from decimal import Decimal import numpy diff --git a/opteryx/managers/expression/__init__.py b/opteryx/managers/expression/__init__.py index 86ec27d30..0400091be 100644 --- a/opteryx/managers/expression/__init__.py +++ b/opteryx/managers/expression/__init__.py @@ -18,6 +18,8 @@ Expressions are evaluated against an entire morsel at a time. """ from enum import Enum +from typing import Callable +from typing import Dict from typing import Optional import numpy @@ -94,6 +96,12 @@ class NodeType(int, Enum): OrsoTypes.NULL: numpy.dtype("O"), } +LOGICAL_OPERATIONS: Dict[NodeType, Callable] = { + NodeType.AND: pyarrow.compute.and_, + NodeType.OR: pyarrow.compute.or_, + NodeType.XOR: pyarrow.compute.xor, +} + class ExecutionContext: def __init__(self): @@ -161,23 +169,14 @@ def _inner_evaluate(root: Node, table: Table, context: ExecutionContext): # BOOLEAN OPERATORS if node_type & LOGICAL_TYPE == LOGICAL_TYPE: - left, right, centre = None, None, None - - if root.left is not None: - left = _inner_evaluate(root.left, table, context) - if root.right is not None: - right = _inner_evaluate(root.right, table, context) - if root.centre is not None: - centre = _inner_evaluate(root.centre, table, context) + if node_type in LOGICAL_OPERATIONS: + left = _inner_evaluate(root.left, table, context) if root.left else None + right = _inner_evaluate(root.right, table, context) if root.right else None + return LOGICAL_OPERATIONS[node_type](left, right) - if node_type == NodeType.AND: - return pyarrow.compute.and_(left, right) - if node_type == NodeType.OR: - return pyarrow.compute.or_(left, right) if node_type == NodeType.NOT: + centre = _inner_evaluate(root.centre, table, context) if root.centre else None return pyarrow.compute.invert(centre) - if node_type == NodeType.XOR: - return pyarrow.compute.xor(left, right) # INTERAL IDENTIFIERS if node_type & INTERNAL_TYPE == INTERNAL_TYPE: diff --git a/opteryx/managers/expression/formatter.py b/opteryx/managers/expression/formatter.py index dfed25044..a3545ae6d 100644 --- a/opteryx/managers/expression/formatter.py +++ b/opteryx/managers/expression/formatter.py @@ -59,6 +59,12 @@ def format_expression(root): "Multiply": "*", "Divide": "/", "MyIntegerDivide": "div", + "Modulo": "%", + "BitwiseOr": "|", + "BitwiseAnd": "&", + "BitwiseXor": "^", + "ShiftLeft": "<<", + "ShiftRight": ">>", } return f"{format_expression(root.left)} {_map.get(root.value, root.value).upper()} {format_expression(root.right)}" if node_type == NodeType.COMPARISON_OPERATOR: diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index 350e4090c..21021d985 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -266,6 +266,10 @@ ("SELECT * FROM $satellites WHERE id = (3 * 1) + 2", 1, 8, None), ("SELECT * FROM $satellites WHERE id = 6 DIV (3 + 1)", 1, 8, None), ("SELECT * FROM $satellites WHERE id BETWEEN 4 AND 6", 3, 8, None), + ("SELECT * FROM $satellites WHERE id ^ 1", 176, 8, None), + ("SELECT * FROM $satellites WHERE id & 1", 89, 8, None), + ("SELECT * FROM $satellites WHERE id | 1", 177, 8, None), + ("SELECT * FROM $satellites WHERE id = 0x08", 1, 8, None), ("SELECT * FROM $satellites WHERE magnitude = 5.29", 1, 8, None), ("SELECT * FROM $satellites WHERE id = 5 AND magnitude = 5.29", 1, 8, None),