Skip to content

Commit a72bcc6

Browse files
authored
Merge pull request #160 from colin99d/master
feat: allow Requests to be sent to exempt_when
2 parents 7769a13 + 42330fc commit a72bcc6

File tree

4 files changed

+83
-9
lines changed

4 files changed

+83
-9
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Change Log
22

3+
## [0.1.10] - 2024-06-04
4+
5+
### Changed
6+
7+
- Breaking change: allow usage of the request object in the except_when function (thanks @colin99d)
8+
39
## [0.1.9] - 2024-02-05
410

511
### Added

slowapi/extension.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
The starlette extension to rate-limit requests
33
"""
4+
45
import asyncio
56
import functools
67
import inspect
@@ -486,7 +487,7 @@ def __evaluate_limits(
486487
limit_for_header = None
487488
for lim in limits:
488489
limit_scope = lim.scope or endpoint
489-
if lim.is_exempt:
490+
if lim.is_exempt(request):
490491
continue
491492
if lim.methods is not None and request.method.lower() not in lim.methods:
492493
continue
@@ -703,11 +704,9 @@ def decorator(func: Callable[..., Response]):
703704
else:
704705
self._route_limits.setdefault(name, []).extend(static_limits)
705706

706-
connection_type: Optional[str] = None
707707
sig = inspect.signature(func)
708708
for idx, parameter in enumerate(sig.parameters.values()):
709709
if parameter.name == "request" or parameter.name == "websocket":
710-
connection_type = parameter.name
711710
break
712711
else:
713712
raise Exception(
@@ -736,7 +735,8 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Response:
736735
if not isinstance(response, Response):
737736
# get the response object from the decorated endpoint function
738737
self._inject_headers(
739-
kwargs.get("response"), request.state.view_rate_limit # type: ignore
738+
kwargs.get("response"), # type: ignore
739+
request.state.view_rate_limit,
740740
)
741741
else:
742742
self._inject_headers(
@@ -768,7 +768,8 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Response:
768768
if not isinstance(response, Response):
769769
# get the response object from the decorated endpoint function
770770
self._inject_headers(
771-
kwargs.get("response"), request.state.view_rate_limit # type: ignore
771+
kwargs.get("response"),
772+
request.state.view_rate_limit, # type: ignore
772773
)
773774
else:
774775
self._inject_headers(
@@ -805,7 +806,7 @@ def limit(
805806
* **error_message**: string (or callable that returns one) to override the
806807
error message used in the response.
807808
* **exempt_when**: function returning a boolean indicating whether to exempt
808-
the route from the limit
809+
the route from the limit. This function can optionally use a Request object.
809810
* **cost**: integer (or callable that returns one) which is the cost of a hit
810811
* **override_defaults**: whether to override the default limits (default: True)
811812
"""

slowapi/wrappers.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, Iterator, List, Optional, Union
33

44
from limits import RateLimitItem, parse_many # type: ignore
5+
from starlette.requests import Request
56

67

78
class Limit(object):
@@ -28,16 +29,27 @@ def __init__(
2829
self.methods = methods
2930
self.error_message = error_message
3031
self.exempt_when = exempt_when
32+
self._exempt_when_takes_request = (
33+
self.exempt_when
34+
and len(inspect.signature(self.exempt_when).parameters) == 1
35+
)
3136
self.cost = cost
3237
self.override_defaults = override_defaults
3338

34-
@property
35-
def is_exempt(self) -> bool:
39+
def is_exempt(self, request: Optional[Request] = None) -> bool:
3640
"""
3741
Check if the limit is exempt.
42+
43+
** parameter **
44+
* **request**: the request object
45+
3846
Return True to exempt the route from the limit.
3947
"""
40-
return self.exempt_when() if self.exempt_when is not None else False
48+
if self.exempt_when is None:
49+
return False
50+
if self._exempt_when_takes_request and request:
51+
return self.exempt_when(request)
52+
return self.exempt_when()
4153

4254
@property
4355
def scope(self) -> str:

tests/test_starlette_extension.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,61 @@ def t1(request: Request):
4343
if i < 5:
4444
assert response.text == "test"
4545

46+
def test_exempt_when_argument(self, build_starlette_app):
47+
app, limiter = build_starlette_app(key_func=get_ipaddr)
48+
49+
def return_true():
50+
return True
51+
52+
def return_false():
53+
return False
54+
55+
def dynamic(request: Request):
56+
user_agent = request.headers.get("User-Agent")
57+
if user_agent is None:
58+
return False
59+
return user_agent == "exempt"
60+
61+
@limiter.limit("1/minute", exempt_when=return_true)
62+
def always_true(request: Request):
63+
return PlainTextResponse("test")
64+
65+
@limiter.limit("1/minute", exempt_when=return_false)
66+
def always_false(request: Request):
67+
return PlainTextResponse("test")
68+
69+
@limiter.limit("1/minute", exempt_when=dynamic)
70+
def always_dynamic(request: Request):
71+
return PlainTextResponse("test")
72+
73+
app.add_route("/true", always_true)
74+
app.add_route("/false", always_false)
75+
app.add_route("/dynamic", always_dynamic)
76+
77+
client = TestClient(app)
78+
# Test always true always exempting
79+
for i in range(0, 2):
80+
response = client.get("/true")
81+
assert response.status_code == 200
82+
assert response.text == "test"
83+
# Test always false hitting the limit after one hit
84+
for i in range(0, 2):
85+
response = client.get("/false")
86+
assert response.status_code == 200 if i < 1 else 429
87+
if i < 1:
88+
assert response.text == "test"
89+
# Test dynamic not exempting with the correct header
90+
for i in range(0, 2):
91+
response = client.get("/dynamic", headers={"User-Agent": "exempt"})
92+
assert response.status_code == 200
93+
assert response.text == "test"
94+
# Test dynamic exempting with the incorrect header
95+
for i in range(0, 2):
96+
response = client.get("/dynamic")
97+
assert response.status_code == 200 if i < 1 else 429
98+
if i < 1:
99+
assert response.text == "test"
100+
46101
def test_shared_decorator(self, build_starlette_app):
47102
app, limiter = build_starlette_app(key_func=get_ipaddr)
48103

0 commit comments

Comments
 (0)