Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Nov 22, 2024
1 parent edc73b7 commit 6bf5d70
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 120 deletions.
9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
name = "compute"
version = "0.1.0"
authors = ["@joocer"]
edition = "2018"
edition = "2021"

[lib]
name = "compute"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20.3", features = ["extension-module", "abi3-py39"] }
pythonize = "0.20"
pythonize = "0.22"
serde = "1.0.171"

[dependencies.pyo3]
version = "0.22"
features = ["extension-module"]

[dependencies.sqlparser]
version = "0.52.0"
features = ["serde", "visitor"]
59 changes: 59 additions & 0 deletions opteryx/managers/expression/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def filter_operations(arr, left_type, operator, value, right_type):
"AnyOpGtEq",
"AnyOpLt",
"AnyOpLtEq",
"AnyOpLike",
"AnyOpNotLike",
"AnyOpILike",
"AnyOpNotILike",
"AnyOpRLike",
"AnyOpNotRLike",
"AllOpEq",
"AllOpNotEq",
"AtArrow",
Expand Down Expand Up @@ -170,6 +176,59 @@ def _inner_filter_operations(arr, operator, value):
if operator == "AllOpNotEq":
return list_ops.cython_allop_neq(arr[0], value)

if operator == "AnyOpLike":
patterns = value[0]
return numpy.array(
[
None
if row is None
else any(compute.match_like(row, pattern).true_count > 0 for pattern in patterns)
for row in arr
],
dtype=bool,
)
if operator == "AnyOpNotLike":
patterns = value[0]
matches = numpy.array(
[
None
if row is None
else any(compute.match_like(row, pattern).true_count > 0 for pattern in patterns)
for row in arr
],
dtype=bool,
)
return numpy.invert(matches)
if operator == "AnyOpILike":
patterns = value[0]
return numpy.array(
[
None
if row is None
else any(
compute.match_like(row, pattern, ignore_case=True).true_count > 0
for pattern in patterns
)
for row in arr
],
dtype=bool,
)
if operator == "AnyOpNotILike":
patterns = value[0]
matches = numpy.array(
[
None
if row is None
else any(
compute.match_like(row, pattern, ignore_case=True).true_count > 0
for pattern in patterns
)
for row in arr
],
dtype=bool,
)
return numpy.invert(matches)

if operator == "AtQuestion":
element = value[0]

Expand Down
7 changes: 0 additions & 7 deletions opteryx/models/serial_engine.py

This file was deleted.

