Skip to content

Commit

Permalink
Use dataclass methods in custom ratelimits and fix tests (#5036)
Browse files Browse the repository at this point in the history
# What this PR does

Follow up PR for #5004
Tests haven’t caught a bug, so the method and the tests are fixed

## Which issue(s) this PR closes

Related to [issue link here]

<!--
*Note*: If you want the issue to be auto-closed once the PR is merged,
change "Related to" to "Closes" in the line above.
If you have more than one GitHub issue that this PR closes, be sure to
preface
each issue link with a [closing
keyword](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue).
This ensures that the issue(s) are auto-closed once the PR has been
merged.
-->

## Checklist

- [ ] Unit, integration, and e2e (if applicable) tests updated
- [ ] Documentation added (or `pr:no public docs` PR label added if not
required)
- [ ] Added the relevant release notes label (see labels prefixed w/
`release:`). These labels dictate how your PR will
    show up in the autogenerated release notes.
  • Loading branch information
iskhakov authored Sep 18, 2024
1 parent 61902d5 commit c7a7a3f
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 12 deletions.
4 changes: 2 additions & 2 deletions engine/apps/integrations/mixins/ratelimit_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def get_rate_limit(group, request):

if group == RATELIMIT_INTEGRATION_GROUP_NAME:
if organization_id in custom_ratelimits:
return custom_ratelimits[organization_id]["integration"]
return custom_ratelimits[organization_id].integration
return RATELIMIT_INTEGRATION
elif group == RATELIMIT_TEAM_GROUP_NAME:
if organization_id in custom_ratelimits:
return custom_ratelimits[organization_id]["organization"]
return custom_ratelimits[organization_id].organization
return RATELIMIT_TEAM
else:
raise Exception("Unknown group")
Expand Down
4 changes: 2 additions & 2 deletions engine/apps/integrations/tests/test_ratelimit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from unittest import mock

import pytest
Expand All @@ -10,6 +9,7 @@
from apps.alerts.models import AlertReceiveChannel
from apps.integrations.mixins import IntegrationRateLimitMixin
from apps.integrations.mixins.ratelimit_mixin import RATELIMIT_INTEGRATION
from common.api_helpers.custom_ratelimit import load_custom_ratelimits


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -197,7 +197,7 @@ def test_custom_throttling(make_organization, make_alert_receive_channel):
+ '": {"integration": "2/m","organization": "3/m","public_api": "1/m"}}'
)

with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)):
with override_settings(CUSTOM_RATELIMITS=load_custom_ratelimits(CUSTOM_RATELIMITS_STR)):
client = Client()

# Organization without custom ratelimit should use default ratelimit
Expand Down
5 changes: 3 additions & 2 deletions engine/apps/public_api/tests/test_ratelimit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from unittest.mock import PropertyMock, patch

import pytest
Expand All @@ -8,6 +7,8 @@
from rest_framework import status
from rest_framework.test import APIClient

from common.api_helpers.custom_ratelimit import load_custom_ratelimits


@pytest.mark.django_db
def test_throttling(make_organization_and_user_with_token):
Expand Down Expand Up @@ -45,7 +46,7 @@ def test_custom_throttling(make_organization_and_user_with_token):
+ '": {"integration": "10/5m","organization": "15/5m","public_api": "1/m"}}'
)

with override_settings(CUSTOM_RATELIMITS=json.loads(CUSTOM_RATELIMITS_STR)):
with override_settings(CUSTOM_RATELIMITS=load_custom_ratelimits(CUSTOM_RATELIMITS_STR)):
client = APIClient()

url = reverse("api-public:alert_groups-list")
Expand Down
2 changes: 1 addition & 1 deletion engine/common/api_helpers/custom_rate_scoped_throttler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def allow_request(self, request, view):
custom_ratelimits = settings.CUSTOM_RATELIMITS
organization_id = str(request.user.organization_id)
if organization_id in custom_ratelimits:
self.rate = custom_ratelimits[organization_id]["public_api"]
self.rate = custom_ratelimits[organization_id].public_api
self.num_requests, self.duration = self.parse_rate(self.rate)

return super().allow_request(request, view)
Expand Down
19 changes: 19 additions & 0 deletions engine/common/api_helpers/custom_ratelimit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import os
import typing
from dataclasses import dataclass


Expand All @@ -6,3 +9,19 @@ class CustomRateLimit:
integration: str
organization: str
public_api: str


def getenv_custom_ratelimit(variable_name: str, default: dict) -> typing.Dict[str, CustomRateLimit]:
custom_ratelimits_str = os.environ.get(variable_name)
if custom_ratelimits_str is None:
return default
value = load_custom_ratelimits(custom_ratelimits_str)
return value


def load_custom_ratelimits(custom_ratelimits_str: str) -> typing.Dict[str, CustomRateLimit]:
custom_ratelimits_dict = json.loads(custom_ratelimits_str)
# Convert the parsed JSON into a dictionary of RateLimit dataclasses
custom_ratelimits = {key: CustomRateLimit(**value) for key, value in custom_ratelimits_dict.items()}

return custom_ratelimits
8 changes: 3 additions & 5 deletions engine/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from celery.schedules import crontab
from firebase_admin import credentials, initialize_app

from common.api_helpers.custom_ratelimit import CustomRateLimit
from common.api_helpers.custom_ratelimit import getenv_custom_ratelimit
from common.utils import getenv_boolean, getenv_integer, getenv_list

VERSION = "dev-oss"
Expand Down Expand Up @@ -971,10 +971,8 @@ class BrokerTypes:
# CUSTOM_RATELIMITS={"1": {"integration": "10/5m", "organization": "15/5m", "public_api": "10/5m"}}
# Where, "1" is the pk of the organization

# Load the environment variable and parse it into a dictionary, falling back to an empty dictionary if not set.
CUSTOM_RATELIMITS: typing.Dict[str, CustomRateLimit] = json.loads(os.getenv("CUSTOM_RATELIMITS", "{}"))
# Convert the parsed JSON into a dictionary of RateLimit dataclasses
CUSTOM_RATELIMITS = {key: CustomRateLimit(**value) for key, value in CUSTOM_RATELIMITS.items()}
# Load the environment variable and parse it into a dictionary of custom ralimits, falling back to an empty dictionary if not set.
CUSTOM_RATELIMITS = getenv_custom_ratelimit("CUSTOM_RATELIMITS", default={})

SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 6)
SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 240)
Expand Down

0 comments on commit c7a7a3f

Please sign in to comment.