Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Sep 12, 2024
1 parent 21aa252 commit 35503db
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 37 deletions.
30 changes: 18 additions & 12 deletions opteryx/compiled/functions/ip_address.pyx
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
# cython: language_level=3
# cython: boundscheck=False
# cython: wraparound=False
# cython: nonecheck=False
# cython: overflowcheck=False

from libc.stdint cimport uint32_t
from libc.stdint cimport uint32_t, int8_t
from libc.stdlib cimport strtol
from libc.string cimport strchr
from libc.string cimport strlen
from libc.string cimport memset
import numpy as np
cimport numpy as cnp
from cpython cimport PyUnicode_AsUTF8String, PyBytes_GET_SIZE

import cython

@cython.boundscheck(False)
@cython.wraparound(False)
cdef uint32_t ip_to_int(char* ip):

cdef inline uint32_t ip_to_int(const char* ip):

# Check if the input string is at least 7 characters long
if strlen(ip) < 7:
raise ValueError("Invalid IP address: too short")

cdef uint32_t result = 0
cdef uint32_t num = 0
cdef int shift = 24 # Start with the leftmost byte
cdef int8_t shift = 24 # Start with the leftmost byte
cdef char* end

# Convert each part of the IP to an integer
Expand All @@ -39,8 +43,6 @@ cdef uint32_t ip_to_int(char* ip):

return result

@cython.boundscheck(False)
@cython.wraparound(False)
def ip_in_cidr(cnp.ndarray ip_addresses, str cidr):

# CIDR validation...
Expand All @@ -51,16 +53,20 @@ def ip_in_cidr(cnp.ndarray ip_addresses, str cidr):
cdef int mask_size
cdef str base_ip_str
cdef list cidr_parts = cidr.split('/')
cdef bytes ip_byte_string
cdef uint32_t arr_len = ip_addresses.shape[0]

base_ip_str, mask_size = cidr_parts[0], int(cidr_parts[1])
netmask = (0xFFFFFFFF << (32 - mask_size)) & 0xFFFFFFFF

base_ip = ip_to_int(base_ip_str.encode('utf-8'))
base_ip = ip_to_int(PyUnicode_AsUTF8String(base_ip_str))

cdef unsigned char[:] result = np.zeros(ip_addresses.shape[0], dtype=np.bool_)
cdef unsigned char[:] result = np.zeros(arr_len, dtype=np.bool_)

for i in range(ip_addresses.shape[0]):
ip_int = ip_to_int(ip_addresses[i].encode('utf-8'))
result[i] = (ip_int & netmask) == base_ip
for i in range(arr_len):
ip_address = ip_addresses[i]
if ip_address is not None:
ip_int = ip_to_int(PyUnicode_AsUTF8String(ip_address))
result[i] = (ip_int & netmask) == base_ip

return np.asarray(result, dtype=bool)
46 changes: 30 additions & 16 deletions opteryx/connectors/gcp_firestore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import Generator

from orso.schema import FlatColumn
from orso.schema import RelationSchema
from orso.types import OrsoTypes

from opteryx import config
from opteryx.connectors.base.base_connector import INITIAL_CHUNK_SIZE
from opteryx.connectors.base.base_connector import BaseConnector
from opteryx.connectors.capabilities import PredicatePushable
from opteryx.exceptions import DatasetNotFoundError
from opteryx.exceptions import MissingDependencyError
from opteryx.exceptions import UnmetRequirementError
Expand Down Expand Up @@ -52,37 +55,48 @@ def _get_project_id(): # pragma: no cover
def _initialize(): # pragma: no cover
"""Create the connection to Firebase"""
try:
import firebase_admin
from firebase_admin import credentials
from google.cloud import firestore
except ImportError as err: # pragma: no cover
raise MissingDependencyError(err.name) from err
if not firebase_admin._apps:
# if we've not been given the ID, fetch it
project_id = GCP_PROJECT_ID
if project_id is None:
project_id = _get_project_id()
creds = credentials.ApplicationDefault()
firebase_admin.initialize_app(creds, {"projectId": project_id, "httpTimeout": 10})

project_id = GCP_PROJECT_ID
if project_id is None:
project_id = _get_project_id()
return firestore.Client(project=project_id)

class GcpFireStoreConnector(BaseConnector):

class GcpFireStoreConnector(BaseConnector, PredicatePushable):
__mode__ = "Collection"
__type__ = "FIRESTORE"

PUSHABLE_OPS: Dict[str, bool] = {"Eq": True}

PUSHABLE_TYPES = {OrsoTypes.BOOLEAN, OrsoTypes.DOUBLE, OrsoTypes.INTEGER, OrsoTypes.VARCHAR}

