From dcf2219b0bff647f45f14dc9897aae090a50fc49 Mon Sep 17 00:00:00 2001 From: fgibertoni <152909479+fgibertoni@users.noreply.github.com> Date: Fri, 29 Mar 2024 02:23:22 +0100 Subject: [PATCH] Add MongoDB storage implementation (#27) * Added MongoDB Storage Interface implementation * Added log level to INFO * Added mongoengine as optional dependency * Fixed tests and removed version from dependency * Removed connection string * Refactoring get_signals method * Changed requirements to remove connect from library * Added connection string to __init__ * Added table name for consistency * Improved update_or_create_signal method --- setup.cfg | 4 + src/cscapi/mongodb_storage.py | 211 ++++++++++++++++++++++++++ tests/test_mongodb_storage.py | 276 ++++++++++++++++++++++++++++++++++ 3 files changed, 491 insertions(+) create mode 100644 src/cscapi/mongodb_storage.py create mode 100644 tests/test_mongodb_storage.py diff --git a/setup.cfg b/setup.cfg index 43fe644..ba181f3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,5 +32,9 @@ install_requires = pyjwt more-itertools +[options.extras_require] +mongodb = + mongoengine + [options.packages.find] where = src diff --git a/src/cscapi/mongodb_storage.py b/src/cscapi/mongodb_storage.py new file mode 100644 index 0000000..d3e1bcf --- /dev/null +++ b/src/cscapi/mongodb_storage.py @@ -0,0 +1,211 @@ +import logging +from dataclasses import asdict +from typing import List, Optional + +from dacite import from_dict +from mongoengine import ( + ConnectionFailure, + Document, + EmbeddedDocument, + Q, + connect, + fields, +) + +from cscapi.storage import MachineModel, SignalModel, StorageInterface + +logger = logging.getLogger(__name__) + + +class ContextDBModel(EmbeddedDocument): + value = fields.StringField() + key = fields.StringField() + + +class DecisionDBModel(EmbeddedDocument): + duration = fields.StringField() + uuid = fields.StringField() + scenario = fields.StringField() + origin = fields.StringField() + scope = fields.StringField() + simulated = fields.BooleanField() + until = fields.StringField() + type = fields.StringField() + value = fields.StringField() + + +class SourceDBModel(EmbeddedDocument): + scope = fields.StringField() + ip = fields.StringField() + latitude = fields.FloatField() + as_number = fields.StringField() + range = fields.StringField() + cn = fields.StringField() + value = fields.StringField() + as_name = fields.StringField() + longitude = fields.FloatField() + + +class SignalDBModel(Document): + alert_id = fields.SequenceField(unique=True) + created_at = fields.StringField() + machine_id = fields.StringField(max_length=128) + scenario_version = fields.StringField(null=True) + message = fields.StringField(null=True) + uuid = fields.StringField() + start_at = fields.StringField(null=True) + scenario_trust = fields.StringField(null=True) + scenario_hash = fields.StringField(null=True) + scenario = fields.StringField(null=True) + stop_at = fields.StringField(null=True) + sent = fields.BooleanField(default=False) + context = fields.EmbeddedDocumentListField(ContextDBModel) + decisions = fields.EmbeddedDocumentListField(DecisionDBModel) + source = fields.EmbeddedDocumentField(SourceDBModel) + + meta = {"collection": "signal_models"} + + +class MachineDBModel(Document): + machine_id = fields.StringField(max_length=128, unique=True) + token = fields.StringField() + password = fields.StringField() + scenarios = fields.StringField() + is_failing = fields.BooleanField(default=False) + + meta = {"collection": "machine_models"} + + +class MongoDBStorage(StorageInterface): + def __init__(self, connection_string="mongodb://127.0.0.1:27017/cscapi"): + try: + connect( + host="mongodb://127.0.0.1:27017/cscapi", + connect=False, + uuidRepresentation="standard", + ) + except ConnectionFailure: + logger.info( + "There is already an existing connection to MongoDB. Using that as default." + ) + + def mass_update_signals(self, signal_ids: List[int], changes: dict): + SignalDBModel.objects.filter(alert_id__in=signal_ids).update(**changes) + + def get_signals( + self, + limit: int, + offset: int = 0, + sent: Optional[bool] = None, + is_failing: Optional[bool] = None, + ) -> List[SignalModel]: + join_name = "joined" + filter_sent = Q() + filter_is_failing = {} + + if sent is not None: + if sent: + filter_sent = Q(sent=True) + else: + filter_sent = Q(sent=False) | Q(sent=None) + + if is_failing is not None: + if is_failing: + filter_is_failing = {"$match": {"is_failing": True}} + else: + filter_is_failing = { + "$match": {"$or": [{"is_failing": False}, {"is_failing": None}]} + } + + pipeline = [ + { # performs a left outer join and return an object called as join_name + "$lookup": { + "from": MachineDBModel._get_collection_name(), + "localField": "machine_id", + "foreignField": "machine_id", + "as": join_name, + } + }, + { # if a machine isn't found, fill the value of is_failing with None at object root level + # otherwise copy the value of the attribute from the matching machine to root level + "$set": { + "is_failing": { + "$cond": { + "if": {"$eq": [{"$size": f"${join_name}"}, 0]}, + "then": None, + "else": {"$arrayElemAt": [f"${join_name}.is_failing", 0]}, + } + } + } + }, + ] + if filter_is_failing: + pipeline.append(filter_is_failing) + pipeline.extend([{"$limit": limit + offset}, {"$skip": offset}]) + + results = SignalDBModel.objects.filter(filter_sent).aggregate(pipeline) + return [from_dict(SignalModel, res) for res in results] + + def get_machine_by_id(self, machine_id: str) -> Optional[MachineModel]: + machine = MachineDBModel.objects.filter(machine_id=machine_id).first() + return from_dict(MachineModel, machine) if machine else None + + def update_or_create_machine(self, machine: MachineModel) -> bool: + try: + result = MachineDBModel.objects.get(machine_id=machine.machine_id) + except MachineDBModel.DoesNotExist: + MachineDBModel.objects.create(**asdict(machine)) + return True + else: + result.update(**asdict(machine)) + return False + + def update_or_create_signal(self, signal: SignalModel) -> bool: + # Filter out the keys for embedded documents to handle them separately + signal_filtered = { + k: v + for k, v in asdict(signal).items() + if k not in ["source", "context", "decisions"] + } + + created = False # Flag to track if a new document is created + + # Only proceed if alert_id is not None + if signal.alert_id is not None: + try: + signal_db_model = SignalDBModel.objects.get(alert_id=signal.alert_id) + logger.info(signal_filtered) + signal_db_model.update(**signal_filtered) + except SignalDBModel.DoesNotExist: + # If it doesn't exist, create a new one with all the data including alert_id + signal_db_model = SignalDBModel(**signal_filtered) + created = True + else: + # If alert_id is None, directly create a new model with the provided data + signal_db_model = SignalDBModel(**signal_filtered) + created = True + + # Update or set the source, context, and decisions fields + if signal.source: + signal_db_model.source = SourceDBModel(**asdict(signal.source)) + + if signal.context: + signal_db_model.context = [ + ContextDBModel(**asdict(ctx)) for ctx in signal.context + ] + + if signal.decisions: + signal_db_model.decisions = [ + DecisionDBModel(**asdict(dec)) for dec in signal.decisions + ] + + # Save the model (works for both creating a new document and updating an existing one) + signal_db_model.save() + + return created + + def delete_signals(self, signal_ids: List[int]): + SignalDBModel.objects.filter(alert_id__in=signal_ids).delete() + + def delete_machines(self, machine_ids: List[str]): + MachineDBModel.objects.filter(machine_id__in=machine_ids).delete() diff --git a/tests/test_mongodb_storage.py b/tests/test_mongodb_storage.py new file mode 100644 index 0000000..e2145e8 --- /dev/null +++ b/tests/test_mongodb_storage.py @@ -0,0 +1,276 @@ +import random +import time +from unittest import TestCase + +from dacite import from_dict +from mongoengine import disconnect + +from cscapi.client import CAPIClient, CAPIClientConfig +from cscapi.mongodb_storage import MachineDBModel, MongoDBStorage, SignalDBModel +from cscapi.storage import MachineModel, SignalModel, SourceModel + + +def mock_signals(): + return [ + from_dict(SignalModel, z) + for z in [ + { + "decisions": [ + { + "duration": "59m49.264032632s", + "id": random.randint(0, 100000), + "origin": "crowdsec", + "scenario": "crowdsecurity/ssh-bf", + "scope": "Ip", + "simulated": False, + "type": "ban", + "value": "1.1.1.172", + } + ], + "context": [ + {"key": "target_user", "value": "netflix"}, + {"key": "service", "value": "ssh"}, + {"key": "target_user", "value": "netflix"}, + {"key": "service", "value": "ssh"}, + ], + "uuid": "1", + "machine_id": "test", + "message": "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761", + "scenario": "crowdsecurity/ssh-bf", + "scenario_hash": "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f", + "scenario_version": "0.1", + "scenario_trust": "trusted", + "source": { + "as_name": "Cloudflare Inc", + "cn": "AU", + "ip": "1.1.1.172", + "latitude": -37.7, + "longitude": 145.1833, + "range": "1.1.1.0/24", + "scope": "Ip", + "value": "1.1.1.172", + }, + "start_at": "2020-11-28 10:20:46.842701127 +0100 +0100", + "stop_at": "2020-11-28 10:20:46.845621385 +0100 +0100", + "created_at": "2020-11-28T10:20:47+01:00", + } + ] + ] + + +class TestMongoDBStorage(TestCase): + storage = None + + @classmethod + def setUpClass(cls): + cls.storage: MongoDBStorage = MongoDBStorage() + cls.client = CAPIClient( + cls.storage, + CAPIClientConfig( + scenarios=["crowdsecurity/http-bf", "crowdsecurity/ssh-bf"], + max_retries=1, + retry_delay=0, + ), + ) + + @classmethod + def tearDownClass(cls): + disconnect() + + def setUp(self): + SignalDBModel.objects.all().delete() + MachineDBModel.objects.all().delete() + + def tearDown(self): + SignalDBModel.objects.all().delete() + MachineDBModel.objects.all().delete() + + def test_get_signals_with_no_machine(self): + self.assertEqual(len(self.storage.get_signals(limit=1000)), 0) + for x in range(10): + self.client.add_signals(mock_signals()) + time.sleep(0.05) + self.assertEqual(len(self.storage.get_signals(limit=1000)), 10) + self.assertEqual(len(self.storage.get_signals(limit=5)), 5) + self.assertEqual(len(self.storage.get_signals(limit=5, offset=8)), 2) + self.assertEqual(len(self.storage.get_signals(limit=1000, sent=True)), 0) + self.assertEqual(len(self.storage.get_signals(limit=1000, sent=False)), 10) + self.assertEqual(len(self.storage.get_signals(limit=1000, is_failing=True)), 0) + self.assertEqual( + len(self.storage.get_signals(limit=1000, is_failing=False)), 10 + ) + self.assertEqual( + len(self.storage.get_signals(limit=1000, sent=False, is_failing=False)), 10 + ) + self.assertEqual( + len(self.storage.get_signals(limit=1000, sent=True, is_failing=False)), 0 + ) + + def test_get_signals_with_machine(self): + m1 = MachineModel( + machine_id="test", # Same machine_id as in mock_signals + token="1", + password="1", + scenarios="crowdsecurity/http-probing", + ) + self.assertTrue(self.storage.update_or_create_machine(m1)) + self.assertEqual(len(self.storage.get_signals(limit=1000)), 0) + for x in range(10): + self.client.add_signals(mock_signals()) + time.sleep(0.05) + self.assertEqual(len(self.storage.get_signals(limit=1000)), 10) + self.assertEqual(len(self.storage.get_signals(limit=5)), 5) + self.assertEqual(len(self.storage.get_signals(limit=5, offset=8)), 2) + self.assertEqual(len(self.storage.get_signals(limit=1000, sent=True)), 0) + self.assertEqual(len(self.storage.get_signals(limit=1000, sent=False)), 10) + self.assertEqual(len(self.storage.get_signals(limit=1000, is_failing=True)), 0) + self.assertEqual( + len(self.storage.get_signals(limit=1000, is_failing=False)), 10 + ) + self.assertEqual( + len(self.storage.get_signals(limit=1000, sent=False, is_failing=False)), 10 + ) + self.assertEqual( + len(self.storage.get_signals(limit=1000, sent=True, is_failing=False)), 0 + ) + + def test_get_signals_with_failing_machine(self): + m1 = MachineModel( + machine_id="test", # Same machine_id as in mock_signals + token="1", + password="1", + scenarios="crowdsecurity/http-probing", + is_failing=True, + ) + self.assertTrue(self.storage.update_or_create_machine(m1)) + self.assertEqual(len(self.storage.get_signals(limit=1000)), 0) + for x in range(10): + self.client.add_signals(mock_signals()) + time.sleep(0.05) + self.assertEqual(len(self.storage.get_signals(limit=1000)), 10) + self.assertEqual(len(self.storage.get_signals(limit=5)), 5) + self.assertEqual(len(self.storage.get_signals(limit=5, offset=8)), 2) + self.assertEqual(len(self.storage.get_signals(limit=1000, sent=True)), 0) + self.assertEqual(len(self.storage.get_signals(limit=1000, sent=False)), 10) + self.assertEqual(len(self.storage.get_signals(limit=1000, is_failing=True)), 10) + self.assertEqual(len(self.storage.get_signals(limit=1000, is_failing=False)), 0) + self.assertEqual( + len(self.storage.get_signals(limit=1000, sent=False, is_failing=False)), 0 + ) + self.assertEqual( + len(self.storage.get_signals(limit=1000, sent=True, is_failing=False)), 0 + ) + self.assertEqual( + len(self.storage.get_signals(limit=1000, sent=True, is_failing=True)), 0 + ) + + def test_create_and_retrieve_machine(self): + m1 = MachineModel( + machine_id="1", + token="1", + password="1", + scenarios="crowdsecurity/http-probing", + ) + + # Should return true if db row is created, else return false + self.assertTrue(self.storage.update_or_create_machine(m1)) + self.assertFalse(self.storage.update_or_create_machine(m1)) + + retrieved = self.storage.get_machine_by_id("1") + + self.assertEqual(retrieved.machine_id, m1.machine_id) + self.assertEqual(retrieved.token, m1.token) + self.assertEqual(retrieved.password, m1.password) + self.assertEqual(retrieved.scenarios, m1.scenarios) + + def test_update_machine(self): + m1 = MachineModel( + machine_id="1", + token="1", + password="1", + scenarios="crowdsecurity/http-probing", + ) + self.storage.update_or_create_machine(m1) + + retrieved = self.storage.get_machine_by_id("1") + + self.assertEqual(retrieved.machine_id, m1.machine_id) + self.assertEqual(retrieved.token, m1.token) + self.assertEqual(retrieved.password, m1.password) + self.assertEqual(retrieved.scenarios, m1.scenarios) + + m2 = MachineModel( + machine_id="1", token="2", password="2", scenarios="crowdsecurity/http-bf" + ) + self.storage.update_or_create_machine(m2) + self.assertEqual(1, MachineDBModel.objects.count()) + + retrieved = self.storage.get_machine_by_id("1") + + self.assertEqual(retrieved.machine_id, m2.machine_id) + self.assertEqual(retrieved.token, m2.token) + self.assertEqual(retrieved.password, m2.password) + self.assertEqual(retrieved.scenarios, m2.scenarios) + + def test_create_signal(self): + self.assertEqual(self.storage.get_signals(limit=1000), []) + self.storage.update_or_create_signal(mock_signals()[0]) + signals = self.storage.get_signals(limit=1000) + self.assertEqual(len(signals), 1) + signal = signals[0] + + self.assertIsNotNone(signal.alert_id) + self.assertFalse(signal.sent) + + self.assertEqual(SignalDBModel.objects.count(), 1) + self.assertEqual(len(signal.context), 4) + + self.assertEqual(len(signal.decisions), 1) + + self.assertTrue(isinstance(signal.source, SourceModel)) + + def test_update_signal(self): + self.assertEqual(self.storage.get_signals(limit=1000), []) + + to_insert = mock_signals()[0] + self.storage.update_or_create_signal(to_insert) + signals = self.storage.get_signals(limit=1000) + + self.assertEqual(len(signals), 1) + signal = signals[0] + + self.assertFalse(signal.sent) + + signal.sent = True + + self.storage.update_or_create_signal(signal) + signals = self.storage.get_signals(limit=1000) + + self.assertEqual(len(signals), 1) + signal = signals[0] + + self.assertTrue(signal.sent) + + def test_mass_update_signals(self): + self.assertEqual(self.storage.get_signals(limit=1000), []) + + for x in range(10): + self.storage.update_or_create_signal(mock_signals()[0]) + + signals = self.storage.get_signals(limit=1000) + + self.assertEqual(len(signals), 10) + for s in signals: + self.assertFalse(s.sent) + self.assertEqual(s.scenario_trust, "trusted") + signal_ids = [s.alert_id for s in signals] + self.storage.mass_update_signals( + signal_ids, {"sent": True, "scenario_trust": "manual"} + ) + + signals = self.storage.get_signals(limit=1000) + + self.assertEqual(len(signals), 10) + for s in signals: + self.assertTrue(s.sent) + self.assertEqual(s.scenario_trust, "manual")