diff --git a/settings/common.py b/settings/common.py index 51849f79..8bd85f62 100644 --- a/settings/common.py +++ b/settings/common.py @@ -440,7 +440,7 @@ REST_FRAMEWORK = { "user": None, "import-mode": None, "import-dump-mode": "1/minute", - "memberships": None, + "create-memberships": None }, "FILTER_BACKEND": "taiga.base.filters.FilterBackend", "EXCEPTION_HANDLER": "taiga.base.exceptions.exception_handler", diff --git a/settings/testing.py b/settings/testing.py index 9af0e6c8..c8875026 100644 --- a/settings/testing.py +++ b/settings/testing.py @@ -33,5 +33,5 @@ REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"] = { "user": None, "import-mode": None, "import-dump-mode": None, - "memberships": None, + "create-memberships": None, } diff --git a/taiga/base/api/throttling.py b/taiga/base/api/throttling.py index b23bea09..74178922 100644 --- a/taiga/base/api/throttling.py +++ b/taiga/base/api/throttling.py @@ -153,11 +153,15 @@ class SimpleRateThrottle(BaseThrottle): # throttle duration while self.history and self.history[-1] <= self.now - self.duration: self.history.pop() - if len(self.history) >= self.num_requests: - return self.throttle_failure() - return self.throttle_success() - def throttle_success(self): + if self.exceeded_throttling_restriction(request, view): + return self.throttle_failure() + return self.throttle_success(request, view) + + def exceeded_throttling_restriction(self, request, view): + return len(self.history) >= self.num_requests + + def throttle_success(self, request, view): """ Inserts the current request's timestamp along with the key into the cache. diff --git a/taiga/projects/throttling.py b/taiga/projects/throttling.py index f3da1ba3..7694b274 100644 --- a/taiga/projects/throttling.py +++ b/taiga/projects/throttling.py @@ -20,5 +20,20 @@ from taiga.base import throttling class MembershipsRateThrottle(throttling.UserRateThrottle): - scope = "memberships" + scope = "create-memberships" throttled_methods = ["POST", "PUT"] + + def exceeded_throttling_restriction(self, request, view): + self.created_memberships = 0 + if view.action in ["create", "resend_invitation"]: + self.created_memberships = 1 + elif view.action == "bulk_create": + self.created_memberships = len(request.DATA.get("bulk_memberships", [])) + return len(self.history) + self.created_memberships > self.num_requests + + def throttle_success(self, request, view): + for i in range(self.created_memberships): + self.history.insert(0, self.now) + + self.cache.set(self.key, self.history, self.duration) + return True diff --git a/tests/integration/test_memberships.py b/tests/integration/test_memberships.py index 5eb75486..4f3568b9 100644 --- a/tests/integration/test_memberships.py +++ b/tests/integration/test_memberships.py @@ -701,7 +701,7 @@ def test_api_create_bulk_members_max_pending_memberships(client, settings): def test_create_memberhips_throttling(client, settings): - settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["memberships"] = "1/minute" + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["create-memberships"] = "1/minute" membership = f.MembershipFactory(is_admin=True) role = f.RoleFactory.create(project=membership.project) @@ -720,11 +720,11 @@ def test_create_memberhips_throttling(client, settings): response = client.json.post(url, json.dumps(data)) assert response.status_code == 429 - settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["memberships"] = None + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["create-memberships"] = None def test_api_resend_invitation_throttling(client, outbox, settings): - settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["memberships"] = "1/minute" + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["create-memberships"] = "1/minute" invitation = f.create_invitation(user=None) f.MembershipFactory(project=invitation.project, user=invitation.project.owner, is_admin=True) @@ -742,11 +742,11 @@ def test_api_resend_invitation_throttling(client, outbox, settings): assert response.status_code == 429 assert len(outbox) == 1 assert outbox[0].to == [invitation.email] - settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["memberships"] = None + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["create-memberships"] = None def test_api_create_bulk_members_throttling(client, settings): - settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["memberships"] = "1/minute" + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["create-memberships"] = "2/minute" project = f.ProjectFactory() john = f.UserFactory.create() @@ -781,4 +781,4 @@ def test_api_create_bulk_members_throttling(client, settings): response = client.json.post(url, json.dumps(data)) assert response.status_code == 429 - settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["memberships"] = None + settings.REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"]["create-memberships"] = None