def __init__(self, **kwargs):
BaseConnector.__init__(self, **kwargs)
PredicatePushable.__init__(self, **kwargs)

def read_dataset(
self, columns: list = None, chunk_size: int = INITIAL_CHUNK_SIZE, **kwargs
self,
columns: list = None,
chunk_size: int = INITIAL_CHUNK_SIZE,
predicates: list = None,
**kwargs,
) -> Generator:
"""
Return a morsel of documents
"""
from firebase_admin import firestore
from google.cloud.firestore_v1.base_query import FieldFilter

_initialize()
database = firestore.client()
database = _initialize()
documents = database.collection(self.dataset)

# for predicate in self._predicates:
# documents = documents.where(*predicate)
if predicates:
for predicate in predicates:
documents = documents.where(
filter=FieldFilter(predicate.left.source_column, "==", predicate.right.value)
)

documents = documents.stream()

Expand Down
2 changes: 1 addition & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ zstandard

# different storage providers
fastavro
firebase-admin
firestore
google-cloud-storage
google-cloud-bigquery-storage
minio
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements_arm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ polars

# different storage providers
fastavro
firebase-admin
firestore
sqlalchemy
pymysql
psycopg2-binary
Expand Down
3 changes: 2 additions & 1 deletion tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,7 @@
("SELECT '192.168.1.1' | '0'", None, None, IncorrectTypeError),
("SELECT name FROM $satellites WHERE '1' | '1'", None, None, IncorrectTypeError),
("SELECT name FROM $satellites WHERE 'abc' | '192.168.1.1'", None, None, IncorrectTypeError),
("SELECT name FROM $satellites WHERE null | '192.168.1.1/8'", 0, 1, None),
("SELECT name FROM $satellites WHERE 123 | '192.168.1.1'", None, None, IncorrectTypeError),
("SELECT name FROM $satellites WHERE '10.10.10.10' | '192.168.1.1'", 0, 1, IncorrectTypeError),
("SELECT name FROM $satellites WHERE 0 | 0", 0, 1, None),
Expand All @@ -1329,7 +1330,7 @@
("SELECT '192.168.1.*' | '192.168.1.1/8'", None, None, IncorrectTypeError),
("SELECT '!!' | '192.168.1.1/8'", None, None, IncorrectTypeError),
("SELECT null | '192.168.1.1'", 1, 1, IncorrectTypeError),
("SELECT null | '192.168.1.1/8'", 1, 1, IncorrectTypeError),
("SELECT null | '192.168.1.1/8'", 1, 1, None),

("SELECT * FROM testdata.flat.different", 196902, 15, None),
("SELECT * FROM testdata.flat.different WHERE following < 10", 7814, 15, None),
Expand Down
31 changes: 25 additions & 6 deletions tests/storage/test_collection_gcs_firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,33 @@ def test_firestore_storage():
cur.execute("SELECT actor, COUNT(*) FROM dwarves GROUP BY actor;")
assert cur.rowcount == 6, cur.rowcount

conn.close()

def test_predicate_pushdown():
opteryx.register_store("dwarves", GcpFireStoreConnector)
os.environ["GCP_PROJECT_ID"] = "mabeldev"

conn = opteryx.connect()

# TEST PREDICATE PUSHDOWN
# cur = conn.cursor()
# cur.execute("SELECT * FROM dwarves WHERE actor = 'Pinto Colvig';")
# # when pushdown is enabled, we only read the matching rows from the source
# assert cur.rowcount == 2, cur.rowcount
# assert cur.stats["rows_read"] == 2, cur.stats
cur = conn.cursor()
cur.execute("SELECT * FROM dwarves WHERE actor = 'Pinto Colvig';")
# when pushdown is enabled, we only read the matching rows from the source
assert cur.rowcount == 2, cur.rowcount
assert cur.stats["rows_read"] == 2, cur.stats

conn.close()
def test_predicate_pushdown_not_equals():
"""we don't push these, we get 5 records by Opteryx does the filtering not the source"""
opteryx.register_store("dwarves", GcpFireStoreConnector)
os.environ["GCP_PROJECT_ID"] = "mabeldev"

conn = opteryx.connect()

# TEST PREDICATE PUSHDOWN
cur = conn.cursor()
cur.execute("SELECT * FROM dwarves WHERE actor != 'Pinto Colvig';")
assert cur.rowcount == 5, cur.rowcount
assert cur.stats["rows_read"] == 7, cur.stats


if __name__ == "__main__": # pragma: no cover
Expand Down

0 comments on commit 35503db

Please sign in to comment.