Skip to content

Commit 77c929f

Browse files
thomasthomas
authored andcommitted
update
1 parent 36a67d8 commit 77c929f

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed

src/lightning/app/core/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from lightning.app import _console
3030
from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest
3131
from lightning.app.core.constants import (
32+
BATCH_DELTA_COUNT,
3233
DEBUG_ENABLED,
3334
FLOW_DURATION_SAMPLES,
3435
FLOW_DURATION_THRESHOLD,
@@ -312,7 +313,7 @@ def get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None)
312313
def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> List[dict]:
313314
try:
314315
timeout = timeout or q.default_timeout
315-
return q.get_all(timeout=timeout)
316+
return q.batch_get(timeout=timeout, count=BATCH_DELTA_COUNT)
316317
except queue.Empty:
317318
return []
318319

@@ -353,7 +354,6 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque
353354
self.delta_queue # type: ignore[assignment,arg-type]
354355
)
355356
for delta in received_deltas:
356-
print(delta)
357357
if isinstance(delta, _DeltaRequest):
358358
deltas.append(delta.delta)
359359
elif isinstance(delta, ComponentDelta):

src/lightning/app/core/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def get_lightning_cloud_url() -> str:
9898
# directory where system customization sync files will be copied to be packed into app tarball
9999
SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync"
100100

101+
BATCH_DELTA_COUNT = int(os.getenv("BATCH_DELTA_COUNT", "128"))
102+
101103

102104
def enable_multiple_works_in_default_container() -> bool:
103105
return bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0")))

src/lightning/app/core/queues.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import base64
1516
import multiprocessing
1617
import pickle
1718
import queue # needed as import instead from/import for mocking in tests
@@ -22,11 +23,10 @@
2223
from pathlib import Path
2324
from typing import Any, Optional, Tuple
2425
from urllib.parse import urljoin
25-
import numpy as np
26+
2627
import backoff
2728
import requests
2829
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
29-
import base64
3030

3131
from lightning.app.core.constants import (
3232
HTTP_QUEUE_REFRESH_INTERVAL,
@@ -191,7 +191,7 @@ def get(self, timeout: Optional[float] = None) -> Any:
191191
pass
192192

193193
@abstractmethod
194-
def get_all(self, timeout: Optional[float] = None) -> Any:
194+
def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
195195
"""Returns the left most elements of the queue.
196196
197197
Parameters
@@ -228,9 +228,10 @@ def get(self, timeout: Optional[float] = None) -> Any:
228228
timeout = self.default_timeout
229229
return self.queue.get(timeout=timeout, block=(timeout is None))
230230

231-
def get_all(self, timeout: Optional[float] = None) -> Any:
231+
def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
232232
if timeout == 0:
233233
timeout = self.default_timeout
234+
# For multiprocessing, we can simply collect the latest upmost element
234235
return [self.queue.get(timeout=timeout, block=(timeout is None))]
235236

236237

@@ -331,8 +332,8 @@ def get(self, timeout: Optional[float] = None) -> Any:
331332
raise queue.Empty
332333
return pickle.loads(out[1])
333334

334-
def get_all(self, timeout: Optional[float] = None) -> Any:
335-
raise NotImplementedError
335+
def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
336+
raise NotImplementedError("The batch_get method isn't implemented.")
336337

337338
def clear(self) -> None:
338339
"""Clear all elements in the queue."""
@@ -404,10 +405,10 @@ def get(self, timeout: Optional[float] = None) -> Any:
404405
self._last_get = time.time()
405406
return self._queue.get(timeout=timeout)
406407

407-
def get_all(self, timeout: Optional[float] = None) -> Any:
408+
def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
408409
self._wait_until_allowed(self._last_get)
409410
self._last_get = time.time()
410-
return self._queue.get_all(timeout=timeout)
411+
return self._queue.batch_get(timeout=timeout)
411412

412413
def put(self, item: Any) -> None:
413414
return self._queue.put(item)
@@ -501,13 +502,53 @@ def _get(self) -> Any:
501502
# we consider the queue is empty to avoid failing the app.
502503
raise queue.Empty
503504

504-
def get_all(self, timeout: Optional[float] = None) -> Any:
505+
def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> list[Any]:
505506
if not self.app_id:
506507
raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}")
507508

509+
# it's a blocking call, we need to loop and call the backend to mimic this behavior
510+
if timeout is None:
511+
while True:
512+
try:
513+
try:
514+
return self._batch_get(count=count)
515+
except requests.exceptions.HTTPError:
516+
pass
517+
except queue.Empty:
518+
time.sleep(HTTP_QUEUE_REFRESH_INTERVAL)
519+
520+
# make one request and return the result
521+
if timeout == 0:
522+
try:
523+
return self._batch_get(count=count)
524+
except requests.exceptions.HTTPError:
525+
return []
526+
527+
# timeout is some value - loop until the timeout is reached
528+
start_time = time.time()
529+
while (time.time() - start_time) < timeout:
530+
try:
531+
try:
532+
return self._batch_get(count=count)
533+
except requests.exceptions.HTTPError:
534+
if timeout > self.default_timeout:
535+
return []
536+
raise queue.Empty
537+
except queue.Empty:
538+
# Note: In theory, there isn't a need for a sleep as the queue shouldn't
539+
# block the flow if the queue is empty.
540+
# However, as the Http Server can saturate,
541+
# let's add a sleep here if a higher timeout is provided
542+
# than the default timeout
543+
if timeout > self.default_timeout:
544+
time.sleep(0.05)
545+
return []
546+
547+
def _batch_get(self, count: Optional[int] = 64) -> list[Any]:
508548
try:
509-
print("HERE")
510-
resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "popCount", "count": "64"})
549+
resp = self.client.post(
550+
f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "popCount", "count": str(count)}
551+
)
511552
if resp.status_code == 204:
512553
raise queue.Empty
513554
return [pickle.loads(base64.b64decode(data)) for data in resp.json()]

0 commit comments

Comments
 (0)