Skip to content

Commit f6bcdb5

Browse files
ethanwharrisBorda
authored andcommitted
[App] Add support for running with multiprocessing in the cloud (#16624)
(cherry picked from commit fd61ed0)
1 parent 182010c commit f6bcdb5

File tree

8 files changed

+135
-12
lines changed

8 files changed

+135
-12
lines changed

src/lightning_app/core/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def get_lightning_cloud_url() -> str:
6666
LIGHTNING_COMPONENT_PUBLIC_REGISTRY = "https://lightning.ai/v1/components"
6767
LIGHTNING_APPS_PUBLIC_REGISTRY = "https://lightning.ai/v1/apps"
6868

69+
LIGHTNING_CLOUDSPACE_HOST = os.getenv("LIGHTNING_CLOUDSPACE_HOST")
70+
LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT = int(os.getenv("LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT", "0"))
71+
6972
# EXPERIMENTAL: ENV VARIABLES TO ENABLE MULTIPLE WORKS IN THE SAME MACHINE
7073
DEFAULT_NUMBER_OF_EXPOSED_PORTS = int(os.getenv("DEFAULT_NUMBER_OF_EXPOSED_PORTS", "50"))
7174
ENABLE_MULTIPLE_WORKS_IN_NON_DEFAULT_CONTAINER = bool(

src/lightning_app/runners/backends/mp_process.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List, Optional
1717

1818
import lightning_app
19+
from lightning_app.core import constants
1920
from lightning_app.core.queues import QueuingSystem
2021
from lightning_app.runners.backends.backend import Backend, WorkManager
2122
from lightning_app.utilities.enum import WorkStageStatus
@@ -76,6 +77,12 @@ def __init__(self, entrypoint_file: str):
7677
super().__init__(entrypoint_file=entrypoint_file, queues=QueuingSystem.MULTIPROCESS, queue_id="0")
7778

7879
def create_work(self, app, work) -> None:
80+
if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
81+
# Override the port if set by the user
82+
work._port = find_free_network_port()
83+
work._host = "0.0.0.0"
84+
work._future_url = f"https://{work.port}-{constants.LIGHTNING_CLOUDSPACE_HOST}"
85+
7986
app.processes[work.name] = MultiProcessWorkManager(app, work)
8087
app.processes[work.name].start()
8188
self.resolve_url(app)

src/lightning_app/runners/multiprocess.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import click
2121

2222
from lightning_app.api.http_methods import _add_tags_to_api, _validate_api
23+
from lightning_app.core import constants
2324
from lightning_app.core.api import start_server
24-
from lightning_app.core.constants import APP_SERVER_IN_CLOUD
2525
from lightning_app.runners.backends import Backend
2626
from lightning_app.runners.runtime import Runtime
2727
from lightning_app.storage.orchestrator import StorageOrchestrator
@@ -50,7 +50,8 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
5050
_set_flow_context()
5151

5252
# Note: In case the runtime is used in the cloud.
53-
self.host = "0.0.0.0" if APP_SERVER_IN_CLOUD else self.host
53+
in_cloudspace = constants.LIGHTNING_CLOUDSPACE_HOST is not None
54+
self.host = "0.0.0.0" if constants.APP_SERVER_IN_CLOUD or in_cloudspace else self.host
5455

5556
self.app.backend = self.backend
5657
self.backend._prepare_queues(self.app)
@@ -116,7 +117,7 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
116117
# wait for server to be ready
117118
has_started_queue.get()
118119

119-
if open_ui and not _is_headless(self.app):
120+
if open_ui and not _is_headless(self.app) and constants.LIGHTNING_CLOUDSPACE_HOST is None:
120121
click.launch(self._get_app_url())
121122

122123
# Connect the runtime to the application.
@@ -134,7 +135,7 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any):
134135
self.terminate()
135136

136137
def terminate(self):
137-
if APP_SERVER_IN_CLOUD:
138+
if constants.APP_SERVER_IN_CLOUD:
138139
# Close all the ports open for the App within the App.
139140
ports = [self.port] + getattr(self.backend, "ports", [])
140141
for port in ports:

src/lightning_app/utilities/network.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,64 @@
2727
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout
2828
from urllib3.util.retry import Retry
2929

30+
from lightning_app.core import constants
3031
from lightning_app.utilities.app_helpers import Logger
3132

3233
logger = Logger(__name__)
3334

3435

36+
# Global record to track ports that have been allocated in this session.
37+
_reserved_ports = set()
38+
39+
3540
def find_free_network_port() -> int:
3641
"""Finds a free port on localhost."""
37-
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38-
s.bind(("", 0))
39-
port = s.getsockname()[1]
40-
s.close()
42+
if constants.LIGHTNING_CLOUDSPACE_HOST is not None:
43+
return _find_free_network_port_cloudspace()
44+
45+
port = None
46+
47+
for _ in range(10):
48+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
49+
sock.bind(("", 0))
50+
port = sock.getsockname()[1]
51+
sock.close()
52+
53+
if port not in _reserved_ports:
54+
break
55+
56+
if port in _reserved_ports:
57+
# Prevent an infinite loop, if we tried 10 times and didn't get a free port then something is wrong
58+
raise RuntimeError(
59+
"Couldn't find a free port. Please open an issue at `https://github.com/Lightning-AI/lightning/issues`."
60+
)
61+
62+
_reserved_ports.add(port)
4163
return port
4264

4365

66+
def _find_free_network_port_cloudspace():
67+
"""Finds a free port in the exposed range when running in a cloudspace."""
68+
for port in range(
69+
constants.APP_SERVER_PORT,
70+
constants.APP_SERVER_PORT + constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT,
71+
):
72+
if port in _reserved_ports:
73+
continue
74+
75+
try:
76+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
77+
sock.bind(("", port))
78+
sock.close()
79+
_reserved_ports.add(port)
80+
return port
81+
except OSError:
82+
continue
83+
84+
# This error should never happen. An app using this many ports would probably fail on a single machine anyway.
85+
raise RuntimeError(f"All {constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT} ports are already in use.")
86+
87+
4488
_CONNECTION_RETRY_TOTAL = 2880
4589
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
4690
_DEFAULT_BACKOFF_MAX = 5 * 60 # seconds

src/lightning_app/utilities/proxies.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from deepdiff import DeepDiff, Delta
3232
from lightning_utilities.core.apply_func import apply_to_collection
3333

34+
from lightning_app.core import constants
3435
from lightning_app.core.queues import MultiProcessQueue
3536
from lightning_app.storage import Path
3637
from lightning_app.storage.copier import _Copier, _copy_files
@@ -500,7 +501,8 @@ def run_once(self):
500501
# Set the internal IP address.
501502
# Set this here after the state observer is initialized, since it needs to record it as a change and send
502503
# it back to the flow
503-
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", "127.0.0.1")
504+
default_internal_ip = "127.0.0.1" if constants.LIGHTNING_CLOUDSPACE_HOST is None else "0.0.0.0"
505+
self.work._internal_ip = os.environ.get("LIGHTNING_NODE_IP", default_internal_ip)
504506

505507
# 8. Patch the setattr method of the work. This needs to be done after step 4, so we don't
506508
# send delta while calling `set_state`.

tests/tests_app/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import py
1111
import pytest
1212

13+
from lightning_app.core import constants
1314
from lightning_app.storage.path import _storage_root_dir
1415
from lightning_app.utilities.app_helpers import _collect_child_process_pids
1516
from lightning_app.utilities.component import _set_context
@@ -115,3 +116,26 @@ def caplog(caplog):
115116
root_logger.propagate = root_propagate
116117
for name, propagate in propagation_dict.items():
117118
logging.getLogger(name).propagate = propagate
119+
120+
121+
@pytest.fixture
122+
def patch_constants(request):
123+
"""This fixture can be used with indirect parametrization to patch values in `lightning.app.core.constants` for
124+
the duration of a test.
125+
126+
Example::
127+
128+
@pytest.mark.parametrize("patch_constants", [{"LIGHTNING_CLOUDSPACE_HOST": "any"}], indirect=True)
129+
def test_my_stuff(patch_constants):
130+
...
131+
"""
132+
# Set constants
133+
old_constants = {}
134+
for constant, value in request.param.items():
135+
old_constants[constant] = getattr(constants, constant)
136+
setattr(constants, constant, value)
137+
138+
yield
139+
140+
for constant, value in old_constants.items():
141+
setattr(constants, constant, value)

tests/tests_app/utilities/test_network.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,45 @@
1+
from unittest import mock
2+
3+
import pytest
4+
15
from lightning_app.utilities.network import find_free_network_port, LightningClient
26

37

4-
def test_port():
8+
def test_find_free_network_port():
9+
"""Tests that `find_free_network_port` gives expected outputs and raises if a free port couldn't be found."""
510
assert find_free_network_port()
611

12+
with mock.patch("lightning.app.utilities.network.socket") as mock_socket:
13+
mock_socket.socket().getsockname.return_value = [0, 8888]
14+
assert find_free_network_port() == 8888
15+
16+
with pytest.raises(RuntimeError, match="Couldn't find a free port."):
17+
find_free_network_port()
18+
19+
mock_socket.socket().getsockname.return_value = [0, 9999]
20+
assert find_free_network_port() == 9999
21+
22+
23+
@mock.patch("lightning.app.utilities.network.socket")
24+
@pytest.mark.parametrize(
25+
"patch_constants",
26+
[{"LIGHTNING_CLOUDSPACE_HOST": "any", "LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT": 10}],
27+
indirect=True,
28+
)
29+
def test_find_free_network_port_cloudspace(_, patch_constants):
30+
"""Tests that `find_free_network_port` gives expected outputs and raises if a free port couldn't be found when
31+
cloudspace env variables are set."""
32+
ports = set()
33+
num_ports = 0
34+
35+
with pytest.raises(RuntimeError, match="All 10 ports are already in use."):
36+
for _ in range(11):
37+
ports.add(find_free_network_port())
38+
num_ports = num_ports + 1
39+
40+
# Check that all ports are unique
41+
assert len(ports) == num_ports
42+
743

844
def test_lightning_client_retry_enabled():
945

tests/tests_app/utilities/test_proxies.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,15 @@ def test_state_observer():
644644

645645

646646
@pytest.mark.parametrize(
647-
"environment, expected_ip_addr", [({}, "127.0.0.1"), ({"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5")]
647+
"patch_constants, environment, expected_ip_addr",
648+
[
649+
({}, {}, "127.0.0.1"),
650+
({"LIGHTNING_CLOUDSPACE_HOST": "any"}, {}, "0.0.0.0"),
651+
({}, {"LIGHTNING_NODE_IP": "10.10.10.5"}, "10.10.10.5"),
652+
],
653+
indirect=["patch_constants"],
648654
)
649-
def test_work_runner_sets_internal_ip(environment, expected_ip_addr):
655+
def test_work_runner_sets_internal_ip(patch_constants, environment, expected_ip_addr):
650656
"""Test that the WorkRunner updates the internal ip address as soon as the Work starts running."""
651657

652658
class Work(LightningWork):

0 commit comments

Comments
 (0)