Skip to content

Commit

Permalink
Merge pull request #23 from mohamadkhalaj/main
Browse files Browse the repository at this point in the history
Fix issue #20
  • Loading branch information
seyed-dev authored Oct 31, 2023
2 parents d6e35c9 + 39bdb40 commit 0e94799
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 109 deletions.
127 changes: 82 additions & 45 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
from typing import Any, Literal, Dict, Type
from typing import Any, Dict, Type

from mongoengine import Document, EmbeddedDocument, fields
from mongoengine.base import TopLevelDocumentMetaclass

from aggify.compiler import F, Match, Q, Operators # noqa keep
from aggify.exceptions import (
Expand All @@ -10,13 +11,16 @@
InvalidField,
InvalidEmbeddedField,
OutStageError,
InvalidArgument,
)
from aggify.types import QueryParams
from aggify.utilty import (
to_mongo_positive_index,
check_fields_exist,
replace_values_recursive,
convert_match_query,
check_field_exists,
get_db_field,
)


Expand Down Expand Up @@ -174,12 +178,17 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
split_query = key.split("__")

# Retrieve the field definition from the model.
join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore # noqa

join_field = self.get_model_field(self.base_model, split_query[0]) # type: ignore
# Check conditions for creating a 'match' pipeline stage.
if (
"document_type_obj" not in join_field.__dict__
or issubclass(join_field.document_type, EmbeddedDocument)
isinstance(
join_field, TopLevelDocumentMetaclass
) # check whether field is added by lookup stage or not
or "document_type_obj"
not in join_field.__dict__ # Check whether this field is a join field or not.
or issubclass(
join_field.document_type, EmbeddedDocument
) # Check whether this field is embedded field or not
or len(split_query) == 1
or (len(split_query) == 2 and split_query[1] in Operators.ALL_OPERATORS)
):
Expand All @@ -191,7 +200,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
self.pipelines.append(match)

else:
from_collection = join_field.document_type # noqa
from_collection = join_field.document_type
local_field = join_field.db_field
as_name = join_field.name
matches = []
Expand All @@ -210,7 +219,7 @@ def __to_aggregate(self, query: dict[str, Any]) -> None:
as_name=as_name,
)
)
self.unwind(as_name)
self.unwind(as_name, preserve=True)
self.pipelines.extend([{"$match": match} for match in matches])

