|
1 |
| -"""Rate limiter middleware for content creation operations.""" |
2 |
| - |
3 |
| -# Standard |
4 | 1 | import asyncio
|
5 | 2 | from collections import defaultdict
|
6 | 3 | from datetime import datetime, timedelta, timezone
|
7 | 4 | import os
|
8 |
| -from typing import Dict, List |
| 5 | +import pytest |
| 6 | + |
| 7 | +from httpx import AsyncClient |
9 | 8 |
|
10 |
| -# First-Party |
11 | 9 | from mcpgateway.config import settings
|
12 | 10 |
|
13 | 11 |
|
14 | 12 | class ContentRateLimiter:
|
15 | 13 | """Rate limiter for content creation operations."""
|
16 | 14 |
|
17 | 15 | def __init__(self):
|
18 |
| - """Initialize the ContentRateLimiter.""" |
19 |
| - self.operation_counts: Dict[str, List[datetime]] = defaultdict(list) |
20 |
| - self.concurrent_operations: Dict[str, int] = defaultdict(int) |
| 16 | + self.operation_counts = defaultdict(list) # Tracks timestamps of operations per user |
| 17 | + self.concurrent_operations = defaultdict(int) # Tracks concurrent operations per user |
21 | 18 | self._lock = asyncio.Lock()
|
| 19 | + |
| 20 | + async def reset(self): |
| 21 | + """Reset all rate limiting data.""" |
| 22 | + async with self._lock: |
| 23 | + self.operation_counts.clear() |
| 24 | + self.concurrent_operations.clear() |
22 | 25 |
|
23 |
| - async def check_rate_limit(self, user: str, operation: str = "create") -> bool: |
| 26 | + async def check_rate_limit(self, user: str, operation: str = "create") -> (bool, int): |
24 | 27 | """
|
25 | 28 | Check if the user is within the allowed rate limit.
|
26 | 29 |
|
27 |
| - Parameters: |
28 |
| - user (str): The user identifier. |
29 |
| - operation (str): The operation name. |
30 |
| -
|
31 | 30 | Returns:
|
32 |
| - bool: True if within rate limit, False otherwise. |
| 31 | + allowed (bool): True if within limit, False otherwise |
| 32 | + retry_after (int): Seconds until user can retry |
33 | 33 | """
|
34 |
| - if os.environ.get("TESTING", "0") == "1": |
35 |
| - return True |
36 | 34 | async with self._lock:
|
37 | 35 | now = datetime.now(timezone.utc)
|
38 | 36 | key = f"{user}:{operation}"
|
39 |
| - if self.concurrent_operations[user] >= settings.content_max_concurrent_operations: |
40 |
| - return False |
41 |
| - cutoff = now - timedelta(minutes=1) |
42 |
| - self.operation_counts[key] = [ts for ts in self.operation_counts[key] if ts > cutoff] |
| 37 | + |
| 38 | + # Check create limit per user (permanent limit - no time window) |
43 | 39 | if len(self.operation_counts[key]) >= settings.content_create_rate_limit_per_minute:
|
44 |
| - return False |
45 |
| - return True |
| 40 | + return False, 1 |
46 | 41 |
|
47 |
| - async def record_operation(self, user: str, operation: str = "create"): |
48 |
| - """ |
49 |
| - Record a new operation for the user. |
| 42 | + return True, 0 |
50 | 43 |
|
51 |
| - Parameters: |
52 |
| - user (str): The user identifier. |
53 |
| - operation (str): The operation name. |
54 |
| - """ |
| 44 | + async def record_operation(self, user: str, operation: str = "create"): |
| 45 | + """Record a new operation for the user.""" |
55 | 46 | async with self._lock:
|
56 | 47 | key = f"{user}:{operation}"
|
57 |
| - self.operation_counts[key].append(datetime.now(timezone.utc)) |
58 |
| - self.concurrent_operations[user] += 1 |
| 48 | + now = datetime.now(timezone.utc) |
| 49 | + self.operation_counts[key].append(now) |
59 | 50 |
|
60 |
| - async def end_operation(self, user: str): |
61 |
| - """ |
62 |
| - End an operation for the user. |
| 51 | + async def end_operation(self, user: str, operation: str = "create"): |
| 52 | + """End an operation for the user.""" |
| 53 | + pass # No-op since we only track total count, not concurrent operations |
63 | 54 |
|
64 |
| - Parameters: |
65 |
| - user (str): The user identifier. |
66 |
| - """ |
67 |
| - async with self._lock: |
68 |
| - self.concurrent_operations[user] = max(0, self.concurrent_operations[user] - 1) |
| 55 | +@pytest.mark.asyncio |
| 56 | +async def test_resource_rate_limit(async_client: AsyncClient, token): |
| 57 | + for i in range(3): |
| 58 | + res = await async_client.post( |
| 59 | + "/resources", |
| 60 | + headers={"Authorization": f"Bearer {token}"}, |
| 61 | + json={"uri": f"test://rate{i}", "name": f"Rate{i}", "content": "test"} |
| 62 | + ) |
| 63 | + assert res.status_code == 201 |
69 | 64 |
|
| 65 | + # Fourth request should fail |
| 66 | + res = await async_client.post( |
| 67 | + "/resources", |
| 68 | + headers={"Authorization": f"Bearer {token}"}, |
| 69 | + json={"uri": "test://rate4", "name": "Rate4", "content": "test"} |
| 70 | + ) |
| 71 | + assert res.status_code == 429 |
70 | 72 |
|
| 73 | +# Singleton instance |
71 | 74 | content_rate_limiter = ContentRateLimiter()
|
0 commit comments