Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def get_lightning_cloud_url() -> str:
LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
LIGHTNING_MODELS_PUBLIC_REGISTRY = "https://lightning.ai/v1/models"

LIGHTNING_CLOUDSPACE_HOST = os.getenv("LIGHTNING_CLOUDSPACE_HOST")
LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT = int(os.getenv("LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT", "0"))

# EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE
DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50"))
ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool(
Expand Down
9 changes: 8 additions & 1 deletion src/lightning/app/runners/backends/mp_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from typing import List, Optional

import lightning.app
from lightning.app.core import constants
from lightning.app.core.queues import QueuingSystem
from lightning.app.runners.backends.backend import Backend, WorkManager
from lightning.app.utilities.enum import WorkStageStatus
from lightning.app.utilities.network import _check_service_url_is_ready
from lightning.app.utilities.network import _check_service_url_is_ready, find_free_network_port
from lightning.app.utilities.port import disable_port, enable_port
from lightning.app.utilities.proxies import ProxyWorkRun, WorkRunner

Expand Down Expand Up @@ -76,6 +77,12 @@ def __init__(self, entrypoint_file: str):
super().__init__(entrypoint_file=entrypoint_file, queues=QueuingSystem.MULTIPROCESS, queue_id="0")

def create_work(self, app, work) -> None:
if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
# Override the port if set by the user
work._port = find_free_network_port()
work._host = "0.0.0.0"
work._future_url = f"https://{work.port}-{constants.LIGHTNING_CLOUDSPACE_HOST}"

app.processes[work.name] = MultiProcessWorkManager(app, work)
app.processes[work.name].start()
self.resolve_url(app)
Expand Down
9 changes: 5 additions & 4 deletions src/lightning/app/runners/multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import click

from lightning.app.api.http_methods import _add_tags_to_api, _validate_api
from lightning.app.core import constants
from lightning.app.core.api import start_server
from lightning.app.core.constants import APP_SERVER_IN_CLOUD
from lightning.app.runners.backends import Backend
from lightning.app.runners.runtime import Runtime
from lightning.app.storage.orchestrator import StorageOrchestrator
Expand Down Expand Up @@ -50,7 +50,8 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
_set_flow_context()

# Note: In case the runtime is used in the cloud.
self.host = "0.0.0.0" if APP_SERVER_IN_CLOUD else self.host
in_cloudspace = constants.LIGHTNING_CLOUDSPACE_HOST is not None
self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host

self.app.backend = self.backend
self.backend._prepare_queues(self.app)
Expand Down Expand Up @@ -116,7 +117,7 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
# wait for server to be ready
has_started_queue.get()

if open_ui and not _is_headless(self.app):
if open_ui and not _is_headless(self.app) and constants.LIGHTNING_CLOUDSPACE_HOST is None:
click.launch(self._get_app_url())

# Connect the runtime to the application.
Expand All @@ -134,7 +135,7 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
self.terminate()

def terminate(self):
if APP_SERVER_IN_CLOUD:
if constants.APP_SERVER_IN_CLOUD:
# Close all the ports open for the App within the App.
ports = [self.port] + getattr(self.backend, "ports", [])
for port in ports:
Expand Down
52 changes: 48 additions & 4 deletions src/lightning/app/utilities/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,64 @@
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
from urllib3.util.retry import Retry

from lightning.app.core import constants
from lightning.app.utilities.app_helpers import Logger

logger = Logger(__name__)


# Global record to track ports that have been allocated in this session.
_reserved_ports = set()


def find_free_network_port() -> int:
"""Finds a free port on localhost."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
return _find_free_network_port_cloudspace()

port = None

for _ in range(10):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()

if port not in _reserved_ports:
break

if port in _reserved_ports:
# Prevent an infinite loop, if we tried 10 times and didn't get a free port then something is wrong
raise RuntimeError(
"Couldn't find a free port. Please open an issue at `https://github.com/Lightning-AI/lightning/issues`."
)

_reserved_ports.add(port)
return port


def _find_free_network_port_cloudspace():
"""Finds a free port in the exposed range when running in a cloudspace."""
for port in range(
constants.APP_SERVER_PORT,
constants.APP_SERVER_PORT + constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT,
):
if port in _reserved_ports:
continue

try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", port))
sock.close()
_reserved_ports.add(port)
return port
except OSError:
continue

# This error should never happen. An app using this many ports would probably fail on a single machine anyway.
raise RuntimeError(f"All {constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT} ports are already in use.")


_CONNECTION_RETRY_TOTAL = 2880
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
_DEFAULT_BACKOFF_MAX = 5 * 60 # seconds
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/app/utilities/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from deepdiff import DeepDiff, Delta
from lightning_utilities.core.apply_func import apply_to_collection

from lightning.app.core import constants
from lightning.app.core.queues import MultiProcessQueue
from lightning.app.storage import Path
from lightning.app.storage.copier import _Copier, _copy_files
Expand Down Expand Up @@ -500,7 +501,8 @@ def run_once(self):
# Set the internal IP address.
# Set this here after the state observer is initialized, since it needs to record it as a change and send
# it back to the flow
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", "127.0.0.1")
default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0"
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", default_internal_ip)

# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
# send delta while calling `set_state`.
Expand Down
24 changes: 24 additions & 0 deletions tests/tests_app/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import py
import pytest

from lightning.app.core import constants
from lightning.app.storage.path import _storage_root_dir
from lightning.app.utilities.app_helpers import _collect_child_process_pids
from lightning.app.utilities.component import _set_context
Expand Down Expand Up @@ -115,3 +116,26 @@ def caplog(caplog):
root_logger.propagate = root_propagate
for name, propagate in propagation_dict.items():
logging.getLogger(name).propagate = propagate


@pytest.fixture
def patch_constants(request):
"""This fixture can be used with indirect parametrization to patch values in `lightning.app.core.constants` for
the duration of a test.

Example::

@pytest.mark.parametrize("patch_constants", [{"LIGHTNING_CLOUDSPACE_HOST": "any"}], indirect=True)
def test_my_stuff(patch_constants):
...
"""
# Set constants
old_constants = {}
for constant, value in request.param.items():
old_constants[constant] = getattr(constants, constant)
setattr(constants, constant, value)

yield

for constant, value in old_constants.items():
setattr(constants, constant, value)
38 changes: 37 additions & 1 deletion tests/tests_app/utilities/test_network.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,45 @@
from unittest import mock

import pytest

from lightning.app.utilities.network import find_free_network_port, LightningClient


def test_port():
def test_find_free_network_port():
"""Tests that `find_free_network_port` gives expected outputs and raises if a free port couldn't be found."""
assert find_free_network_port()

with mock.patch("lightning.app.utilities.network.socket") as mock_socket:
mock_socket.socket().getsockname.return_value = [0, 8888]
assert find_free_network_port() == 8888

with pytest.raises(RuntimeError, match="Couldn't find a free port."):
find_free_network_port()

mock_socket.socket().getsockname.return_value = [0, 9999]
assert find_free_network_port() == 9999


@mock.patch("lightning.app.utilities.network.socket")
@pytest.mark.parametrize(
"patch_constants",
[{"LIGHTNING_CLOUDSPACE_HOST": "any", "LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT": 10}],
indirect=True,
)
def test_find_free_network_port_cloudspace(_, patch_constants):
"""Tests that `find_free_network_port` gives expected outputs and raises if a free port couldn't be found when
cloudspace env variables are set."""
ports = set()
num_ports = 0

with pytest.raises(RuntimeError, match="All 10 ports are already in use."):
for _ in range(11):
ports.add(find_free_network_port())
num_ports = num_ports + 1

# Check that all ports are unique
assert len(ports) == num_ports


def test_lightning_client_retry_enabled():

Expand Down
10 changes: 8 additions & 2 deletions tests/tests_app/utilities/test_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,15 @@ def test_state_observer():


@pytest.mark.parametrize(
"environment, expected_ip_addr", [({}, "127.0.0.1"), ({"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5")]
"patch_constants, environment, expected_ip_addr",
[
({}, {}, "127.0.0.1"),
({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"),
({}, {"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5"),
],
indirect=["patch_constants"],
)
def test_work_runner_sets_internal_ip(environment, expected_ip_addr):
def test_work_runner_sets_internal_ip(patch_constants, environment, expected_ip_addr):
"""Test that the WorkRunner updates the internal ip address as soon as the Work starts running."""

class Work(LightningWork):
Expand Down