Skip to content

Commit 1cee84c

Browse files
authored
Replace LightningClient with import from lightning_cloud (#18544)
1 parent 8cee0b4 commit 1cee84c

File tree

3 files changed

+5
-103
lines changed

3 files changed

+5
-103
lines changed

requirements/app/app.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud >=0.5.37
1+
lightning-cloud >=0.5.38
22
packaging
33
typing-extensions >=4.0.0, <4.8.0
44
deepdiff >=5.7.0, <6.3.2

src/lightning/app/utilities/network.py

Lines changed: 3 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
# limitations under the License.
1414

1515
import socket
16-
import time
1716
from functools import wraps
1817
from typing import Any, Callable, Dict, Optional
1918
from urllib.parse import urljoin
2019

21-
import lightning_cloud
2220
import requests
23-
import urllib3
24-
from lightning_cloud.rest_client import create_swagger_client, GridRestClient
21+
22+
# for backwards compatibility
23+
from lightning_cloud.rest_client import create_swagger_client, GridRestClient, LightningClient # noqa: F401
2524
from requests import Session
2625
from requests.adapters import HTTPAdapter
2726
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
@@ -87,7 +86,6 @@ def _find_free_network_port_cloudspace():
8786

8887
_CONNECTION_RETRY_TOTAL = 2880
8988
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
90-
_DEFAULT_BACKOFF_MAX = 5 * 60 # seconds
9189
_DEFAULT_REQUEST_TIMEOUT = 30 # seconds
9290

9391

@@ -119,75 +117,6 @@ def _check_service_url_is_ready(url: str, timeout: float = 5, metadata="") -> bo
119117
return False
120118

121119

122-
def _get_next_backoff_time(num_retries: int, backoff_value: float = 0.5) -> float:
123-
next_backoff_value = backoff_value * (2 ** (num_retries - 1))
124-
return min(_DEFAULT_BACKOFF_MAX, next_backoff_value)
125-
126-
127-
def _retry_wrapper(self, func: Callable, max_tries: Optional[int] = None) -> Callable:
128-
"""Returns the function decorated by a wrapper that retries the call several times if a connection error occurs.
129-
130-
The retries follow an exponential backoff.
131-
132-
"""
133-
134-
@wraps(func)
135-
def wrapped(*args: Any, **kwargs: Any) -> Any:
136-
consecutive_errors = 0
137-
138-
while True:
139-
try:
140-
return func(self, *args, **kwargs)
141-
except (lightning_cloud.openapi.rest.ApiException, urllib3.exceptions.HTTPError) as ex:
142-
# retry if the backend fails with all errors except 4xx but not 408 - (Request Timeout)
143-
if (
144-
isinstance(ex, urllib3.exceptions.HTTPError)
145-
or ex.status in (408, 409)
146-
or not str(ex.status).startswith("4")
147-
):
148-
consecutive_errors += 1
149-
backoff_time = _get_next_backoff_time(consecutive_errors)
150-
151-
msg = (
152-
f"error: {str(ex)}"
153-
if isinstance(ex, urllib3.exceptions.HTTPError)
154-
else f"response: {ex.status}"
155-
)
156-
logger.debug(
157-
f"The {func.__name__} request failed to reach the server, {msg}."
158-
f" Retrying after {backoff_time} seconds."
159-
)
160-
161-
if max_tries is not None and consecutive_errors == max_tries:
162-
raise Exception(f"The {func.__name__} request failed to reach the server, {msg}.")
163-
164-
time.sleep(backoff_time)
165-
else:
166-
raise ex
167-
168-
return wrapped
169-
170-
171-
class LightningClient(GridRestClient):
172-
"""The LightningClient is a wrapper around the GridRestClient.
173-
174-
It wraps all methods to monitor connection exceptions and employs a retry strategy.
175-
176-
Args:
177-
retry: Whether API calls should follow a retry mechanism with exponential backoff.
178-
max_tries: Maximum number of attempts (or -1 to retry forever).
179-
180-
"""
181-
182-
def __init__(self, retry: bool = True, max_tries: Optional[int] = None) -> None:
183-
super().__init__(api_client=create_swagger_client())
184-
if retry:
185-
for base_class in GridRestClient.__mro__:
186-
for name, attribute in base_class.__dict__.items():
187-
if callable(attribute) and attribute.__name__ != "__init__":
188-
setattr(self, name, _retry_wrapper(self, attribute, max_tries=max_tries))
189-
190-
191120
class CustomRetryAdapter(HTTPAdapter):
192121
def __init__(self, *args: Any, **kwargs: Any):
193122
self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)

tests/tests_app/utilities/test_network.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import re
21
from unittest import mock
32

43
import pytest
5-
from urllib3.exceptions import HTTPError
64

75
from lightning.app.core import constants
8-
from lightning.app.utilities.network import _retry_wrapper, find_free_network_port, LightningClient
6+
from lightning.app.utilities.network import find_free_network_port
97

108

119
def test_find_free_network_port():
@@ -45,28 +43,3 @@ def test_find_free_network_port_cloudspace(_, patch_constants):
4543

4644
# Shouldn't use the APP_SERVER_PORT
4745
assert constants.APP_SERVER_PORT not in ports
48-
49-
50-
def test_lightning_client_retry_enabled():
51-
client = LightningClient() # default: retry=True
52-
assert hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")
53-
54-
client = LightningClient(retry=False)
55-
assert not hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")
56-
57-
client = LightningClient(retry=True)
58-
assert hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")
59-
60-
61-
@mock.patch("time.sleep")
62-
def test_retry_wrapper_max_tries(_):
63-
mock_client = mock.MagicMock()
64-
mock_client.test.__name__ = "test"
65-
mock_client.test.side_effect = HTTPError("failed")
66-
67-
wrapped_mock_client = _retry_wrapper(mock_client, mock_client.test, max_tries=3)
68-
69-
with pytest.raises(Exception, match=re.escape("The test request failed to reach the server, error: failed")):
70-
wrapped_mock_client()
71-
72-
assert mock_client.test.call_count == 3

0 commit comments

Comments
 (0)