|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import base64 |
15 | 16 | import multiprocessing
|
16 | 17 | import pickle
|
17 | 18 | import queue # needed as import instead from/import for mocking in tests
|
|
22 | 23 | from pathlib import Path
|
23 | 24 | from typing import Any, Optional, Tuple
|
24 | 25 | from urllib.parse import urljoin
|
25 |
| -import numpy as np |
| 26 | + |
26 | 27 | import backoff
|
27 | 28 | import requests
|
28 | 29 | from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
|
29 |
| -import base64 |
30 | 30 |
|
31 | 31 | from lightning.app.core.constants import (
|
32 | 32 | HTTP_QUEUE_REFRESH_INTERVAL,
|
@@ -191,7 +191,7 @@ def get(self, timeout: Optional[float] = None) -> Any:
|
191 | 191 | pass
|
192 | 192 |
|
193 | 193 | @abstractmethod
|
194 |
| - def get_all(self, timeout: Optional[float] = None) -> Any: |
| 194 | + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any: |
195 | 195 | """Returns the left most elements of the queue.
|
196 | 196 |
|
197 | 197 | Parameters
|
@@ -228,9 +228,10 @@ def get(self, timeout: Optional[float] = None) -> Any:
|
228 | 228 | timeout = self.default_timeout
|
229 | 229 | return self.queue.get(timeout=timeout, block=(timeout is None))
|
230 | 230 |
|
231 |
| - def get_all(self, timeout: Optional[float] = None) -> Any: |
| 231 | + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any: |
232 | 232 | if timeout == 0:
|
233 | 233 | timeout = self.default_timeout
|
| 234 | + # For multiprocessing, we can simply collect the latest upmost element |
234 | 235 | return [self.queue.get(timeout=timeout, block=(timeout is None))]
|
235 | 236 |
|
236 | 237 |
|
@@ -331,8 +332,8 @@ def get(self, timeout: Optional[float] = None) -> Any:
|
331 | 332 | raise queue.Empty
|
332 | 333 | return pickle.loads(out[1])
|
333 | 334 |
|
334 |
| - def get_all(self, timeout: Optional[float] = None) -> Any: |
335 |
| - raise NotImplementedError |
| 335 | + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any: |
| 336 | + raise NotImplementedError("The batch_get method isn't implemented.") |
336 | 337 |
|
337 | 338 | def clear(self) -> None:
|
338 | 339 | """Clear all elements in the queue."""
|
@@ -404,10 +405,10 @@ def get(self, timeout: Optional[float] = None) -> Any:
|
404 | 405 | self._last_get = time.time()
|
405 | 406 | return self._queue.get(timeout=timeout)
|
406 | 407 |
|
407 |
| - def get_all(self, timeout: Optional[float] = None) -> Any: |
| 408 | + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any: |
408 | 409 | self._wait_until_allowed(self._last_get)
|
409 | 410 | self._last_get = time.time()
|
410 |
| - return self._queue.get_all(timeout=timeout) |
| 411 | + return self._queue.batch_get(timeout=timeout) |
411 | 412 |
|
412 | 413 | def put(self, item: Any) -> None:
|
413 | 414 | return self._queue.put(item)
|
@@ -501,13 +502,53 @@ def _get(self) -> Any:
|
501 | 502 | # we consider the queue is empty to avoid failing the app.
|
502 | 503 | raise queue.Empty
|
503 | 504 |
|
504 |
| - def get_all(self, timeout: Optional[float] = None) -> Any: |
| 505 | + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> list[Any]: |
505 | 506 | if not self.app_id:
|
506 | 507 | raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}")
|
507 | 508 |
|
| 509 | + # it's a blocking call, we need to loop and call the backend to mimic this behavior |
| 510 | + if timeout is None: |
| 511 | + while True: |
| 512 | + try: |
| 513 | + try: |
| 514 | + return self._batch_get(count=count) |
| 515 | + except requests.exceptions.HTTPError: |
| 516 | + pass |
| 517 | + except queue.Empty: |
| 518 | + time.sleep(HTTP_QUEUE_REFRESH_INTERVAL) |
| 519 | + |
| 520 | + # make one request and return the result |
| 521 | + if timeout == 0: |
| 522 | + try: |
| 523 | + return self._batch_get(count=count) |
| 524 | + except requests.exceptions.HTTPError: |
| 525 | + return [] |
| 526 | + |
| 527 | + # timeout is some value - loop until the timeout is reached |
| 528 | + start_time = time.time() |
| 529 | + while (time.time() - start_time) < timeout: |
| 530 | + try: |
| 531 | + try: |
| 532 | + return self._batch_get(count=count) |
| 533 | + except requests.exceptions.HTTPError: |
| 534 | + if timeout > self.default_timeout: |
| 535 | + return [] |
| 536 | + raise queue.Empty |
| 537 | + except queue.Empty: |
| 538 | + # Note: In theory, there isn't a need for a sleep as the queue shouldn't |
| 539 | + # block the flow if the queue is empty. |
| 540 | + # However, as the Http Server can saturate, |
| 541 | + # let's add a sleep here if a higher timeout is provided |
| 542 | + # than the default timeout |
| 543 | + if timeout > self.default_timeout: |
| 544 | + time.sleep(0.05) |
| 545 | + return [] |
| 546 | + |
| 547 | + def _batch_get(self, count: Optional[int] = 64) -> list[Any]: |
508 | 548 | try:
|
509 |
| - print("HERE") |
510 |
| - resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "popCount", "count": "64"}) |
| 549 | + resp = self.client.post( |
| 550 | + f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "popCount", "count": str(count)} |
| 551 | + ) |
511 | 552 | if resp.status_code == 204:
|
512 | 553 | raise queue.Empty
|
513 | 554 | return [pickle.loads(base64.b64decode(data)) for data in resp.json()]
|
|
0 commit comments