Skip to content

Commit 9085db4

Browse files
authored
App: Limit rate of requests to http queue (#18981)
1 parent a9d427c commit 9085db4

File tree

3 files changed

+79
-8
lines changed

3 files changed

+79
-8
lines changed

src/lightning/app/core/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def get_lightning_cloud_url() -> str:
5353
HTTP_QUEUE_URL = os.getenv("LIGHTNING_HTTP_QUEUE_URL", "http://localhost:9801")
5454
HTTP_QUEUE_REFRESH_INTERVAL = float(os.getenv("LIGHTNING_HTTP_QUEUE_REFRESH_INTERVAL", "1"))
5555
HTTP_QUEUE_TOKEN = os.getenv("LIGHTNING_HTTP_QUEUE_TOKEN", None)
56+
HTTP_QUEUE_REQUESTS_PER_SECOND = float(os.getenv("LIGHTNING_HTTP_QUEUE_REQUESTS_PER_SECOND", "0.5"))
5657

5758
USER_ID = os.getenv("USER_ID", "1234")
5859
FRONTEND_DIR = str(Path(__file__).parent.parent / "ui")

src/lightning/app/core/queues.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from lightning.app.core.constants import (
3131
HTTP_QUEUE_REFRESH_INTERVAL,
32+
HTTP_QUEUE_REQUESTS_PER_SECOND,
3233
HTTP_QUEUE_TOKEN,
3334
HTTP_QUEUE_URL,
3435
LIGHTNING_DIR,
@@ -77,7 +78,9 @@ def get_queue(self, queue_name: str) -> "BaseQueue":
7778
return MultiProcessQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
7879
if self == QueuingSystem.REDIS:
7980
return RedisQueue(queue_name, default_timeout=REDIS_QUEUES_READ_DEFAULT_TIMEOUT)
80-
return HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT)
81+
return RateLimitedQueue(
82+
HTTPQueue(queue_name, default_timeout=STATE_UPDATE_TIMEOUT), HTTP_QUEUE_REQUESTS_PER_SECOND
83+
)
8184

8285
def get_api_response_queue(self, queue_id: Optional[str] = None) -> "BaseQueue":
8386
queue_name = f"{queue_id}_{API_RESPONSE_QUEUE_CONSTANT}" if queue_id else API_RESPONSE_QUEUE_CONSTANT
@@ -347,6 +350,45 @@ def from_dict(cls, state: dict) -> "RedisQueue":
347350
return cls(**state)
348351

349352

353+
class RateLimitedQueue(BaseQueue):
354+
def __init__(self, queue: BaseQueue, requests_per_second: float):
355+
"""This is a queue wrapper that will block on get or put calls if they are made too quickly.
356+
357+
Args:
358+
queue: The queue to wrap.
359+
requests_per_second: The target number of get or put requests per second.
360+
361+
"""
362+
self.name = queue.name
363+
self.default_timeout = queue.default_timeout
364+
365+
self._queue = queue
366+
self._seconds_per_request = 1 / requests_per_second
367+
368+
self._last_get = 0.0
369+
self._last_put = 0.0
370+
371+
@property
372+
def is_running(self) -> bool:
373+
return self._queue.is_running
374+
375+
def _wait_until_allowed(self, last_time: float) -> None:
376+
t = time.time()
377+
diff = t - last_time
378+
if diff < self._seconds_per_request:
379+
time.sleep(self._seconds_per_request - diff)
380+
381+
def get(self, timeout: Optional[float] = None) -> Any:
382+
self._wait_until_allowed(self._last_get)
383+
self._last_get = time.time()
384+
return self._queue.get(timeout=timeout)
385+
386+
def put(self, item: Any) -> None:
387+
self._wait_until_allowed(self._last_put)
388+
self._last_put = time.time()
389+
return self._queue.put(item)
390+
391+
350392
class HTTPQueue(BaseQueue):
351393
def __init__(self, name: str, default_timeout: float) -> None:
352394
"""

tests/tests_app/core/test_queues.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
import requests_mock
99
from lightning.app import LightningFlow
1010
from lightning.app.core import queues
11-
from lightning.app.core.constants import HTTP_QUEUE_URL
12-
from lightning.app.core.queues import READINESS_QUEUE_CONSTANT, BaseQueue, QueuingSystem, RedisQueue
11+
from lightning.app.core.constants import HTTP_QUEUE_URL, STATE_UPDATE_TIMEOUT
12+
from lightning.app.core.queues import (
13+
READINESS_QUEUE_CONSTANT,
14+
BaseQueue,
15+
HTTPQueue,
16+
QueuingSystem,
17+
RateLimitedQueue,
18+
RedisQueue,
19+
)
1320
from lightning.app.utilities.imports import _is_redis_available
1421
from lightning.app.utilities.redis import check_if_redis_running
1522

@@ -162,7 +169,7 @@ def test_redis_raises_error_if_failing(redis_mock):
162169

163170
class TestHTTPQueue:
164171
def test_http_queue_failure_on_queue_name(self):
165-
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test")
172+
test_queue = HTTPQueue("test", STATE_UPDATE_TIMEOUT)
166173
with pytest.raises(ValueError, match="App ID couldn't be extracted"):
167174
test_queue.put("test")
168175

@@ -174,7 +181,7 @@ def test_http_queue_failure_on_queue_name(self):
174181

175182
def test_http_queue_put(self, monkeypatch):
176183
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
177-
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
184+
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
178185
test_obj = LightningFlow()
179186

180187
# mocking requests and responses
@@ -200,8 +207,7 @@ def test_http_queue_put(self, monkeypatch):
200207

201208
def test_http_queue_get(self, monkeypatch):
202209
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
203-
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
204-
210+
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
205211
adapter = requests_mock.Adapter()
206212
test_queue.client.session.mount("http://", adapter)
207213

@@ -218,7 +224,7 @@ def test_http_queue_get(self, monkeypatch):
218224
def test_unreachable_queue(monkeypatch):
219225
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
220226

221-
test_queue = QueuingSystem.HTTP.get_queue(queue_name="test_http_queue")
227+
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
222228

223229
resp1 = mock.MagicMock()
224230
resp1.status_code = 204
@@ -235,3 +241,25 @@ def test_unreachable_queue(monkeypatch):
235241
# Test backoff on queue.put
236242
test_queue.put("foo")
237243
assert test_queue.client.post.call_count == 3
244+
245+
246+
@mock.patch("lightning.app.core.queues.time.sleep")
247+
def test_rate_limited_queue(mock_sleep):
248+
sleeps = []
249+
mock_sleep.side_effect = lambda sleep_time: sleeps.append(sleep_time)
250+
251+
mock_queue = mock.MagicMock()
252+
253+
mock_queue.name = "inner_queue"
254+
mock_queue.default_timeout = 10.0
255+
256+
rate_limited_queue = RateLimitedQueue(mock_queue, requests_per_second=1)
257+
258+
assert rate_limited_queue.name == "inner_queue"
259+
assert rate_limited_queue.default_timeout == 10.0
260+
261+
timeout = time.perf_counter() + 1
262+
while time.perf_counter() + sum(sleeps) < timeout:
263+
rate_limited_queue.get()
264+
265+
assert mock_queue.get.call_count == 2

0 commit comments

Comments
 (0)