diff --git a/aggify/aggify.py b/aggify/aggify.py index e4956c2..be28cab 100644 --- a/aggify/aggify.py +++ b/aggify/aggify.py @@ -309,7 +309,7 @@ def __getitem__(self, index: Union[slice, int]) -> "Aggify": def unwind( self, path: str, - include_index_array: Union[str, None] = None, + include_array_index: Union[str, None] = None, preserve: bool = False, ) -> "Aggify": """Generates a MongoDB unwind pipeline stage. @@ -319,7 +319,7 @@ def unwind( To specify a field path, prefix the field name with a dollar sign $ and enclose in quotes. - include_index_array: The name of a new field to hold the array index of the element. + include_array_index: The name of a new field to hold the array index of the element. The name cannot start with a dollar sign $. preserve: Whether to preserve null and empty arrays. @@ -344,18 +344,17 @@ def unwind( docs: https://www.mongodb.com/docs/manual/reference/operator/aggregation/unwind/ """ path = self.get_field_name_recursively(path) - if include_index_array is None and preserve is False: - self.pipelines.append({"$unwind": f"${path}"}) - return self - self.pipelines.append( - { - "$unwind": { - "path": f"${path}", - "includeArrayIndex": include_index_array, - "preserveNullAndEmptyArrays": preserve, - } - } - ) + if include_array_index is None and preserve is False: + unwind_stage = {"$unwind": f"${path}"} + else: + unwind_stage = {"$unwind": {"path": f"${path}"}} + if preserve: + unwind_stage["$unwind"]["preserveNullAndEmptyArrays"] = preserve + if include_array_index: + unwind_stage["$unwind"][ + "includeArrayIndex" + ] = include_array_index.replace("$", "") + self.pipelines.append(unwind_stage) return self def aggregate(self): diff --git a/aggify/exceptions.py b/aggify/exceptions.py index ff01758..951a123 100644 --- a/aggify/exceptions.py +++ b/aggify/exceptions.py @@ -28,7 +28,9 @@ class AnnotationError(InvalidPipelineStageError): class OutStageError(InvalidPipelineStageError): def __init__(self, stage): - self.message = f"You cannot add a {self!r} pipeline after $out stage! stage : {stage}" + self.message = ( + f"You cannot add a {self!r} pipeline after $out stage! stage : {stage}" + ) super().__init__(self.message) @@ -45,9 +47,7 @@ def __init__(self, expected_list: List[Type], result: Type): class InvalidOperator(AggifyBaseException): def __init__(self, operator: str): - self.message = ( - f"Operator {operator} does not exists, please refer to documentation to see all supported operators." - ) + self.message = f"Operator {operator} does not exists, please refer to documentation to see all supported operators." super().__init__(self.message) diff --git a/tests/test_aggify.py b/tests/test_aggify.py index 734b89f..57c5734 100644 --- a/tests/test_aggify.py +++ b/tests/test_aggify.py @@ -8,6 +8,9 @@ OutStageError, InvalidArgument, InvalidField, + InvalidOperator, + AlreadyExistsField, + InvalidEmbeddedField, ) @@ -412,16 +415,16 @@ def test_unwind_just_path(self): @pytest.mark.parametrize( "params", ( - {"include_index_array": "Mahdi"}, + {"include_array_index": "Mahdi"}, {"preserve": True}, - {"include_index_array": "Mahdi", "preserve": True}, + {"include_array_index": "Mahdi", "preserve": True}, ), ) def test_unwind_with_parameters(self, params): aggify = Aggify(BaseModel) thing = aggify.unwind("name", **params) assert len(thing.pipelines) == 1 - include = params.get("include_index_array") + include = params.get("include_array_index") preserve = params.get("preserve") if include is not None: assert thing.pipelines[-1]["$unwind"]["includeArrayIndex"] == "Mahdi" @@ -538,3 +541,59 @@ def test_unwind_invalid_field(self): aggify = Aggify(BaseModel) with pytest.raises(InvalidField): aggify.unwind("invalid") + + def test_in_operator(self): + thing = list(Aggify(BaseModel).filter(name__in=[])) + assert thing[0]["$match"] == {"name": {"$in": []}} + + def test_nin_operator(self): + thing = list(Aggify(BaseModel).filter(name__nin=[])) + assert thing[0]["$match"] == {"name": {"$nin": []}} + + def test_eq_operator(self): + thing = list(Aggify(BaseModel).filter(name__exact=[])) + assert thing[0]["$match"] == {"name": {"$eq": []}} + + def test_invalid_operator(self): + aggify = Aggify(BaseModel) + with pytest.raises(InvalidOperator): + aggify.filter(name__aggify="test") + + def test_lookup_with_duplicate_as_name(self): + aggify = Aggify(BaseModel) + with pytest.raises(AlreadyExistsField): + aggify.lookup( + BaseModel, local_field="name", foreign_field="name", as_name="name" + ) + + def test_project_delete_id(self): + thing = list(Aggify(BaseModel).project(id=0)) + assert thing[0]["$project"] == {"_id": 0} + + def test_add_field_list_as_expression(self): + thing = list(Aggify(BaseModel).add_fields(new=[])) + assert thing[0]["$addFields"] == {"new": []} + + def test_add_field_cond_as_expression(self): + thing = list(Aggify(BaseModel).add_fields(new=Cond("name", "==", "name", 0, 1))) + assert thing[0]["$addFields"] == { + "new": {"$cond": {"if": {"$eq": ["name", "name"]}, "then": 0, "else": 1}} + } + + def test_annotate_int_field(self): + thing = list(Aggify(BaseModel).group("name").annotate("name", "first", 2)) + assert thing[0]["$group"] == {"_id": "$name", "name": {"$first": 2}} + + def test_sequential_matches_combine(self): + thing = list(Aggify(BaseModel).filter(name=123).filter(age=123)) + assert thing[0]["$match"] == {"name": 123, "age": 123} + + def test_get_model_field_invalid_field(self): + aggify = Aggify(BaseModel) + with pytest.raises(InvalidField): + aggify.get_model_field(BaseModel, "tttttt") + + def test_replace_base_invalid_embedded_field(self): + aggify = Aggify(BaseModel) + with pytest.raises(InvalidEmbeddedField): + aggify._replace_base("name") diff --git a/tests/test_query.py b/tests/test_query.py index c47f515..0bce2c7 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -72,7 +72,6 @@ class ParameterTestCase: }, { "$unwind": { - "includeArrayIndex": None, "path": "$owner_id", "preserveNullAndEmptyArrays": True, }