Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed the default `LightningClient(retry=False)` to `retry=True` ([#16382](https://github.com/Lightning-AI/lightning/pull/16382))

- Add support for async predict method in PythonServer and remove torch context ([#16453](https://github.com/Lightning-AI/lightning/pull/16453))


Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/cli/cmd_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class _AppManager:
"""_AppManager implements API calls specific to Lightning AI BYOC apps."""

def __init__(self) -> None:
self.api_client = LightningClient()
self.api_client = LightningClient(retry=False)

def get_cluster(self, cluster_id: str) -> V1GetClusterResponse:
return self.api_client.cluster_service_get_cluster(id=cluster_id)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/cli/cmd_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class AWSClusterManager:
is selected as the backend compute."""

def __init__(self) -> None:
self.api_client = LightningClient()
self.api_client = LightningClient(retry=False)

def create(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/cli/cmd_ssh_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class _SSHKeyManager:
"""_SSHKeyManager implements API calls specific to Lightning AI SSH-Keys."""

def __init__(self) -> None:
self.api_client = LightningClient()
self.api_client = LightningClient(retry=False)

def get_ssh_keys(self) -> _SSHKeyList:
resp = self.api_client.s_sh_public_key_service_list_ssh_public_keys()
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/cli/commands/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def connect(app_name_or_id: str):
retriever = _LightningAppOpenAPIRetriever(app_name_or_id)

if not retriever.api_commands:
client = LightningClient()
client = LightningClient(retry=False)
project = _get_project(client)
apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project.project_id)
click.echo(
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/cli/commands/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def logs(app_name: str, components: List[str], follow: bool) -> None:


def _show_logs(app_name: str, components: List[str], follow: bool) -> None:
client = LightningClient()
client = LightningClient(retry=False)
project = _get_project(client)

apps = {
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/cli/lightning_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def cluster_logs(cluster_id: str, to_time: arrow.Arrow, from_time: arrow.Arrow,
$ lightning show cluster logs my-cluster --limit 10
"""

client = LightningClient()
client = LightningClient(retry=False)
cluster_manager = AWSClusterManager()
existing_cluster_list = cluster_manager.get_clusters()

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/runners/backends/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class CloudBackend(Backend):
def __init__(self, entrypoint_file, queue_id: Optional[str] = None, status_update_interval: int = None):
super().__init__(entrypoint_file, queues=QueuingSystem.MULTIPROCESS, queue_id=queue_id)
self.client = LightningClient(retry=True)
self.client = LightningClient()

def create_work(self, app: "lightning_app.LightningApp", work: "lightning_app.LightningWork") -> None:
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/utilities/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _maybe_find_url(self):

def _maybe_find_matching_cloud_app(self):
"""Tries to resolve the app url from the provided `app_id_or_name_or_url`."""
client = LightningClient()
client = LightningClient(retry=False)
project = _get_project(client)
list_apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project.project_id)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/utilities/commands/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _download_command(
if not debug_mode:
if app_id:
if not os.path.exists(target_file):
client = LightningClient()
client = LightningClient(retry=False)
project_id = _get_project(client).project_id
response = client.lightningapp_instance_service_list_lightningapp_instance_artifacts(
project_id=project_id, id=app_id
Expand Down
16 changes: 6 additions & 10 deletions src/lightning_app/utilities/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _get_next_backoff_time(num_retries: int, backoff_value: float = 0.5) -> floa
return min(_DEFAULT_BACKOFF_MAX, next_backoff_value)


def _retry_wrapper(func: Callable) -> Callable:
def _retry_wrapper(self, func: Callable) -> Callable:
"""Returns the function decorated by a wrapper that retries the call several times if a connection error
occurs.

Expand All @@ -77,7 +77,7 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
consecutive_errors = 0
while _get_next_backoff_time(consecutive_errors) != _DEFAULT_BACKOFF_MAX:
try:
return func(*args, **kwargs)
return func(self, *args, **kwargs)
except lightning_cloud.openapi.rest.ApiException as e:
# retry if the control plane fails with all errors except 4xx but not 408 - (Request Timeout)
if e.status == 408 or e.status == 409 or not str(e.status).startswith("4"):
Expand Down Expand Up @@ -113,17 +113,13 @@ class LightningClient(GridRestClient):
retry: Whether API calls should follow a retry mechanism with exponential backoff.
"""

def __new__(cls, *args: Any, **kwargs: Any) -> "LightningClient":
if kwargs.get("retry", False):
def __init__(self, retry: bool = True) -> None:
super().__init__(api_client=create_swagger_client())
if retry:
for base_class in GridRestClient.__mro__:
for name, attribute in base_class.__dict__.items():
if callable(attribute) and attribute.__name__ != "__init__":
setattr(cls, name, _retry_wrapper(attribute))
return super().__new__(cls)

def __init__(self, retry: bool = False) -> None:
super().__init__(api_client=create_swagger_client())
self._retry = retry
setattr(self, name, _retry_wrapper(self, attribute))


class CustomRetryAdapter(HTTPAdapter):
Expand Down
16 changes: 8 additions & 8 deletions tests/tests_app/utilities/test_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest.mock import patch

from lightning_app.utilities.network import find_free_network_port, LightningClient


Expand All @@ -8,10 +6,12 @@ def test_port():


def test_lightning_client_retry_enabled():
with patch("lightning_app.utilities.network._retry_wrapper") as wrapper:
LightningClient() # default: retry=False
wrapper.assert_not_called()

with patch("lightning_app.utilities.network._retry_wrapper") as wrapper:
LightningClient(retry=True)
wrapper.assert_called()
client = LightningClient() # default: retry=True
assert hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")

client = LightningClient(retry=False)
assert not hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")

client = LightningClient(retry=True)
assert hasattr(client.auth_service_get_user_with_http_info, "__wrapped__")