diff --git a/Cargo.toml b/Cargo.toml index 950a56e9f..0130e0db2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,17 +2,20 @@ name = "compute" version = "0.1.0" authors = ["@joocer"] -edition = "2018" +edition = "2021" [lib] name = "compute" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.20.3", features = ["extension-module", "abi3-py39"] } -pythonize = "0.20" +pythonize = "0.22" serde = "1.0.171" +[dependencies.pyo3] +version = "0.22" +features = ["extension-module"] + [dependencies.sqlparser] version = "0.52.0" features = ["serde", "visitor"] \ No newline at end of file diff --git a/opteryx/managers/expression/ops.py b/opteryx/managers/expression/ops.py index 474a961fd..58615fc16 100644 --- a/opteryx/managers/expression/ops.py +++ b/opteryx/managers/expression/ops.py @@ -32,6 +32,12 @@ def filter_operations(arr, left_type, operator, value, right_type): "AnyOpGtEq", "AnyOpLt", "AnyOpLtEq", + "AnyOpLike", + "AnyOpNotLike", + "AnyOpILike", + "AnyOpNotILike", + "AnyOpRLike", + "AnyOpNotRLike", "AllOpEq", "AllOpNotEq", "AtArrow", @@ -170,6 +176,59 @@ def _inner_filter_operations(arr, operator, value): if operator == "AllOpNotEq": return list_ops.cython_allop_neq(arr[0], value) + if operator == "AnyOpLike": + patterns = value[0] + return numpy.array( + [ + None + if row is None + else any(compute.match_like(row, pattern).true_count > 0 for pattern in patterns) + for row in arr + ], + dtype=bool, + ) + if operator == "AnyOpNotLike": + patterns = value[0] + matches = numpy.array( + [ + None + if row is None + else any(compute.match_like(row, pattern).true_count > 0 for pattern in patterns) + for row in arr + ], + dtype=bool, + ) + return numpy.invert(matches) + if operator == "AnyOpILike": + patterns = value[0] + return numpy.array( + [ + None + if row is None + else any( + compute.match_like(row, pattern, ignore_case=True).true_count > 0 + for pattern in patterns + ) + for row in arr + ], + dtype=bool, + ) + if operator == "AnyOpNotILike": + patterns = value[0] + matches = numpy.array( + [ + None + if row is None + else any( + compute.match_like(row, pattern, ignore_case=True).true_count > 0 + for pattern in patterns + ) + for row in arr + ], + dtype=bool, + ) + return numpy.invert(matches) + if operator == "AtQuestion": element = value[0] diff --git a/opteryx/models/serial_engine.py b/opteryx/models/serial_engine.py deleted file mode 100644 index 9a7dc5d05..000000000 --- a/opteryx/models/serial_engine.py +++ /dev/null @@ -1,7 +0,0 @@ -import gc - -import pyarrow - -from opteryx.constants import ResultType -from opteryx.exceptions import InvalidInternalStateError -from opteryx.third_party.travers import Graph diff --git a/opteryx/planner/logical_planner/logical_planner.py b/opteryx/planner/logical_planner/logical_planner.py index 49c70115a..510e34d3d 100644 --- a/opteryx/planner/logical_planner/logical_planner.py +++ b/opteryx/planner/logical_planner/logical_planner.py @@ -79,13 +79,13 @@ def __str__(self): # fmt:off node_type = self.node_type if node_type == LogicalPlanStepType.AggregateAndGroup: - return f"AGGREGATE ({', '.join(format_expression(col) for col in self.aggregates)}) GROUP BY ({', '.join(format_expression(col) for col in self.groups)})" + return f"AGGREGATE [{', '.join(format_expression(col) for col in self.aggregates)}] GROUP BY [{', '.join(format_expression(col) for col in self.groups)}]" if node_type == LogicalPlanStepType.Aggregate: - return f"AGGREGATE ({', '.join(format_expression(col) for col in self.aggregates)})" + return f"AGGREGATE [{', '.join(format_expression(col) for col in self.aggregates)}]" if node_type == LogicalPlanStepType.Distinct: distinct_on = "" if self.on is not None: - distinct_on = f" ON ({','.join(format_expression(col) for col in self.on)})" + distinct_on = f" ON [{','.join(format_expression(col) for col in self.on)}]" return f"DISTINCT{distinct_on}" if node_type == LogicalPlanStepType.Explain: return f"EXPLAIN{' ANALYZE' if self.analyze else ''}{(' (' + self.format + ')') if self.format else ''}" @@ -111,16 +111,16 @@ def __str__(self): return f"{self.type.upper()} JOIN{distinct} (USING {','.join(map(format_expression, self.using))}){filters}" return f"{self.type.upper()}{distinct} {filters}" if node_type == LogicalPlanStepType.HeapSort: - return f"HEAP SORT (LIMIT {self.limit}, ORDER BY {', '.join(format_expression(item[0]) + (' DESC' if item[1] =='descending' else '') for item in self.order_by)})" + return f"HEAP SORT (LIMIT {self.limit}, ORDER BY [{', '.join(format_expression(item[0]) + (' DESC' if item[1] =='descending' else '') for item in self.order_by)}])" if node_type == LogicalPlanStepType.Limit: limit_str = f"LIMIT ({self.limit})" if self.limit is not None else "" offset_str = f" OFFSET ({self.offset})" if self.offset is not None else "" return (limit_str + offset_str).strip() if node_type == LogicalPlanStepType.Order: - return f"ORDER BY ({', '.join(format_expression(item[0]) + (' DESC' if item[1] =='descending' else '') for item in self.order_by)})" + return f"ORDER BY [{', '.join(format_expression(item[0]) + (' DESC' if item[1] =='descending' else '') for item in self.order_by)}]" if node_type == LogicalPlanStepType.Project: order_by_indicator = f" + ({', '.join(format_expression(col) for col in self.order_by_columns)})" if self.order_by_columns else "" - return f"PROJECT ({', '.join(format_expression(col) for col in self.columns)}){order_by_indicator}" + return f"PROJECT [{', '.join(format_expression(col) for col in self.columns)}]{order_by_indicator}" if node_type == LogicalPlanStepType.Scan: io_async = "ASYNC " if hasattr(self.connector, "async_read_blob") else "" date_range = "" diff --git a/opteryx/planner/logical_planner/logical_planner_builders.py b/opteryx/planner/logical_planner/logical_planner_builders.py index a52c69e9a..7d853b5b0 100644 --- a/opteryx/planner/logical_planner/logical_planner_builders.py +++ b/opteryx/planner/logical_planner/logical_planner_builders.py @@ -549,10 +549,18 @@ def pattern_match(branch, alias: Optional[List[str]] = None, key=None): negated = branch["negated"] left = build(branch["expr"]) right = build(branch["pattern"]) + is_any = branch.get("any", False) if key in ("PGRegexMatch", "SimilarTo"): key = "RLike" if negated: key = f"Not{key}" + if is_any: + key = f"AnyOp{key}" + if right.node_type == NodeType.NESTED: + right = right.centre + if right.type != OrsoTypes.ARRAY: + right.value = (right.value,) + right.type = OrsoTypes.ARRAY return Node( NodeType.COMPARISON_OPERATOR, value=key, diff --git a/opteryx/third_party/sqloxide/__init__.py b/opteryx/third_party/sqloxide/__init__.py index aa2c47b68..36517ff0e 100644 --- a/opteryx/third_party/sqloxide/__init__.py +++ b/opteryx/third_party/sqloxide/__init__.py @@ -8,7 +8,6 @@ """ from opteryx.compute import parse_sql -from opteryx.compute import restore_ast # Explicitly define the API of this module for external consumers -__all__ = ["parse_sql", "restore_ast"] +__all__ = ["parse_sql"] diff --git a/src/lib.rs b/src/lib.rs index 02b0f0a48..420a4cd3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,18 +1,45 @@ +use pythonize::pythonize; - - +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::wrap_pyfunction; - -mod sqloxide; -use sqloxide::{restore_ast, parse_sql}; +use sqlparser::dialect::dialect_from_str; +use sqlparser::dialect::*; +use sqlparser::parser::Parser; + + +/// Function to parse SQL statements from a string. Returns a list with +/// one item per query statement. +/// +/// Available `dialects`: https://github.com/sqlparser-rs/sqlparser-rs/blob/main/src/dialect/mod.rs#L189-L206 +#[pyfunction] +#[pyo3(text_signature = "(sql, dialect)")] +fn parse_sql(py: Python, sql: String, dialect: String) -> PyResult { + let chosen_dialect = dialect_from_str(dialect).unwrap_or_else(|| { + println!("The dialect you chose was not recognized, falling back to 'generic'"); + Box::new(GenericDialect {}) + }); + let parse_result = Parser::parse_sql(&*chosen_dialect, &sql); + + let output = match parse_result { + Ok(statements) => pythonize(py, &statements).map_err(|e| { + let msg = e.to_string(); + PyValueError::new_err(format!("Python object serialization failed.\n\t{msg}")) + })?, + Err(e) => { + let msg = e.to_string(); + return Err(PyValueError::new_err(format!( + "Query parsing failed.\n\t{msg}" + ))); + } + }; + + Ok(output.into()) +} #[pymodule] -fn compute(_py: Python, m: &PyModule) -> PyResult<()> { +fn compute(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(parse_sql, m)?)?; - m.add_function(wrap_pyfunction!(restore_ast, m)?)?; - Ok(()) } \ No newline at end of file diff --git a/src/sqloxide.rs b/src/sqloxide.rs deleted file mode 100644 index bf2a5d11c..000000000 --- a/src/sqloxide.rs +++ /dev/null @@ -1,93 +0,0 @@ -use pythonize::pythonize; - -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; - -use pythonize::PythonizeError; - -use sqlparser::ast::Statement; -use sqlparser::dialect::*; -use sqlparser::parser::Parser; - -fn string_to_dialect(dialect: &str) -> Box { - match dialect.to_lowercase().as_str() { - "ansi" => Box::new(AnsiDialect {}), - "bigquery" | "bq" => Box::new(BigQueryDialect {}), - "clickhouse" => Box::new(ClickHouseDialect {}), - "generic" => Box::new(GenericDialect {}), - "hive" => Box::new(HiveDialect {}), - "ms" | "mssql" => Box::new(MsSqlDialect {}), - "mysql" => Box::new(MySqlDialect {}), - "postgres" => Box::new(PostgreSqlDialect {}), - "redshift" => Box::new(RedshiftSqlDialect {}), - "snowflake" => Box::new(SnowflakeDialect {}), - "sqlite" => Box::new(SQLiteDialect {}), - _ => { - println!("The dialect you chose was not recognized, falling back to 'generic'"); - Box::new(GenericDialect {}) - } - } -} - -/// Function to parse SQL statements from a string. Returns a list with -/// one item per query statement. -/// -/// Available `dialects`: -/// - generic -/// - ansi -/// - hive -/// - ms (mssql) -/// - mysql -/// - postgres -/// - snowflake -/// - sqlite -/// - clickhouse -/// - redshift -/// - bigquery (bq) -/// -#[pyfunction] -#[pyo3(text_signature = "(sql, dialect)")] -pub fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult { - let chosen_dialect = string_to_dialect(dialect); - let parse_result = Parser::parse_sql(&*chosen_dialect, sql); - - let output = match parse_result { - Ok(statements) => { - pythonize(py, &statements).map_err(|e| { - let msg = e.to_string(); - PyValueError::new_err(format!("Python object serialization failed.\n\t{msg}")) - })? - } - Err(e) => { - let msg = e.to_string(); - return Err(PyValueError::new_err(format!( - "Query parsing failed.\n\t{msg}" - ))); - } - }; - - Ok(output) -} - -/// This utility function allows reconstituing a modified AST back into list of SQL queries. -#[pyfunction] -#[pyo3(text_signature = "(ast)")] -pub fn restore_ast(_py: Python, ast: &PyAny) -> PyResult> { - let parse_result: Result, PythonizeError> = pythonize::depythonize(ast); - - let output = match parse_result { - Ok(statements) => statements, - Err(e) => { - let msg = e.to_string(); - return Err(PyValueError::new_err(format!( - "Query serialization failed.\n\t{msg}" - ))); - } - }; - - Ok(output - .iter() - .map(std::string::ToString::to_string) - .collect::>()) -} - diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index 479d338b1..7c26aeeac 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -1805,6 +1805,22 @@ # Aggregate Functions with HAVING Clause ("SELECT name, COUNT(*) AS count FROM $satellites GROUP BY name HAVING count > 1", 0, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions LIKE ANY '%apoll%'", 0, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions ILIKE ANY '%apoll%'", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions LIKE ANY ('%apoll%')", 0, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions ILIKE ANY ('%apoll%')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions LIKE ANY ('%Apoll%')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions ILIKE ANY ('%Apoll%')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions LIKE ANY ('%Apoll%', 'mission')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions ILIKE ANY ('%Apoll%', 'mission')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions NOT LIKE ANY '%apoll%'", 0, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions NOT ILIKE ANY '%apoll%'", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions NOT LIKE ANY ('%apoll%')", 0, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions NOT ILIKE ANY ('%apoll%')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions NOT LIKE ANY ('%Apoll%')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions NOT ILIKE ANY ('%Apoll%')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions NOT LIKE ANY ('%Apoll%', 'mission')", 34, 2, None), + ("SELECT name, missions FROM $astronauts WHERE missions NOT ILIKE ANY ('%Apoll%', 'mission')", 34, 2, None), # ****************************************************************************************