diff --git a/src/lightning/app/core/constants.py b/src/lightning/app/core/constants.py index 35b0b44a073c3..ca482c4cb9de6 100644 --- a/src/lightning/app/core/constants.py +++ b/src/lightning/app/core/constants.py @@ -67,6 +67,9 @@ def get_lightning_cloud_url() -> str: LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps" LIGHTNING_MODELS_PUBLIC_REGISTRY = "https://lightning.ai/v1/models" +LIGHTNING_CLOUDSPACE_HOST = os.getenv("LIGHTNING_CLOUDSPACE_HOST") +LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT = int(os.getenv("LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT", "0")) + # EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50")) ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool( diff --git a/src/lightning/app/runners/backends/mp_process.py b/src/lightning/app/runners/backends/mp_process.py index 3c914fc92828d..882d26e3dfa99 100644 --- a/src/lightning/app/runners/backends/mp_process.py +++ b/src/lightning/app/runners/backends/mp_process.py @@ -16,10 +16,11 @@ from typing import List, Optional import lightning.app +from lightning.app.core import constants from lightning.app.core.queues import QueuingSystem from lightning.app.runners.backends.backend import Backend, WorkManager from lightning.app.utilities.enum import WorkStageStatus -from lightning.app.utilities.network import _check_service_url_is_ready +from lightning.app.utilities.network import _check_service_url_is_ready, find_free_network_port from lightning.app.utilities.port import disable_port, enable_port from lightning.app.utilities.proxies import ProxyWorkRun, WorkRunner @@ -76,6 +77,12 @@ def __init__(self, entrypoint_file: str): super().__init__(entrypoint_file=entrypoint_file, queues=QueuingSystem.MULTIPROCESS, queue_id="0") def create_work(self, app, work) -> None: + if constants.LIGHTNING_CLOUDSPACE_HOST is not None: + # Override the port if set by the user + work._port = find_free_network_port() + work._host = "0.0.0.0" + work._future_url = f"https://{work.port}-{constants.LIGHTNING_CLOUDSPACE_HOST}" + app.processes[work.name] = MultiProcessWorkManager(app, work) app.processes[work.name].start() self.resolve_url(app) diff --git a/src/lightning/app/runners/multiprocess.py b/src/lightning/app/runners/multiprocess.py index 5b219e0e79693..9e016047e4b3e 100644 --- a/src/lightning/app/runners/multiprocess.py +++ b/src/lightning/app/runners/multiprocess.py @@ -20,8 +20,8 @@ import click from lightning.app.api.http_methods import _add_tags_to_api, _validate_api +from lightning.app.core import constants from lightning.app.core.api import start_server -from lightning.app.core.constants import APP_SERVER_IN_CLOUD from lightning.app.runners.backends import Backend from lightning.app.runners.runtime import Runtime from lightning.app.storage.orchestrator import StorageOrchestrator @@ -50,7 +50,8 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any): _set_flow_context() # Note: In case the runtime is used in the cloud. - self.host = "0.0.0.0" if APP_SERVER_IN_CLOUD else self.host + in_cloudspace = constants.LIGHTNING_CLOUDSPACE_HOST is not None + self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host self.app.backend = self.backend self.backend._prepare_queues(self.app) @@ -116,7 +117,7 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any): # wait for server to be ready has_started_queue.get() - if open_ui and not _is_headless(self.app): + if open_ui and not _is_headless(self.app) and constants.LIGHTNING_CLOUDSPACE_HOST is None: click.launch(self._get_app_url()) # Connect the runtime to the application. @@ -134,7 +135,7 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any): self.terminate() def terminate(self): - if APP_SERVER_IN_CLOUD: + if constants.APP_SERVER_IN_CLOUD: # Close all the ports open for the App within the App. ports = [self.port] + getattr(self.backend, "ports", []) for port in ports: diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py index 3265d8fb57844..99b09d77a94b7 100644 --- a/src/lightning/app/utilities/network.py +++ b/src/lightning/app/utilities/network.py @@ -27,20 +27,64 @@ from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout from urllib3.util.retry import Retry +from lightning.app.core import constants from lightning.app.utilities.app_helpers import Logger logger = Logger(__name__) +# Global record to track ports that have been allocated in this session. +_reserved_ports = set() + + def find_free_network_port() -> int: """Finds a free port on localhost.""" - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - port = s.getsockname()[1] - s.close() + if constants.LIGHTNING_CLOUDSPACE_HOST is not None: + return _find_free_network_port_cloudspace() + + port = None + + for _ in range(10): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + + if port not in _reserved_ports: + break + + if port in _reserved_ports: + # Prevent an infinite loop, if we tried 10 times and didn't get a free port then something is wrong + raise RuntimeError( + "Couldn't find a free port. Please open an issue at `https://github.com/Lightning-AI/lightning/issues`." + ) + + _reserved_ports.add(port) return port +def _find_free_network_port_cloudspace(): + """Finds a free port in the exposed range when running in a cloudspace.""" + for port in range( + constants.APP_SERVER_PORT, + constants.APP_SERVER_PORT + constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT, + ): + if port in _reserved_ports: + continue + + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", port)) + sock.close() + _reserved_ports.add(port) + return port + except OSError: + continue + + # This error should never happen. An app using this many ports would probably fail on a single machine anyway. + raise RuntimeError(f"All {constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT} ports are already in use.") + + _CONNECTION_RETRY_TOTAL = 2880 _CONNECTION_RETRY_BACKOFF_FACTOR = 0.5 _DEFAULT_BACKOFF_MAX = 5 * 60 # seconds diff --git a/src/lightning/app/utilities/proxies.py b/src/lightning/app/utilities/proxies.py index 973ffc3c4ca54..6cd176104a45b 100644 --- a/src/lightning/app/utilities/proxies.py +++ b/src/lightning/app/utilities/proxies.py @@ -31,6 +31,7 @@ from deepdiff import DeepDiff, Delta from lightning_utilities.core.apply_func import apply_to_collection +from lightning.app.core import constants from lightning.app.core.queues import MultiProcessQueue from lightning.app.storage import Path from lightning.app.storage.copier import _Copier, _copy_files @@ -500,7 +501,8 @@ def run_once(self): # Set the internal IP address. # Set this here after the state observer is initialized, since it needs to record it as a change and send # it back to the flow - self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", "127.0.0.1") + default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0" + self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", default_internal_ip) # 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't # send delta while calling `set_state`. diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py index 1cdd98fe20555..f3319e7b98d66 100644 --- a/tests/tests_app/conftest.py +++ b/tests/tests_app/conftest.py @@ -10,6 +10,7 @@ import py import pytest +from lightning.app.core import constants from lightning.app.storage.path import _storage_root_dir from lightning.app.utilities.app_helpers import _collect_child_process_pids from lightning.app.utilities.component import _set_context @@ -115,3 +116,26 @@ def caplog(caplog): root_logger.propagate = root_propagate for name, propagate in propagation_dict.items(): logging.getLogger(name).propagate = propagate + + +@pytest.fixture +def patch_constants(request): + """This fixture can be used with indirect parametrization to patch values in `lightning.app.core.constants` for + the duration of a test. + + Example:: + + @pytest.mark.parametrize("patch_constants", [{"LIGHTNING_CLOUDSPACE_HOST": "any"}], indirect=True) + def test_my_stuff(patch_constants): + ... + """ + # Set constants + old_constants = {} + for constant, value in request.param.items(): + old_constants[constant] = getattr(constants, constant) + setattr(constants, constant, value) + + yield + + for constant, value in old_constants.items(): + setattr(constants, constant, value) diff --git a/tests/tests_app/utilities/test_network.py b/tests/tests_app/utilities/test_network.py index 995ca92eb80c7..1795d5d524966 100644 --- a/tests/tests_app/utilities/test_network.py +++ b/tests/tests_app/utilities/test_network.py @@ -1,9 +1,45 @@ +from unittest import mock + +import pytest + from lightning.app.utilities.network import find_free_network_port, LightningClient -def test_port(): +def test_find_free_network_port(): + """Tests that `find_free_network_port` gives expected outputs and raises if a free port couldn't be found.""" assert find_free_network_port() + with mock.patch("lightning.app.utilities.network.socket") as mock_socket: + mock_socket.socket().getsockname.return_value = [0, 8888] + assert find_free_network_port() == 8888 + + with pytest.raises(RuntimeError, match="Couldn't find a free port."): + find_free_network_port() + + mock_socket.socket().getsockname.return_value = [0, 9999] + assert find_free_network_port() == 9999 + + +@mock.patch("lightning.app.utilities.network.socket") +@pytest.mark.parametrize( + "patch_constants", + [{"LIGHTNING_CLOUDSPACE_HOST": "any", "LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT": 10}], + indirect=True, +) +def test_find_free_network_port_cloudspace(_, patch_constants): + """Tests that `find_free_network_port` gives expected outputs and raises if a free port couldn't be found when + cloudspace env variables are set.""" + ports = set() + num_ports = 0 + + with pytest.raises(RuntimeError, match="All 10 ports are already in use."): + for _ in range(11): + ports.add(find_free_network_port()) + num_ports = num_ports + 1 + + # Check that all ports are unique + assert len(ports) == num_ports + def test_lightning_client_retry_enabled(): diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index e858407134e54..05e5dd3d875e4 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -644,9 +644,15 @@ def test_state_observer(): @pytest.mark.parametrize( - "environment, expected_ip_addr", [({}, "127.0.0.1"), ({"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5")] + "patch_constants, environment, expected_ip_addr", + [ + ({}, {}, "127.0.0.1"), + ({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"), + ({}, {"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5"), + ], + indirect=["patch_constants"], ) -def test_work_runner_sets_internal_ip(environment, expected_ip_addr): +def test_work_runner_sets_internal_ip(patch_constants, environment, expected_ip_addr): """Test that the WorkRunner updates the internal ip address as soon as the Work starts running.""" class Work(LightningWork):