@last_out_stage_check
Expand Down Expand Up @@ -356,7 +365,7 @@ def annotate(
else:
if isinstance(f, str):
try:
self.get_model_field(self.base_model, f) # noqa
self.get_model_field(self.base_model, f)
value = f"${f}"
except InvalidField:
value = f
Expand Down Expand Up @@ -429,66 +438,94 @@ def __combine_sequential_matches(self) -> list[dict[str, dict | Any]]:

@last_out_stage_check
def lookup(
self, from_collection: Document, let: list[str], query: list[Q], as_name: str
self,
from_collection: Document,
as_name: str,
query: list[Q] | Q | None = None,
let: list[str] | None = None,
local_field: str | None = None,
foreign_field: str | None = None,
) -> "Aggify":
"""
Generates a MongoDB lookup pipeline stage.
Args:
from_collection (Document): The name of the collection to lookup.
let (list): The local field(s) to join on.
query (list[Q]): List of desired queries with Q function.
from_collection (Document): The document representing the collection to perform the lookup on.
as_name (str): The name of the new field to create.
query (list[Q] | Q | None, optional): List of desired queries with Q function or a single query.
let (list[str] | None, optional): The local field(s) to join on. If provided, localField and foreignField are not used.
local_field (str | None, optional): The local field to join on when let is not provided.
foreign_field (str | None, optional): The foreign field to join on when let is not provided.
Returns:
Aggify: A MongoDB lookup pipeline stage.
Aggify: An instance of the Aggify class representing a MongoDB lookup pipeline stage.
"""
check_fields_exist(self.base_model, let) # noqa

let_dict = {
field: f"${self.base_model._fields[field].db_field}"
for field in let # noqa
}
from_collection = from_collection._meta.get("collection") # noqa

lookup_stages = []
check_field_exists(self.base_model, as_name)
from_collection_name = from_collection._meta.get("collection") # noqa

for q in query:
# Construct the match stage for each query
if isinstance(q, Q):
replaced_values = replace_values_recursive(
convert_match_query(dict(q)), # noqa
{field: f"$${field}" for field in let},
)
match_stage = {"$match": {"$expr": replaced_values.get("$match")}}
lookup_stages.append(match_stage)
elif isinstance(q, Aggify):
lookup_stages.extend(
replace_values_recursive(
convert_match_query(q.pipelines), # noqa
if not let and not (local_field and foreign_field):
raise InvalidArgument(
expected_list=[["local_field", "foreign_field"], "let"]
)
elif not let:
if not (local_field and foreign_field):
raise InvalidArgument(expected_list=["local_field", "foreign_field"])
lookup_stage = {
"$lookup": {
"from": from_collection_name,
"localField": get_db_field(self.base_model, local_field), # noqa
"foreignField": get_db_field(
from_collection, foreign_field
), # noqa
"as": as_name,
}
}
else:
if not query:
raise InvalidArgument(expected_list=["query"])
check_fields_exist(self.base_model, let) # noqa
let_dict = {
field: f"${get_db_field(self.base_model, field)}"
for field in let # noqa
}
for q in query:
# Construct the match stage for each query
if isinstance(q, Q):
replaced_values = replace_values_recursive(
convert_match_query(dict(q)),
{field: f"$${field}" for field in let},
)
)
match_stage = {"$match": {"$expr": replaced_values.get("$match")}}
lookup_stages.append(match_stage)
elif isinstance(q, Aggify):
lookup_stages.extend(
replace_values_recursive(
convert_match_query(q.pipelines), # noqa
{field: f"$${field}" for field in let},
)
)

# Append the lookup stage with multiple match stages to the pipeline
lookup_stage = {
"$lookup": {
"from": from_collection,
"let": let_dict,
"pipeline": lookup_stages, # List of match stages
"as": as_name,
# Append the lookup stage with multiple match stages to the pipeline
lookup_stage = {
"$lookup": {
"from": from_collection_name,
"let": let_dict,
"pipeline": lookup_stages, # List of match stages
"as": as_name,
}
}
}

self.pipelines.append(lookup_stage)

# Add this new field to base model fields, which we can use it in the next stages.
self.base_model._fields[as_name] = fields.StringField() # noqa
self.base_model._fields[as_name] = from_collection # noqa

return self

@staticmethod
def get_model_field(model: Document, field: str) -> fields:
def get_model_field(model: Type[Document], field: str) -> fields:
"""
Get the field definition of a specified field in a MongoDB model.
Expand Down Expand Up @@ -520,7 +557,7 @@ def _replace_base(self, embedded_field) -> str:
Raises:
InvalidEmbeddedField: If the specified embedded field is not found or is not of the correct type.
"""
model_field = self.get_model_field(self.base_model, embedded_field) # noqa
model_field = self.get_model_field(self.base_model, embedded_field)

if not hasattr(model_field, "document_type") or not issubclass(
model_field.document_type, EmbeddedDocument
Expand Down
26 changes: 20 additions & 6 deletions aggify/compiler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Type

from mongoengine import Document, EmbeddedDocumentField
from mongoengine.base import TopLevelDocumentMetaclass

from aggify.exceptions import InvalidOperator
from aggify.utilty import get_db_field


class Operators:
Expand Down Expand Up @@ -232,15 +234,27 @@ def validate_operator(key: str):
raise InvalidOperator(operator)

def is_base_model_field(self, field) -> bool:
return self.base_model is not None and isinstance(
self.base_model._fields.get(field), # type: ignore # noqa
EmbeddedDocumentField,
"""
Check if a field in the base model class is of a specific type.
EmbeddedDocumentField: Field which is embedded.
TopLevelDocumentMetaclass: Field which is added by lookup stage.
Args:
field (str): The name of the field to check.
Returns:
bool: True if the field is of type EmbeddedDocumentField or TopLevelDocumentMetaclass
and the base_model is not None, otherwise False.
"""
return self.base_model is not None and (
isinstance(self.base_model._fields.get(field), (EmbeddedDocumentField, TopLevelDocumentMetaclass)) # noqa
)

def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
match_query = {}
for key, value in self.matches.items():
if "__" not in key:
key = get_db_field(self.base_model, key)
match_query[key] = value
continue

Expand All @@ -249,7 +263,7 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
raise InvalidOperator(key)

field, operator, *_ = key.split("__")
if self.is_base_model_field(field):
if self.is_base_model_field(field) and operator not in Operators.ALL_OPERATORS:
pipelines.append(
Match({key.replace("__", ".", 1): value}, self.base_model).compile(
[]
Expand All @@ -259,7 +273,7 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]:

if operator not in Operators.ALL_OPERATORS:
raise InvalidOperator(operator)

match_query = Operators(match_query).compile_match(operator, value, field)
db_field = get_db_field(self.base_model, field)
match_query = Operators(match_query).compile_match(operator, value, db_field)

return {"$match": match_query}
13 changes: 13 additions & 0 deletions aggify/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,16 @@ class InvalidEmbeddedField(AggifyBaseException):
def __init__(self, field: str):
self.message = f"Field {field} is not embedded."
super().__init__(self.message)


class AlreadyExistsField(AggifyBaseException):
def __init__(self, field: str):
self.message = f"Field {field} already exists."
super().__init__(self.message)


class InvalidArgument(AggifyBaseException):
def __init__(self, expected_list: list):
self.message = f"Input is not correctly passed, expected {[expected for expected in expected_list]}"
self.expecteds = expected_list
super().__init__(self.message)
37 changes: 35 additions & 2 deletions aggify/utilty.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any
from typing import Any, Type

from mongoengine import Document

from aggify.exceptions import MongoIndexError, InvalidField
from aggify.exceptions import MongoIndexError, InvalidField, AlreadyExistsField


def int_to_slice(final_index: int) -> slice:
Expand Down Expand Up @@ -113,3 +113,36 @@ def convert_match_query(
return [convert_match_query(item) for item in d]
else:
return d


def check_field_exists(model: Type[Document], field: str) -> None:
"""
Check if a field exists in the given model.
Args:
model (Document): The model to check for the field.
field (str): The name of the field to check.
Raises:
AlreadyExistsField: If the field already exists in the model.
"""
if model._fields.get(field): # noqa
raise AlreadyExistsField(field=field)


def get_db_field(model: Type[Document], field: str) -> str:
"""
Get the database field name for a given field in the model.
Args:
model (Document): The model containing the field.
field (str): The name of the field.
Returns:
str: The database field name if available, otherwise the original field name.
"""
try:
db_field = model._fields.get(field).db_field # noqa
return field if db_field is None else db_field
except AttributeError:
return field
Loading

0 comments on commit 0e94799

Please sign in to comment.