Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Oct 25, 2023
1 parent e34d519 commit 76e78fa
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 22 deletions.
5 changes: 1 addition & 4 deletions opteryx/components/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
The plan does not try to be efficient or clever, at this point it is only trying to be correct.
"""

import os
import sys
from enum import Enum
from enum import auto
from typing import List
Expand All @@ -61,8 +59,6 @@
from opteryx.models import Node
from opteryx.third_party.travers import Graph

sys.path.insert(1, os.path.join(sys.path[0], "../../../..")) # isort:skip


class LogicalPlanStepType(int, Enum):
Project = auto() # field selection
Expand Down Expand Up @@ -746,6 +742,7 @@ def plan_show_variables(statement):
previous_step_id, step_id = step_id, random_string()
plan.add_node(step_id, select_step)
plan.add_edge(previous_step_id, step_id)
raise UnsupportedSyntaxError("Cannot filter by Variable Names")

exit_step = LogicalPlanNode(node_type=LogicalPlanStepType.Exit)
exit_step.columns = [Node(node_type=NodeType.WILDCARD)] # We are always SELECT *
Expand Down
30 changes: 18 additions & 12 deletions opteryx/functions/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import array
import datetime
from typing import Any
from typing import Dict
from typing import Tuple
from typing import Union

import numpy
Expand Down Expand Up @@ -41,6 +43,7 @@
BINARY_OPERATORS = set(OPERATOR_FUNCTION_MAP.keys())

INTERVALS = (pyarrow.lib.MonthDayNano, pyarrow.lib.MonthDayNanoIntervalArray)
LISTS = (pyarrow.Array, numpy.ndarray, list, array.ArrayType)

# Also supported by the AST but not implemented

Expand All @@ -51,7 +54,7 @@

def _date_plus_interval(left, right):
# left is the date, right is the interval
if type(left) in INTERVALS or (isinstance(left, list) and type(left[0]) in INTERVALS):
if isinstance(left, INTERVALS) or (isinstance(left, LISTS) and type(left[0]) in INTERVALS):
left, right = right, left

result = []
Expand All @@ -62,11 +65,11 @@ def _date_plus_interval(left, right):
interval = interval.value
months = interval.months
days = interval.days
nano = interval.nanoseconds
nanoseconds = interval.nanoseconds

date = dates.parse_iso(date)
date = date + datetime.timedelta(days=days)
date = date + datetime.timedelta(microseconds=(nano * 1000))
# Subtract days and nanoseconds (as microseconds)
date += datetime.timedelta(days=days, microseconds=nanoseconds // 1000)
date = dates.add_months(date, months)

result.append(date)
Expand All @@ -76,7 +79,7 @@ def _date_plus_interval(left, right):

def _date_minus_interval(left, right):
# left is the date, right is the interval
if type(left) in INTERVALS or (isinstance(left, list) and type(left[0]) in INTERVALS):
if isinstance(left, INTERVALS) or (isinstance(left, LISTS) and type(left[0]) in INTERVALS):
left, right = right, left

result = []
Expand All @@ -87,11 +90,11 @@ def _date_minus_interval(left, right):
interval = interval.value
months = interval.months
days = interval.days
nano = interval.nanoseconds
nanoseconds = interval.nanoseconds

date = dates.parse_iso(date)
date = date - datetime.timedelta(days=days)
date = date - datetime.timedelta(microseconds=(nano * 1000))
# Subtract days and nanoseconds (as microseconds)
date -= datetime.timedelta(days=days, microseconds=nanoseconds // 1000)
date = dates.add_months(date, (0 - months))

result.append(date)
Expand All @@ -100,11 +103,14 @@ def _date_minus_interval(left, right):


def _has_intervals(left, right):
def _check_type(obj, types: Union[type, Tuple[type, ...]]) -> bool:
return any(isinstance(obj, t) for t in types)

return (
type(left) in INTERVALS
or type(right) in INTERVALS
or (isinstance(left, list) and type(left[0]) in INTERVALS)
or (isinstance(right, list) and type(right[0]) in INTERVALS)
_check_type(left, INTERVALS)
or _check_type(right, INTERVALS)
or (_check_type(left, LISTS) and _check_type(left[0], INTERVALS))
or (_check_type(right, LISTS) and _check_type(right[0], INTERVALS))
)


Expand Down
3 changes: 2 additions & 1 deletion opteryx/managers/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def _inner_evaluate(root: Node, table: Table, context: ExecutionContext):
if literal_type == OrsoTypes.VARCHAR:
return numpy.array([root.value] * table.num_rows, dtype=numpy.unicode_)
if literal_type == OrsoTypes.INTERVAL:
return pyarrow.array([root.value] * table.num_rows)
value = pyarrow.MonthDayNano(root.value)
return pyarrow.array([value] * table.num_rows)
return numpy.full(
shape=table.num_rows, fill_value=root.value, dtype=ORSO_TO_NUMPY_MAP[literal_type]
) # type:ignore
Expand Down
14 changes: 12 additions & 2 deletions opteryx/utils/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,28 @@
UNIX_EPOCH: datetime.date = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)


def add_months(start_date, number_of_months):
def add_months(start_date: datetime.datetime, number_of_months: int):
"""
Add months to a date, makes assumptions about how to handle the end of the month.
"""
new_year, new_month = divmod(start_date.month - 1 + number_of_months, 12)
new_year += start_date.year
new_month += 1
# Ensure the month is valid
new_month = min(max(1, new_month), 12)
last_day_of_month = (
datetime.datetime(new_year, new_month % 12 + 1, 1) - datetime.timedelta(days=1)
).day
new_day = min(start_date.day, last_day_of_month)
return datetime.datetime(new_year, new_month, new_day)
return datetime.datetime(
new_year,
new_month,
new_day,
start_date.hour,
start_date.minute,
start_date.second,
start_date.microsecond,
)


def add_interval(
Expand Down
6 changes: 3 additions & 3 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,9 +937,9 @@
("SET @id = 3; SELECT name FROM $planets WHERE id < @id OR id > @id;", 8, 1, None),
("SET @dob = '1950-01-01'; SELECT name FROM $astronauts WHERE birth_date < @dob;", 149, 1, None),
("SET @dob = '1950-01-01'; SET @mission = 'Apollo 11'; SELECT name FROM $astronauts WHERE birth_date < @dob AND @mission IN UNNEST(missions);", 3, 1, None),
("SET @pples = 'b'; SET @ngles = 90; SHOW VARIABLES LIKE '@%s'", 2, 4, None),
("SET @pples = 'b'; SET @rgon = 90; SHOW VARIABLES LIKE '@%gon'", 1, 4, None),
("SET @variable = 44; SET @var = 'name'; SHOW VARIABLES LIKE '@%ri%';", 1, 4, None),
("SET @pples = 'b'; SET @ngles = 90; SHOW VARIABLES LIKE '@%s'", 2, 4, UnsupportedSyntaxError),
("SET @pples = 'b'; SET @rgon = 90; SHOW VARIABLES LIKE '@%gon'", 1, 4, UnsupportedSyntaxError),
("SET @variable = 44; SET @var = 'name'; SHOW VARIABLES LIKE '@%ri%';", 1, 4, UnsupportedSyntaxError),
("SHOW PARAMETER disable_optimizer", 1, 2, None),
("SET disable_optimizer = true; SHOW PARAMETER disable_optimizer;", 1, 2, None),
]
Expand Down

0 comments on commit 76e78fa

Please sign in to comment.