diff --git a/CHANGELOG.md b/CHANGELOG.md index c1382ab..15a3b78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,9 +16,15 @@ functions provided by the `src/cscapi` folder. ### Changed -- **Breaking change**: Modify `StorageInterface::get_all_signals` to accept a `limit` and `offset` argument +- **Breaking change**: Modify `StorageInterface::get_all_signals` to accept a `limit` and `offset` arguments +- **Breaking change**: Change `StorageInterface::delete_signals` signature to require a list of signal ids +- **Breaking change**: Change `StorageInterface::delete_machines` signature to require a list of machine ids - **Breaking change**: `SQLStorage::get_all_signals` requires now a `limit` argument +### Removed + +- **Breaking change**: Remove `CAPIClient::_prune_sent_signals` method + --- diff --git a/src/cscapi/client.py b/src/cscapi/client.py index 3cc73ff..a69491f 100644 --- a/src/cscapi/client.py +++ b/src/cscapi/client.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import asdict, replace, dataclass from importlib import metadata -from typing import Dict, Iterable, List +from typing import Dict, Iterable, List, Tuple import httpx import jwt @@ -95,10 +95,13 @@ def prune_failing_machines_signals(self): if not signals: break - for machine_id, signals in _group_signals_by_machine_id(signals).items(): + for machine_id, grouped_signals in _group_signals_by_machine_id( + signals + ).items(): machine = self.storage.get_machine_by_id(machine_id) if machine.is_failing: - self.storage.delete_signals(signals) + signal_ids = [signal.alert_id for signal in grouped_signals] + self.storage.delete_signals(signal_ids) offset += SIGNAL_BATCH_LIMIT def send_signals(self, prune_after_send: bool = True): @@ -161,8 +164,9 @@ def _send_signals_by_machine_id( self.logger.info( f"sending signals for machine {machine_to_process.machine_id}" ) + sent_signal_ids = [] try: - self._send_signals( + sent_signal_ids = self._send_signals( machine_to_process.token, signals_by_machineid[machine_to_process.machine_id], ) @@ -179,11 +183,11 @@ def _send_signals_by_machine_id( machine_to_process.token = None retry_machines_to_process_attempts.append(machine_to_process) continue - if prune_after_send: + if prune_after_send and sent_signal_ids: self.logger.info( f"pruning sent signals for machine {machine_to_process.machine_id}" ) - self._prune_sent_signals() + self.storage.delete_signals(sent_signal_ids) self.logger.info( f"sending metrics for machine {machine_to_process.machine_id}" @@ -206,7 +210,8 @@ def _send_signals_by_machine_id( ) time.sleep(self.retry_delay) - def _send_signals(self, token: str, signals: SignalModel): + def _send_signals(self, token: str, signals: List[SignalModel]) -> List[int]: + result = [] for signal_batch in batched(signals, 250): body = [asdict(signal) for signal in signal_batch] resp = self.http_client.post( @@ -215,11 +220,17 @@ def _send_signals(self, token: str, signals: SignalModel): headers={"Authorization": token}, ) resp.raise_for_status() - self._mark_signals_as_sent(signal_batch) + result.extend(self._mark_signals_as_sent(signal_batch)) - def _mark_signals_as_sent(self, signals: List[SignalModel]): + return result + + def _mark_signals_as_sent(self, signals: Tuple[SignalModel]) -> List[int]: + result = [] for signal in signals: self.storage.update_or_create_signal(replace(signal, sent=True)) + result.append(signal.alert_id) + + return result def _send_metrics_for_machine(self, machine: MachineModel): for _ in range(self.max_retries + 1): @@ -246,19 +257,6 @@ def _send_metrics_for_machine(self, machine: MachineModel): f"received error {exc} while sending metrics for machine {machine.machine_id}" ) - def _prune_sent_signals(self): - offset = 0 - while True: - signals = self.storage.get_all_signals( - limit=SIGNAL_BATCH_LIMIT, offset=offset - ) - if not signals: - break - signals = list(filter(lambda signal: signal.sent, signals)) - - self.storage.delete_signals(signals) - offset += SIGNAL_BATCH_LIMIT - def _refresh_machine_token(self, machine: MachineModel) -> MachineModel: machine.scenarios = self.scenarios resp = self.http_client.post( diff --git a/src/cscapi/sql_storage.py b/src/cscapi/sql_storage.py index a18ced0..427eaba 100644 --- a/src/cscapi/sql_storage.py +++ b/src/cscapi/sql_storage.py @@ -237,16 +237,12 @@ def update_or_create_signal(self, signal: storage.SignalModel) -> bool: return False - def delete_signals(self, signals: List[storage.SignalModel]): - stmt = delete(SignalDBModel).where( - SignalDBModel.alert_id.in_((signal.alert_id for signal in signals)) - ) + def delete_signals(self, signal_ids: List[int]): + stmt = delete(SignalDBModel).where(SignalDBModel.alert_id.in_(signal_ids)) with self.session.begin() as session: session.execute(stmt) - def delete_machines(self, machines: List[storage.MachineModel]): - stmt = delete(MachineDBModel).where( - MachineDBModel.machine_id.in_((machine.machine_id for machine in machines)) - ) + def delete_machines(self, machine_ids: List[str]): + stmt = delete(MachineDBModel).where(MachineDBModel.machine_id.in_(machine_ids)) with self.session.begin() as session: session.execute(stmt) diff --git a/src/cscapi/storage.py b/src/cscapi/storage.py index ded4f09..8c30352 100644 --- a/src/cscapi/storage.py +++ b/src/cscapi/storage.py @@ -105,9 +105,9 @@ def update_or_create_signal(self, signal: SignalModel) -> bool: raise NotImplementedError @abstractmethod - def delete_signals(self, signals: List[SignalModel]): + def delete_signals(self, signal_ids: List[int]): raise NotImplementedError @abstractmethod - def delete_machines(self, machines: List[MachineModel]): + def delete_machines(self, machine_ids: List[str]): raise NotImplementedError