Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1261 #1263

Merged
merged 1 commit into from
Nov 14, 2023
Merged

#1261 #1263

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions opteryx/components/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,24 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding

return node, context

def visit_order(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
order_by = []
columns = []
for column, direction in node.order_by:
bound_column, context = inner_binder(column, context)

order_by.append(
(
bound_column,
"ascending" if direction else "descending",
)
)
columns.append(bound_column)

node.order_by = order_by
node.columns = columns
return node, context

def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
columns = []

Expand Down Expand Up @@ -714,24 +732,6 @@ def visit_project(self, node: Node, context: BindingContext) -> Tuple[Node, Bind

return node, context

def visit_order(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
order_by = []
columns = []
for column, direction in node.order_by:
bound_column, context = inner_binder(column, context)

order_by.append(
(
bound_column,
"ascending" if direction else "descending",
)
)
columns.append(bound_column)

node.order_by = order_by
node.columns = columns
return node, context

def visit_scan(self, node: Node, context: BindingContext) -> Tuple[Node, BindingContext]:
from opteryx.connectors import connector_factory
from opteryx.connectors.capabilities import Cacheable
Expand Down
27 changes: 14 additions & 13 deletions opteryx/components/heuristic_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
"""
from opteryx.components.logical_planner import LogicalPlan
from opteryx.components.logical_planner import LogicalPlanStepType
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import get_all_nodes_of_type


# Context object to carry state
Expand All @@ -70,7 +72,7 @@ def __init__(self, tree: LogicalPlan):

# We collect column identities so we can push column selection as close to the
# read as possible, including off to remote systems
self.collected_identities = []
self.collected_identities = set()


# Optimizer Visitor
Expand All @@ -80,8 +82,13 @@ def rewrite_predicates(self, node):

def collect_columns(self, node):
if node.columns:
return [col.schema_column.identity for col in node.columns if col.schema_column]
return []
return {
col.schema_column.identity
for column in node.columns
for col in get_all_nodes_of_type(column, (NodeType.IDENTIFIER,))
if col.schema_column
}
return set()

def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext):
# collect column references to push PROJECTION
Expand All @@ -93,7 +100,7 @@ def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext):

# do this before any transformations
if node.node_type != LogicalPlanStepType.Scan:
context.collected_identities.extend(self.collect_columns(node))
context.collected_identities.union(self.collect_columns(node))

if node.node_type == LogicalPlanStepType.Filter:
# rewrite predicates, to favor conjuctions and reduce negations
Expand All @@ -103,16 +110,10 @@ def visit(self, parent: str, nid: str, context: HeuristicOptimizerContext):
if node.node_type == LogicalPlanStepType.Scan:
# push projections
node_columns = [
col
for col in node.schema.columns
if col.identity in set(context.collected_identities)
col for col in node.schema.columns if col.identity in context.collected_identities
]
# print("FOUND")
# print(node_columns)
# print("NOT FOUND")
# print([col for col in node.schema.columns if col.identity not in set(context.collected_identities)])
# push selections
pass
# these are the pushed columns
node.columns = node_columns

context.optimized_tree.add_node(nid, node)
if parent:
Expand Down
5 changes: 4 additions & 1 deletion opteryx/connectors/arrow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from opteryx.connectors.base.base_connector import DEFAULT_MORSEL_SIZE
from opteryx.connectors.base.base_connector import BaseConnector
from opteryx.shared import MaterializedDatasets
from opteryx.utils import arrow


class ArrowConnector(BaseConnector):
Expand All @@ -45,11 +46,13 @@ def get_dataset_schema(self) -> RelationSchema:

return self.schema

def read_dataset(self, **kwargs) -> pyarrow.Table:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
dataset = self._datasets[self.dataset]

batch_size = DEFAULT_MORSEL_SIZE // (dataset.nbytes / dataset.num_rows)

for batch in dataset.to_batches(max_chunksize=batch_size):
morsel = pyarrow.Table.from_batches([batch], schema=dataset.schema)
if columns:
morsel = arrow.post_read_projector(morsel, columns)
yield morsel
4 changes: 2 additions & 2 deletions opteryx/connectors/aws_s3_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]:

return [blob for blob in blobs if ("." + blob.split(".")[-1].lower()) in VALID_EXTENSIONS]

def read_dataset(self) -> pyarrow.Table:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
blob_names = self.partition_scheme.get_blobs_in_partition(
start_date=self.start_date,
end_date=self.end_date,
Expand All @@ -91,7 +91,7 @@ def read_dataset(self) -> pyarrow.Table:
try:
decoder = get_decoder(blob_name)
blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics)
yield decoder(blob_bytes)
yield decoder(blob_bytes, projection=columns)
except UnsupportedFileTypeError:
pass

Expand Down
5 changes: 4 additions & 1 deletion opteryx/connectors/base/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,18 @@ def read_schema_from_metastore(self):
def chunk_dictset(
self,
dictset: typing.Iterable[dict],
columns: list,
morsel_size: int = DEFAULT_MORSEL_SIZE,
initial_chunk_size: int = INITIAL_CHUNK_SIZE,
):
) -> pyarrow.Table:
chunk = []
self.chunk_size = initial_chunk_size # we reset each time
morsel = None

for index, record in enumerate(dictset):
_id = record.pop("_id", None)
# column selection
record = {k: record.get(k) for k in columns}
record["id"] = None if _id is None else str(_id)

chunk.append(record)
Expand Down
4 changes: 2 additions & 2 deletions opteryx/connectors/disk_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]:
]
return files

def read_dataset(self) -> pyarrow.Table:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
"""
Read the entire dataset from disk.

Expand All @@ -110,7 +110,7 @@ def read_dataset(self) -> pyarrow.Table:
try:
decoder = get_decoder(blob_name)
blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics)
yield decoder(blob_bytes)
yield decoder(blob_bytes, projection=columns)
except UnsupportedFileTypeError:
pass # Skip unsupported file types

Expand Down
5 changes: 3 additions & 2 deletions opteryx/connectors/file_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
from typing import Optional

import pyarrow
from orso.schema import RelationSchema

from opteryx.connectors.base.base_connector import BaseConnector
Expand Down Expand Up @@ -46,15 +47,15 @@ def _read_file(self) -> None:
with open(self.dataset, mode="br") as file:
self._byte_array = bytes(file.read())

def read_dataset(self) -> iter:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
"""
Reads the dataset file and decodes it.

Returns:
An iterator containing a single decoded pyarrow.Table.
"""
self._read_file()
return iter([self.decoder(self._byte_array)])
return iter([self.decoder(self._byte_array, projection=columns)])

def get_dataset_schema(self) -> RelationSchema:
"""
Expand Down
4 changes: 2 additions & 2 deletions opteryx/connectors/gcp_cloudstorage_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_list_of_blob_names(self, *, prefix: str) -> List[str]:
blobs = (bucket + "/" + blob.name for blob in blobs if not blob.name.endswith("/"))
return [blob for blob in blobs if ("." + blob.split(".")[-1].lower()) in VALID_EXTENSIONS]

def read_dataset(self) -> pyarrow.Table:
def read_dataset(self, columns: list = None) -> pyarrow.Table:
blob_names = self.partition_scheme.get_blobs_in_partition(
start_date=self.start_date,
end_date=self.end_date,
Expand All @@ -111,7 +111,7 @@ def read_dataset(self) -> pyarrow.Table:
try:
decoder = get_decoder(blob_name)
blob_bytes = self.read_blob(blob_name=blob_name, statistics=self.statistics)
yield decoder(blob_bytes)
yield decoder(blob_bytes, projection=columns)
except UnsupportedFileTypeError:
pass

Expand Down
8 changes: 6 additions & 2 deletions opteryx/connectors/gcp_firestore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def _initialize(): # pragma: no cover
class GcpFireStoreConnector(BaseConnector):
__mode__ = "Collection"

def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader":
def read_dataset(
self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE
) -> "DatasetReader":
"""
Return a morsel of documents
"""
Expand All @@ -82,7 +84,9 @@ def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader":
documents = documents.stream()

for morsel in self.chunk_dictset(
({**doc.to_dict(), "_id": doc.id} for doc in documents), initial_chunk_size=chunk_size
({**doc.to_dict(), "_id": doc.id} for doc in documents),
columns=columns,
initial_chunk_size=chunk_size,
):
yield morsel

Expand Down
6 changes: 4 additions & 2 deletions opteryx/connectors/mongodb_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ def __init__(self, *args, database: str = None, connection: str = None, **kwargs
"MongoDB connector requires 'database' set in register_stpre or MONGODB_DATABASE set in environment variables."
)

def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader":
def read_dataset(
self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE
) -> "DatasetReader":
import pymongo

client = pymongo.MongoClient(self.connection) # type:ignore
database = client[self.database]
documents = database[self.dataset].find()
for morsel in self.chunk_dictset(documents, initial_chunk_size=chunk_size):
for morsel in self.chunk_dictset(documents, columns=columns, initial_chunk_size=chunk_size):
yield morsel

def get_dataset_schema(self) -> RelationSchema:
Expand Down
5 changes: 4 additions & 1 deletion opteryx/connectors/sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ def __init__(self, *args, connection: str = None, engine=None, **kwargs):
self.schema = None
self.metadata = MetaData()

def read_dataset(self, chunk_size: int = INITIAL_CHUNK_SIZE) -> "DatasetReader":
def read_dataset(
self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE
) -> "DatasetReader":
from sqlalchemy import Table
from sqlalchemy import select

self.chunk_size = chunk_size

# get the schema from the dataset
table = Table(self.dataset, self.metadata, autoload_with=self._engine)
print("SQL push projection")
query = select(table)
morsel = DataFrame(schema=self.schema)

Expand Down
14 changes: 11 additions & 3 deletions opteryx/connectors/virtual_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from opteryx.connectors.base.base_connector import DatasetReader
from opteryx.connectors.capabilities import Partitionable
from opteryx.exceptions import DatasetNotFoundError
from opteryx.utils import arrow

WELL_KNOWN_DATASETS = {
"$astronauts": (virtual_datasets.astronauts, True),
Expand Down Expand Up @@ -67,9 +68,13 @@ def __init__(self, *args, **kwargs):
def interal_only(self):
return True

def read_dataset(self) -> "DatasetReader":
def read_dataset(self, columns: list = None) -> "DatasetReader":
return SampleDatasetReader(
self.dataset, config=self.config, date=self.end_date, variables=self.variables
self.dataset,
columns=columns,
config=self.config,
date=self.end_date,
variables=self.variables,
)

def get_dataset_schema(self) -> RelationSchema:
Expand All @@ -84,6 +89,7 @@ class SampleDatasetReader(DatasetReader):
def __init__(
self,
dataset_name: str,
columns: list,
config: typing.Optional[typing.Dict[str, typing.Any]] = None,
date: typing.Union[datetime.datetime, datetime.date, None] = None,
variables: typing.Dict = None,
Expand All @@ -95,6 +101,7 @@ def __init__(
config: Configuration information specific to the reader.
"""
super().__init__(dataset_name=dataset_name, config=config)
self.columns = columns
self.exhausted = False
self.date = date
self.variables = variables
Expand All @@ -116,4 +123,5 @@ def __next__(self) -> pyarrow.Table:
if data_provider is None:
suggestion = suggest(self.dataset_name.lower())
raise DatasetNotFoundError(suggestion=suggestion, dataset=self.dataset_name)
return data_provider.read(self.date, self.variables)
table = data_provider.read(self.date, self.variables)
return arrow.post_read_projector(table, self.columns)
3 changes: 2 additions & 1 deletion opteryx/operators/scanner_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, properties: QueryProperties, **parameters):
self.start_date = parameters.get("start_date")
self.end_date = parameters.get("end_date")
self.hints = parameters.get("hints", [])
self.columns = parameters.get("columns", [])

if len(self.hints) != 0:
self.statistics.add_message("All HINTS are currently ignored")
Expand Down Expand Up @@ -82,7 +83,7 @@ def execute(self) -> Iterable:
morsel = None
schema = self.parameters["schema"]
start_clock = time.monotonic_ns()
reader = self.parameters["connector"].read_dataset()
reader = self.parameters["connector"].read_dataset(columns=self.columns)
for morsel in reader:
self.statistics.blobs_read += 1
self.statistics.rows_read += morsel.num_rows
Expand Down
25 changes: 25 additions & 0 deletions opteryx/utils/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,28 @@ def restore_null_columns(removed, table):
for column in removed: # pragma: no cover
table = table.append_column(column, pyarrow.array([None] * table.num_rows))
return table


def post_read_projector(table: pyarrow.Table, columns: list) -> pyarrow.Table:
"""
This is the near-read projection for data sources that the projection can't be
done as part of the read.
"""
if not columns:
# this should happen when there's no relation in the query
return table

schema_columns = table.column_names

columns_to_keep = []
column_names = []

for projection_column in columns:
for schema_column in schema_columns:
if schema_column in projection_column.all_names:
columns_to_keep.append(schema_column)
column_names.append(projection_column.name)
break

table = table.select(columns_to_keep)
return table.rename_columns(column_names)
Loading