12 changes: 6 additions & 6 deletions opteryx/planner/logical_planner/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def __str__(self):
# fmt:off
node_type = self.node_type
if node_type == LogicalPlanStepType.AggregateAndGroup:
return f"AGGREGATE ({', '.join(format_expression(col) for col in self.aggregates)}) GROUP BY ({', '.join(format_expression(col) for col in self.groups)})"
return f"AGGREGATE [{', '.join(format_expression(col) for col in self.aggregates)}] GROUP BY [{', '.join(format_expression(col) for col in self.groups)}]"
if node_type == LogicalPlanStepType.Aggregate:
return f"AGGREGATE ({', '.join(format_expression(col) for col in self.aggregates)})"
return f"AGGREGATE [{', '.join(format_expression(col) for col in self.aggregates)}]"
if node_type == LogicalPlanStepType.Distinct:
distinct_on = ""
if self.on is not None:
distinct_on = f" ON ({','.join(format_expression(col) for col in self.on)})"
distinct_on = f" ON [{','.join(format_expression(col) for col in self.on)}]"
return f"DISTINCT{distinct_on}"
if node_type == LogicalPlanStepType.Explain:
return f"EXPLAIN{' ANALYZE' if self.analyze else ''}{(' (' + self.format + ')') if self.format else ''}"
Expand All @@ -111,16 +111,16 @@ def __str__(self):
return f"{self.type.upper()} JOIN{distinct} (USING {','.join(map(format_expression, self.using))}){filters}"
return f"{self.type.upper()}{distinct} {filters}"
if node_type == LogicalPlanStepType.HeapSort:
return f"HEAP SORT (LIMIT {self.limit}, ORDER BY {', '.join(format_expression(item[0]) + (' DESC' if item[1] =='descending' else '') for item in self.order_by)})"
return f"HEAP SORT (LIMIT {self.limit}, ORDER BY [{', '.join(format_expression(item[0]) + (' DESC' if item[1] =='descending' else '') for item in self.order_by)}])"
if node_type == LogicalPlanStepType.Limit:
limit_str = f"LIMIT ({self.limit})" if self.limit is not None else ""
offset_str = f" OFFSET ({self.offset})" if self.offset is not None else ""
return (limit_str + offset_str).strip()
if node_type == LogicalPlanStepType.Order:
return f"ORDER BY ({', '.join(format_expression(item[0]) + (' DESC' if item[1] =='descending' else '') for item in self.order_by)})"
return f"ORDER BY [{', '.join(format_expression(item[0]) + (' DESC' if item[1] =='descending' else '') for item in self.order_by)}]"
if node_type == LogicalPlanStepType.Project:
order_by_indicator = f" + ({', '.join(format_expression(col) for col in self.order_by_columns)})" if self.order_by_columns else ""
return f"PROJECT ({', '.join(format_expression(col) for col in self.columns)}){order_by_indicator}"
return f"PROJECT [{', '.join(format_expression(col) for col in self.columns)}]{order_by_indicator}"
if node_type == LogicalPlanStepType.Scan:
io_async = "ASYNC " if hasattr(self.connector, "async_read_blob") else ""
date_range = ""
Expand Down
8 changes: 8 additions & 0 deletions opteryx/planner/logical_planner/logical_planner_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,18 @@ def pattern_match(branch, alias: Optional[List[str]] = None, key=None):
negated = branch["negated"]
left = build(branch["expr"])
right = build(branch["pattern"])
is_any = branch.get("any", False)
if key in ("PGRegexMatch", "SimilarTo"):
key = "RLike"
if negated:
key = f"Not{key}"
if is_any:
key = f"AnyOp{key}"
if right.node_type == NodeType.NESTED:
right = right.centre
if right.type != OrsoTypes.ARRAY:
right.value = (right.value,)
right.type = OrsoTypes.ARRAY
return Node(
NodeType.COMPARISON_OPERATOR,
value=key,
Expand Down
3 changes: 1 addition & 2 deletions opteryx/third_party/sqloxide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

from opteryx.compute import parse_sql
from opteryx.compute import restore_ast

# Explicitly define the API of this module for external consumers
__all__ = ["parse_sql", "restore_ast"]
__all__ = ["parse_sql"]
45 changes: 36 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
use pythonize::pythonize;



use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use pyo3::wrap_pyfunction;

mod sqloxide;
use sqloxide::{restore_ast, parse_sql};
use sqlparser::dialect::dialect_from_str;
use sqlparser::dialect::*;
use sqlparser::parser::Parser;


/// Function to parse SQL statements from a string. Returns a list with
/// one item per query statement.
///
/// Available `dialects`: https://github.com/sqlparser-rs/sqlparser-rs/blob/main/src/dialect/mod.rs#L189-L206
#[pyfunction]
#[pyo3(text_signature = "(sql, dialect)")]
fn parse_sql(py: Python, sql: String, dialect: String) -> PyResult<PyObject> {
let chosen_dialect = dialect_from_str(dialect).unwrap_or_else(|| {
println!("The dialect you chose was not recognized, falling back to 'generic'");
Box::new(GenericDialect {})
});
let parse_result = Parser::parse_sql(&*chosen_dialect, &sql);

let output = match parse_result {
Ok(statements) => pythonize(py, &statements).map_err(|e| {
let msg = e.to_string();
PyValueError::new_err(format!("Python object serialization failed.\n\t{msg}"))
})?,
Err(e) => {
let msg = e.to_string();
return Err(PyValueError::new_err(format!(
"Query parsing failed.\n\t{msg}"
)));
}
};

Ok(output.into())
}


#[pymodule]
fn compute(_py: Python, m: &PyModule) -> PyResult<()> {
fn compute(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(parse_sql, m)?)?;
m.add_function(wrap_pyfunction!(restore_ast, m)?)?;

Ok(())
}
93 changes: 0 additions & 93 deletions src/sqloxide.rs

This file was deleted.

16 changes: 16 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,22 @@
# Aggregate Functions with HAVING Clause
("SELECT name, COUNT(*) AS count FROM $satellites GROUP BY name HAVING count > 1", 0, 2, None),

("SELECT name, missions FROM $astronauts WHERE missions LIKE ANY '%apoll%'", 0, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions ILIKE ANY '%apoll%'", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions LIKE ANY ('%apoll%')", 0, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions ILIKE ANY ('%apoll%')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions LIKE ANY ('%Apoll%')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions ILIKE ANY ('%Apoll%')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions LIKE ANY ('%Apoll%', 'mission')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions ILIKE ANY ('%Apoll%', 'mission')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions NOT LIKE ANY '%apoll%'", 0, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions NOT ILIKE ANY '%apoll%'", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions NOT LIKE ANY ('%apoll%')", 0, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions NOT ILIKE ANY ('%apoll%')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions NOT LIKE ANY ('%Apoll%')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions NOT ILIKE ANY ('%Apoll%')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions NOT LIKE ANY ('%Apoll%', 'mission')", 34, 2, None),
("SELECT name, missions FROM $astronauts WHERE missions NOT ILIKE ANY ('%Apoll%', 'mission')", 34, 2, None),

# ****************************************************************************************

Expand Down

0 comments on commit 6bf5d70

Please sign in to comment.