Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Nov 14, 2023
1 parent 457c604 commit f3bae99
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 50 deletions.
36 changes: 18 additions & 18 deletions opteryx/components/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,24 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding

return node, context

def visit_order(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
order_by = []
columns = []
for column, direction in node.order_by:
bound_column, context = inner_binder(column, context)

order_by.append(
(
bound_column,
"ascending" if direction else "descending",
)
)
columns.append(bound_column)

node.order_by = order_by
node.columns = columns
return node, context

def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
columns = []

Expand Down Expand Up @@ -714,24 +732,6 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind

return node, context

def visit_order(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
order_by = []
columns = []
for column, direction in node.order_by:
bound_column, context = inner_binder(column, context)

order_by.append(
(
bound_column,
"ascending" if direction else "descending",
)
)
columns.append(bound_column)

node.order_by = order_by
node.columns = columns
return node, context

def visit_scan(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
from opteryx.connectors import connector_factory
from opteryx.connectors.capabilities import Cacheable
Expand Down
27 changes: 14 additions & 13 deletions opteryx/components/heuristic_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
"""
from opteryx.components.logical_planner import LogicalPlan
from opteryx.components.logical_planner import LogicalPlanStepType
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import get_all_nodes_of_type


# Context object to carry state
Expand All @@ -70,7 +72,7 @@ def __init__(self, tree: LogicalPlan):

# We collect column identities so we can push column selection as close to the
# read as possible, including off to remote systems
self.collected_identities = []
self.collected_identities = set()


# Optimizer Visitor
Expand All @@ -80,8 +82,13 @@ def rewrite_predicates(self, node):

def collect_columns(self, node):
if node.columns:
return [col.schema_column.identity for col in node.columns if col.schema_column]
return []
return {
col.schema_column.identity
for column in node.columns
for col in get_all_nodes_of_type(column, (NodeType.IDENTIFIER,))
if col.schema_column
}
return set()

def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext):
# collect column references to push PROJECTION
Expand All @@ -93,7 +100,7 @@ def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext):

# do this before any transformations
if node.node_type != LogicalPlanStepType.Scan:
context.collected_identities.extend(self.collect_columns(node))
context.collected_identities.union(self.collect_columns(node))

if node.node_type == LogicalPlanStepType.Filter:
# rewrite predicates, to favor conjuctions and reduce negations
Expand All @@ -103,16 +110,10 @@ def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext):
if node.node_type == LogicalPlanStepType.Scan:
# push projections
node_columns = [
col
for col in node.schema.columns
if col.identity in set(context.collected_identities)
col for col in node.schema.columns if col.identity in context.collected_identities
]
# print("FOUND")
# print(node_columns)
# print("NOT FOUND")
# print([col for col in node.schema.columns if col.identity not in set(context.collected_identities)])
# push selections
pass
# these are the pushed columns
node.columns = node_columns

context.optimized_tree.add_node(nid, node)
if parent:
Expand Down
5 changes: 4 additions & 1 deletion opteryx/connectors/arrow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from opteryx.connectors.base.base_connector import DEFAULT_MORSEL_SIZE
from opteryx.connectors.base.base_connector import BaseConnector
from opteryx.shared import MaterializedDatasets
from opteryx.utils import arrow


class ArrowConnector(BaseConnector):
Expand All @@ -45,11 +46,13 @@ def get_dataset_schema(self) -> RelationSchema:

return self.schema

def read_dataset(self, **kwargs) -> pyarrow.Table:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
dataset = self._datasets[self.dataset]

batch_size = DEFAULT_MORSEL_SIZE // (dataset.nbytes / dataset.num_rows)

for batch in dataset.to_batches(max_chunksize=batch_size):
morsel = pyarrow.Table.from_batches([batch], schema=dataset.schema)
if columns:
morsel = arrow.post_read_projector(morsel, columns)
yield morsel
4 changes: 2 additions & 2 deletions opteryx/connectors/aws_s3_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]:

return [blob for blob in blobs if ("." + blob.split(".")[-1].lower()) in VALID_EXTENSIONS]

def read_dataset(self) -> pyarrow.Table:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
blob_names = self.partition_scheme.get_blobs_in_partition(
start_date=self.start_date,
end_date=self.end_date,
Expand All @@ -91,7 +91,7 @@ def read_dataset(self) -> pyarrow.Table:
try:
decoder = get_decoder(blob_name)
blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics)
yield decoder(blob_bytes)
yield decoder(blob_bytes, projection=columns)
except UnsupportedFileTypeError:
pass

Expand Down
5 changes: 4 additions & 1 deletion opteryx/connectors/base/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,18 @@ def read_schema_from_metastore(self):
def chunk_dictset(
self,
dictset: typing.Iterable[dict],
columns: list,
morsel_size: int = DEFAULT_MORSEL_SIZE,
initial_chunk_size: int = INITIAL_CHUNK_SIZE,
):
) -> pyarrow.Table:
chunk = []
self.chunk_size = initial_chunk_size # we reset each time
morsel = None

for index, record in enumerate(dictset):
_id = record.pop("_id", None)
# column selection
record = {k: record.get(k) for k in columns}
record["id"] = None if _id is None else str(_id)

chunk.append(record)
Expand Down
4 changes: 2 additions & 2 deletions opteryx/connectors/disk_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]:
]
return files

def read_dataset(self) -> pyarrow.Table:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
"""
Read the entire dataset from disk.
Expand All @@ -110,7 +110,7 @@ def read_dataset(self) -> pyarrow.Table:
try:
decoder = get_decoder(blob_name)
blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics)
yield decoder(blob_bytes)
yield decoder(blob_bytes, projection=columns)
except UnsupportedFileTypeError:
pass # Skip unsupported file types

Expand Down
5 changes: 3 additions & 2 deletions opteryx/connectors/file_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
from typing import Optional

import pyarrow
from orso.schema import RelationSchema

from opteryx.connectors.base.base_connector import BaseConnector
Expand Down Expand Up @@ -46,15 +47,15 @@ def _read_file(self) -> None:
with open(self.dataset, mode="br") as file:
self._byte_array = bytes(file.read())

def read_dataset(self) -> iter:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
"""
Reads the dataset file and decodes it.
Returns:
An iterator containing a single decoded pyarrow.Table.
"""
self._read_file()
return iter([self.decoder(self._byte_array)])
return iter([self.decoder(self._byte_array, projection=columns)])

def get_dataset_schema(self) -> RelationSchema:
"""
Expand Down
4 changes: 2 additions & 2 deletions opteryx/connectors/gcp_cloudstorage_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]:
blobs = (bucket + "/" + blob.name for blob in blobs if not blob.name.endswith("/"))
return [blob for blob in blobs if ("." + blob.split(".")[-1].lower()) in VALID_EXTENSIONS]

def read_dataset(self) -> pyarrow.Table:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
blob_names = self.partition_scheme.get_blobs_in_partition(
start_date=self.start_date,
end_date=self.end_date,
Expand All @@ -111,7 +111,7 @@ def read_dataset(self) -> pyarrow.Table:
try:
decoder = get_decoder(blob_name)
blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics)
yield decoder(blob_bytes)
yield decoder(blob_bytes, projection=columns)
except UnsupportedFileTypeError:
pass

Expand Down
8 changes: 6 additions & 2 deletions opteryx/connectors/gcp_firestore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def _initialize(): # pragma: no cover
class GcpFireStoreConnector(BaseConnector):
__mode__ = "Collection"

def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader":
def read_dataset(
self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE
) -> "DatasetReader":
"""
Return a morsel of documents
"""
Expand All @@ -82,7 +84,9 @@ def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader":
documents = documents.stream()

for morsel in self.chunk_dictset(
({**doc.to_dict(), "_id": doc.id} for doc in documents), initial_chunk_size=chunk_size
({**doc.to_dict(), "_id": doc.id} for doc in documents),
columns=columns,
initial_chunk_size=chunk_size,
):
yield morsel

Expand Down
6 changes: 4 additions & 2 deletions opteryx/connectors/mongodb_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ def __init__(self, *args, database: str = None, connection: str = None, **kwargs
"MongoDB connector requires 'database' set in register_stpre or MONGODB_DATABASE set in environment variables."
)

def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader":
def read_dataset(
self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE
) -> "DatasetReader":
import pymongo

client = pymongo.MongoClient(self.connection) # type:ignore
database = client[self.database]
documents = database[self.dataset].find()
for morsel in self.chunk_dictset(documents, initial_chunk_size=chunk_size):
for morsel in self.chunk_dictset(documents, columns=columns, initial_chunk_size=chunk_size):
yield morsel

def get_dataset_schema(self) -> RelationSchema:
Expand Down
5 changes: 4 additions & 1 deletion opteryx/connectors/sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ def __init__(self, *args, connection: str = None, engine=None, **kwargs):
self.schema = None
self.metadata = MetaData()

def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader":
def read_dataset(
self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE
) -> "DatasetReader":
from sqlalchemy import Table
from sqlalchemy import select

self.chunk_size = chunk_size

# get the schema from the dataset
table = Table(self.dataset, self.metadata, autoload_with=self._engine)
print("SQL push projection")
query = select(table)
morsel = DataFrame(schema=self.schema)

Expand Down
14 changes: 11 additions & 3 deletions opteryx/connectors/virtual_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from opteryx.connectors.base.base_connector import DatasetReader
from opteryx.connectors.capabilities import Partitionable
from opteryx.exceptions import DatasetNotFoundError
from opteryx.utils import arrow

WELL_KNOWN_DATASETS = {
"$astronauts": (virtual_datasets.astronauts, True),
Expand Down Expand Up @@ -67,9 +68,13 @@ def __init__(self, *args, **kwargs):
def interal_only(self):
return True

def read_dataset(self) -> "DatasetReader":
def read_dataset(self, columns: list = None) -> "DatasetReader":
return SampleDatasetReader(
self.dataset, config=self.config, date=self.end_date, variables=self.variables
self.dataset,
columns=columns,
config=self.config,
date=self.end_date,
variables=self.variables,
)

def get_dataset_schema(self) -> RelationSchema:
Expand All @@ -84,6 +89,7 @@ class SampleDatasetReader(DatasetReader):
def __init__(
self,
dataset_name: str,
columns: list,
config: typing.Optional[typing.Dict[str, typing.Any]] = None,
date: typing.Union[datetime.datetime, datetime.date, None] = None,
variables: typing.Dict = None,
Expand All @@ -95,6 +101,7 @@ def __init__(
config: Configuration information specific to the reader.
"""
super().__init__(dataset_name=dataset_name, config=config)
self.columns = columns
self.exhausted = False
self.date = date
self.variables = variables
Expand All @@ -116,4 +123,5 @@ def __next__(self) -> pyarrow.Table:
if data_provider is None:
suggestion = suggest(self.dataset_name.lower())
raise DatasetNotFoundError(suggestion=suggestion, dataset=self.dataset_name)
return data_provider.read(self.date, self.variables)
table = data_provider.read(self.date, self.variables)
return arrow.post_read_projector(table, self.columns)
3 changes: 2 additions & 1 deletion opteryx/operators/scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, properties: QueryProperties, **parameters):
self.start_date = parameters.get("start_date")
self.end_date = parameters.get("end_date")
self.hints = parameters.get("hints", [])
self.columns = parameters.get("columns", [])

if len(self.hints) != 0:
self.statistics.add_message("All HINTS are currently ignored")
Expand Down Expand Up @@ -82,7 +83,7 @@ def execute(self) -> Iterable:
morsel = None
schema = self.parameters["schema"]
start_clock = time.monotonic_ns()
reader = self.parameters["connector"].read_dataset()
reader = self.parameters["connector"].read_dataset(columns=self.columns)
for morsel in reader:
self.statistics.blobs_read += 1
self.statistics.rows_read += morsel.num_rows
Expand Down
25 changes: 25 additions & 0 deletions opteryx/utils/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,28 @@ def restore_null_columns(removed, table):
for column in removed: # pragma: no cover
table = table.append_column(column, pyarrow.array([None] * table.num_rows))
return table


def post_read_projector(table: pyarrow.Table, columns: list) -> pyarrow.Table:
"""
This is the near-read projection for data sources that the projection can't be
done as part of the read.
"""
if not columns:
# this should happen when there's no relation in the query
return table

schema_columns = table.column_names

columns_to_keep = []
column_names = []

for projection_column in columns:
for schema_column in schema_columns:
if schema_column in projection_column.all_names:
columns_to_keep.append(schema_column)
column_names.append(projection_column.name)
break

table = table.select(columns_to_keep)
return table.rename_columns(column_names)

0 comments on commit f3bae99

Please sign in to comment.