From f3bae99de6aaeaf2abc4b4df33b2a2b5ee2153d9 Mon Sep 17 00:00:00 2001 From: joocer Date: Tue, 14 Nov 2023 23:31:30 +0000 Subject: [PATCH] #1261 --- opteryx/components/binder/binder_visitor.py | 36 +++++++++---------- opteryx/components/heuristic_optimizer.py | 27 +++++++------- opteryx/connectors/arrow_connector.py | 5 ++- opteryx/connectors/aws_s3_connector.py | 4 +-- opteryx/connectors/base/base_connector.py | 5 ++- opteryx/connectors/disk_connector.py | 4 +-- opteryx/connectors/file_connector.py | 5 +-- .../connectors/gcp_cloudstorage_connector.py | 4 +-- opteryx/connectors/gcp_firestore_connector.py | 8 +++-- opteryx/connectors/mongodb_connector.py | 6 ++-- opteryx/connectors/sql_connector.py | 5 ++- opteryx/connectors/virtual_data.py | 14 ++++++-- opteryx/operators/scanner_node.py | 3 +- opteryx/utils/arrow.py | 25 +++++++++++++ 14 files changed, 101 insertions(+), 50 deletions(-) diff --git a/opteryx/components/binder/binder_visitor.py b/opteryx/components/binder/binder_visitor.py index 1a7234621..897a518da 100644 --- a/opteryx/components/binder/binder_visitor.py +++ b/opteryx/components/binder/binder_visitor.py @@ -640,6 +640,24 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding return node, context + def visit_order(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]: + order_by = [] + columns = [] + for column, direction in node.order_by: + bound_column, context = inner_binder(column, context) + + order_by.append( + ( + bound_column, + "ascending" if direction else "descending", + ) + ) + columns.append(bound_column) + + node.order_by = order_by + node.columns = columns + return node, context + def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]: columns = [] @@ -714,24 +732,6 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind return node, context - def visit_order(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]: - order_by = [] - columns = [] - for column, direction in node.order_by: - bound_column, context = inner_binder(column, context) - - order_by.append( - ( - bound_column, - "ascending" if direction else "descending", - ) - ) - columns.append(bound_column) - - node.order_by = order_by - node.columns = columns - return node, context - def visit_scan(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]: from opteryx.connectors import connector_factory from opteryx.connectors.capabilities import Cacheable diff --git a/opteryx/components/heuristic_optimizer.py b/opteryx/components/heuristic_optimizer.py index f39cb5766..d4a04383a 100644 --- a/opteryx/components/heuristic_optimizer.py +++ b/opteryx/components/heuristic_optimizer.py @@ -56,6 +56,8 @@ """ from opteryx.components.logical_planner import LogicalPlan from opteryx.components.logical_planner import LogicalPlanStepType +from opteryx.managers.expression import NodeType +from opteryx.managers.expression import get_all_nodes_of_type # Context object to carry state @@ -70,7 +72,7 @@ def __init__(self, tree: LogicalPlan): # We collect column identities so we can push column selection as close to the # read as possible, including off to remote systems - self.collected_identities = [] + self.collected_identities = set() # Optimizer Visitor @@ -80,8 +82,13 @@ def rewrite_predicates(self, node): def collect_columns(self, node): if node.columns: - return [col.schema_column.identity for col in node.columns if col.schema_column] - return [] + return { + col.schema_column.identity + for column in node.columns + for col in get_all_nodes_of_type(column, (NodeType.IDENTIFIER,)) + if col.schema_column + } + return set() def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext): # collect column references to push PROJECTION @@ -93,7 +100,7 @@ def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext): # do this before any transformations if node.node_type != LogicalPlanStepType.Scan: - context.collected_identities.extend(self.collect_columns(node)) + context.collected_identities.union(self.collect_columns(node)) if node.node_type == LogicalPlanStepType.Filter: # rewrite predicates, to favor conjuctions and reduce negations @@ -103,16 +110,10 @@ def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext): if node.node_type == LogicalPlanStepType.Scan: # push projections node_columns = [ - col - for col in node.schema.columns - if col.identity in set(context.collected_identities) + col for col in node.schema.columns if col.identity in context.collected_identities ] - # print("FOUND") - # print(node_columns) - # print("NOT FOUND") - # print([col for col in node.schema.columns if col.identity not in set(context.collected_identities)]) - # push selections - pass + # these are the pushed columns + node.columns = node_columns context.optimized_tree.add_node(nid, node) if parent: diff --git a/opteryx/connectors/arrow_connector.py b/opteryx/connectors/arrow_connector.py index d29b04c7a..5212c11e4 100644 --- a/opteryx/connectors/arrow_connector.py +++ b/opteryx/connectors/arrow_connector.py @@ -23,6 +23,7 @@ from opteryx.connectors.base.base_connector import DEFAULT_MORSEL_SIZE from opteryx.connectors.base.base_connector import BaseConnector from opteryx.shared import MaterializedDatasets +from opteryx.utils import arrow class ArrowConnector(BaseConnector): @@ -45,11 +46,13 @@ def get_dataset_schema(self) -> RelationSchema: return self.schema - def read_dataset(self, **kwargs) -> pyarrow.Table: + def read_dataset(self, columns: list = None) -> pyarrow.Table: dataset = self._datasets[self.dataset] batch_size = DEFAULT_MORSEL_SIZE // (dataset.nbytes / dataset.num_rows) for batch in dataset.to_batches(max_chunksize=batch_size): morsel = pyarrow.Table.from_batches([batch], schema=dataset.schema) + if columns: + morsel = arrow.post_read_projector(morsel, columns) yield morsel diff --git a/opteryx/connectors/aws_s3_connector.py b/opteryx/connectors/aws_s3_connector.py index 2b41520f5..193c4b05f 100644 --- a/opteryx/connectors/aws_s3_connector.py +++ b/opteryx/connectors/aws_s3_connector.py @@ -73,7 +73,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]: return [blob for blob in blobs if ("." + blob.split(".")[-1].lower()) in VALID_EXTENSIONS] - def read_dataset(self) -> pyarrow.Table: + def read_dataset(self, columns: list = None) -> pyarrow.Table: blob_names = self.partition_scheme.get_blobs_in_partition( start_date=self.start_date, end_date=self.end_date, @@ -91,7 +91,7 @@ def read_dataset(self) -> pyarrow.Table: try: decoder = get_decoder(blob_name) blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics) - yield decoder(blob_bytes) + yield decoder(blob_bytes, projection=columns) except UnsupportedFileTypeError: pass diff --git a/opteryx/connectors/base/base_connector.py b/opteryx/connectors/base/base_connector.py index 59c14c2db..42cc21d7f 100644 --- a/opteryx/connectors/base/base_connector.py +++ b/opteryx/connectors/base/base_connector.py @@ -85,15 +85,18 @@ def read_schema_from_metastore(self): def chunk_dictset( self, dictset: typing.Iterable[dict], + columns: list, morsel_size: int = DEFAULT_MORSEL_SIZE, initial_chunk_size: int = INITIAL_CHUNK_SIZE, - ): + ) -> pyarrow.Table: chunk = [] self.chunk_size = initial_chunk_size # we reset each time morsel = None for index, record in enumerate(dictset): _id = record.pop("_id", None) + # column selection + record = {k: record.get(k) for k in columns} record["id"] = None if _id is None else str(_id) chunk.append(record) diff --git a/opteryx/connectors/disk_connector.py b/opteryx/connectors/disk_connector.py index cdf3a927b..6c056e036 100644 --- a/opteryx/connectors/disk_connector.py +++ b/opteryx/connectors/disk_connector.py @@ -87,7 +87,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]: ] return files - def read_dataset(self) -> pyarrow.Table: + def read_dataset(self, columns: list = None) -> pyarrow.Table: """ Read the entire dataset from disk. @@ -110,7 +110,7 @@ def read_dataset(self) -> pyarrow.Table: try: decoder = get_decoder(blob_name) blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics) - yield decoder(blob_bytes) + yield decoder(blob_bytes, projection=columns) except UnsupportedFileTypeError: pass # Skip unsupported file types diff --git a/opteryx/connectors/file_connector.py b/opteryx/connectors/file_connector.py index a5f423364..a81fbc188 100644 --- a/opteryx/connectors/file_connector.py +++ b/opteryx/connectors/file_connector.py @@ -16,6 +16,7 @@ """ from typing import Optional +import pyarrow from orso.schema import RelationSchema from opteryx.connectors.base.base_connector import BaseConnector @@ -46,7 +47,7 @@ def _read_file(self) -> None: with open(self.dataset, mode="br") as file: self._byte_array = bytes(file.read()) - def read_dataset(self) -> iter: + def read_dataset(self, columns: list = None) -> pyarrow.Table: """ Reads the dataset file and decodes it. @@ -54,7 +55,7 @@ def read_dataset(self) -> iter: An iterator containing a single decoded pyarrow.Table. """ self._read_file() - return iter([self.decoder(self._byte_array)]) + return iter([self.decoder(self._byte_array, projection=columns)]) def get_dataset_schema(self) -> RelationSchema: """ diff --git a/opteryx/connectors/gcp_cloudstorage_connector.py b/opteryx/connectors/gcp_cloudstorage_connector.py index a4421d21d..9fb8f46aa 100644 --- a/opteryx/connectors/gcp_cloudstorage_connector.py +++ b/opteryx/connectors/gcp_cloudstorage_connector.py @@ -93,7 +93,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]: blobs = (bucket + "/" + blob.name for blob in blobs if not blob.name.endswith("/")) return [blob for blob in blobs if ("." + blob.split(".")[-1].lower()) in VALID_EXTENSIONS] - def read_dataset(self) -> pyarrow.Table: + def read_dataset(self, columns: list = None) -> pyarrow.Table: blob_names = self.partition_scheme.get_blobs_in_partition( start_date=self.start_date, end_date=self.end_date, @@ -111,7 +111,7 @@ def read_dataset(self) -> pyarrow.Table: try: decoder = get_decoder(blob_name) blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics) - yield decoder(blob_bytes) + yield decoder(blob_bytes, projection=columns) except UnsupportedFileTypeError: pass diff --git a/opteryx/connectors/gcp_firestore_connector.py b/opteryx/connectors/gcp_firestore_connector.py index 1a630034c..d3c51bd8b 100644 --- a/opteryx/connectors/gcp_firestore_connector.py +++ b/opteryx/connectors/gcp_firestore_connector.py @@ -66,7 +66,9 @@ def _initialize(): # pragma: no cover class GcpFireStoreConnector(BaseConnector): __mode__ = "Collection" - def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader": + def read_dataset( + self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE + ) -> "DatasetReader": """ Return a morsel of documents """ @@ -82,7 +84,9 @@ def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader": documents = documents.stream() for morsel in self.chunk_dictset( - ({**doc.to_dict(), "_id": doc.id} for doc in documents), initial_chunk_size=chunk_size + ({**doc.to_dict(), "_id": doc.id} for doc in documents), + columns=columns, + initial_chunk_size=chunk_size, ): yield morsel diff --git a/opteryx/connectors/mongodb_connector.py b/opteryx/connectors/mongodb_connector.py index 80ff6e42e..99bea2742 100644 --- a/opteryx/connectors/mongodb_connector.py +++ b/opteryx/connectors/mongodb_connector.py @@ -63,13 +63,15 @@ def __init__(self, *args, database: str = None, connection: str = None, **kwargs "MongoDB connector requires 'database' set in register_stpre or MONGODB_DATABASE set in environment variables." ) - def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader": + def read_dataset( + self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE + ) -> "DatasetReader": import pymongo client = pymongo.MongoClient(self.connection) # type:ignore database = client[self.database] documents = database[self.dataset].find() - for morsel in self.chunk_dictset(documents, initial_chunk_size=chunk_size): + for morsel in self.chunk_dictset(documents, columns=columns, initial_chunk_size=chunk_size): yield morsel def get_dataset_schema(self) -> RelationSchema: diff --git a/opteryx/connectors/sql_connector.py b/opteryx/connectors/sql_connector.py index 4e77bae78..6d932e9a9 100644 --- a/opteryx/connectors/sql_connector.py +++ b/opteryx/connectors/sql_connector.py @@ -52,7 +52,9 @@ def __init__(self, *args, connection: str = None, engine=None, **kwargs): self.schema = None self.metadata = MetaData() - def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader": + def read_dataset( + self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE + ) -> "DatasetReader": from sqlalchemy import Table from sqlalchemy import select @@ -60,6 +62,7 @@ def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader": # get the schema from the dataset table = Table(self.dataset, self.metadata, autoload_with=self._engine) + print("SQL push projection") query = select(table) morsel = DataFrame(schema=self.schema) diff --git a/opteryx/connectors/virtual_data.py b/opteryx/connectors/virtual_data.py index 95ebf3809..d09904091 100644 --- a/opteryx/connectors/virtual_data.py +++ b/opteryx/connectors/virtual_data.py @@ -29,6 +29,7 @@ from opteryx.connectors.base.base_connector import DatasetReader from opteryx.connectors.capabilities import Partitionable from opteryx.exceptions import DatasetNotFoundError +from opteryx.utils import arrow WELL_KNOWN_DATASETS = { "$astronauts": (virtual_datasets.astronauts, True), @@ -67,9 +68,13 @@ def __init__(self, *args, **kwargs): def interal_only(self): return True - def read_dataset(self) -> "DatasetReader": + def read_dataset(self, columns: list = None) -> "DatasetReader": return SampleDatasetReader( - self.dataset, config=self.config, date=self.end_date, variables=self.variables + self.dataset, + columns=columns, + config=self.config, + date=self.end_date, + variables=self.variables, ) def get_dataset_schema(self) -> RelationSchema: @@ -84,6 +89,7 @@ class SampleDatasetReader(DatasetReader): def __init__( self, dataset_name: str, + columns: list, config: typing.Optional[typing.Dict[str, typing.Any]] = None, date: typing.Union[datetime.datetime, datetime.date, None] = None, variables: typing.Dict = None, @@ -95,6 +101,7 @@ def __init__( config: Configuration information specific to the reader. """ super().__init__(dataset_name=dataset_name, config=config) + self.columns = columns self.exhausted = False self.date = date self.variables = variables @@ -116,4 +123,5 @@ def __next__(self) -> pyarrow.Table: if data_provider is None: suggestion = suggest(self.dataset_name.lower()) raise DatasetNotFoundError(suggestion=suggestion, dataset=self.dataset_name) - return data_provider.read(self.date, self.variables) + table = data_provider.read(self.date, self.variables) + return arrow.post_read_projector(table, self.columns) diff --git a/opteryx/operators/scanner_node.py b/opteryx/operators/scanner_node.py index e13a6f7b1..2600185c0 100644 --- a/opteryx/operators/scanner_node.py +++ b/opteryx/operators/scanner_node.py @@ -50,6 +50,7 @@ def __init__(self, properties: QueryProperties, **parameters): self.start_date = parameters.get("start_date") self.end_date = parameters.get("end_date") self.hints = parameters.get("hints", []) + self.columns = parameters.get("columns", []) if len(self.hints) != 0: self.statistics.add_message("All HINTS are currently ignored") @@ -82,7 +83,7 @@ def execute(self) -> Iterable: morsel = None schema = self.parameters["schema"] start_clock = time.monotonic_ns() - reader = self.parameters["connector"].read_dataset() + reader = self.parameters["connector"].read_dataset(columns=self.columns) for morsel in reader: self.statistics.blobs_read += 1 self.statistics.rows_read += morsel.num_rows diff --git a/opteryx/utils/arrow.py b/opteryx/utils/arrow.py index 39868e324..e984a17a1 100644 --- a/opteryx/utils/arrow.py +++ b/opteryx/utils/arrow.py @@ -105,3 +105,28 @@ def restore_null_columns(removed, table): for column in removed: # pragma: no cover table = table.append_column(column, pyarrow.array([None] * table.num_rows)) return table + + +def post_read_projector(table: pyarrow.Table, columns: list) -> pyarrow.Table: + """ + This is the near-read projection for data sources that the projection can't be + done as part of the read. + """ + if not columns: + # this should happen when there's no relation in the query + return table + + schema_columns = table.column_names + + columns_to_keep = [] + column_names = [] + + for projection_column in columns: + for schema_column in schema_columns: + if schema_column in projection_column.all_names: + columns_to_keep.append(schema_column) + column_names.append(projection_column.name) + break + + table = table.select(columns_to_keep) + return table.rename_columns(column_names)