Skip to content

Commit 757413c

Browse files
authored
[App] Accelerate Multi Node Startup Time (#15650)
1 parent 4e8cf85 commit 757413c

File tree

13 files changed

+144
-152
lines changed

13 files changed

+144
-152
lines changed

examples/app_multi_node/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ lightning run app train_lite.py
2828

2929
Using Lite, you retain control over your loops while accessing in a minimal way all Lightning distributed strategies.
3030

31-
## Multi Node with PyTorch Lightning
31+
## Multi Node with Lightning Trainer
3232

33-
Lightning supports running PyTorch Lightning from a script or within a Lightning Work.
33+
Lightning supports running Lightning Trainer from a script or within a Lightning Work.
3434

3535
You can either run a script directly
3636

src/lightning_app/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7474
- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))
7575

7676

77+
- Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650))
78+
7779

7880
## [1.8.0] - 2022-11-01
7981

src/lightning_app/components/database/server.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
import tempfile
66
import threading
7+
import traceback
78
from typing import List, Optional, Type, Union
89

910
import uvicorn
@@ -36,6 +37,9 @@ def install_signal_handlers(self):
3637
"""Ignore Uvicorn Signal Handlers."""
3738

3839

40+
_lock = threading.Lock()
41+
42+
3943
class Database(LightningWork):
4044
def __init__(
4145
self,
@@ -146,25 +150,29 @@ class CounterModel(SQLModel, table=True):
146150
self._exit_event = None
147151

148152
def store_database(self):
149-
with tempfile.TemporaryDirectory() as tmpdir:
150-
tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))
153+
try:
154+
with tempfile.TemporaryDirectory() as tmpdir:
155+
tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))
151156

152-
source = sqlite3.connect(self.db_filename)
153-
dest = sqlite3.connect(tmp_db_filename)
157+
source = sqlite3.connect(self.db_filename)
158+
dest = sqlite3.connect(tmp_db_filename)
154159

155-
source.backup(dest)
160+
source.backup(dest)
156161

157-
source.close()
158-
dest.close()
162+
source.close()
163+
dest.close()
159164

160-
drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
161-
drive.put(os.path.basename(tmp_db_filename))
165+
drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
166+
drive.put(os.path.basename(tmp_db_filename))
162167

163-
print("Stored the database to the Drive.")
168+
print("Stored the database to the Drive.")
169+
except Exception:
170+
print(traceback.print_exc())
164171

165172
def periodic_store_database(self, store_interval):
166173
while not self._exit_event.is_set():
167-
self.store_database()
174+
with _lock:
175+
self.store_database()
168176
self._exit_event.wait(store_interval)
169177

170178
def run(self, token: Optional[str] = None) -> None:
@@ -210,4 +218,5 @@ def db_url(self) -> Optional[str]:
210218

211219
def on_exit(self):
212220
self._exit_event.set()
213-
self.store_database()
221+
with _lock:
222+
self.store_database()

src/lightning_app/components/multi_node/base.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from lightning_app import structures
44
from lightning_app.core.flow import LightningFlow
55
from lightning_app.core.work import LightningWork
6-
from lightning_app.utilities.enum import WorkStageStatus
76
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
87

98

@@ -52,46 +51,31 @@ def run(
5251
work_kwargs: Keywords arguments to be provided to the work on instantiation.
5352
"""
5453
super().__init__()
55-
self.ws = structures.List()
56-
self._work_cls = work_cls
57-
self.num_nodes = num_nodes
58-
self._cloud_compute = cloud_compute
59-
self._work_args = work_args
60-
self._work_kwargs = work_kwargs
61-
self.has_started = False
54+
self.ws = structures.List(
55+
*[
56+
work_cls(
57+
*work_args,
58+
cloud_compute=cloud_compute,
59+
**work_kwargs,
60+
parallel=True,
61+
)
62+
for _ in range(num_nodes)
63+
]
64+
)
6265

6366
def run(self) -> None:
64-
if not self.has_started:
65-
66-
# 1. Create & start the works
67-
if not self.ws:
68-
for node_rank in range(self.num_nodes):
69-
self.ws.append(
70-
self._work_cls(
71-
*self._work_args,
72-
cloud_compute=self._cloud_compute,
73-
**self._work_kwargs,
74-
parallel=True,
75-
)
76-
)
77-
78-
# Starting node `node_rank`` ...
79-
self.ws[-1].start()
80-
81-
# 2. Wait for all machines to be started !
82-
if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
83-
return
84-
85-
self.has_started = True
67+
# 1. Wait for all works to be started !
68+
if not all(w.internal_ip for w in self.ws):
69+
return
8670

87-
# Loop over all node machines
88-
for node_rank in range(self.num_nodes):
71+
# 2. Loop over all node machines
72+
for node_rank in range(len(self.ws)):
8973

9074
# 3. Run the user code in a distributed way !
9175
self.ws[node_rank].run(
9276
main_address=self.ws[0].internal_ip,
9377
main_port=self.ws[0].port,
94-
num_nodes=self.num_nodes,
78+
num_nodes=len(self.ws),
9579
node_rank=node_rank,
9680
)
9781

src/lightning_app/core/app.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def _run(self) -> bool:
472472
self._original_state = deepcopy(self.state)
473473
done = False
474474

475+
self._start_with_flow_works()
476+
475477
if self.should_publish_changes_to_api and self.api_publish_state_queue:
476478
logger.debug("Publishing the state with changes")
477479
# Push two states to optimize start in the cloud.
@@ -668,3 +670,11 @@ def _send_flow_to_work_deltas(self, state) -> None:
668670
if deep_diff:
669671
logger.debug(f"Sending deep_diff to {w.name} : {deep_diff}")
670672
self.flow_to_work_delta_queues[w.name].put(deep_diff)
673+
674+
def _start_with_flow_works(self):
675+
for w in self.works:
676+
if w._start_with_flow:
677+
parallel = w.parallel
678+
w._parallel = True
679+
w.start()
680+
w._parallel = parallel

src/lightning_app/runners/cloud.py

Lines changed: 64 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -142,78 +142,77 @@ def dispatch(
142142
v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))
143143

144144
works: List[V1Work] = []
145-
for flow in self.app.flows:
146-
for work in flow.works(recurse=False):
147-
if not work._start_with_flow:
148-
continue
149-
150-
work_requirements = "\n".join(work.cloud_build_config.requirements)
151-
build_spec = V1BuildSpec(
152-
commands=work.cloud_build_config.build_commands(),
153-
python_dependencies=V1PythonDependencyInfo(
154-
package_manager=V1PackageManager.PIP, packages=work_requirements
155-
),
156-
image=work.cloud_build_config.image,
157-
)
158-
user_compute_config = V1UserRequestedComputeConfig(
159-
name=work.cloud_compute.name,
160-
count=1,
161-
disk_size=work.cloud_compute.disk_size,
162-
preemptible=work.cloud_compute.preemptible,
163-
shm_size=work.cloud_compute.shm_size,
164-
)
145+
for work in self.app.works:
146+
if not work._start_with_flow:
147+
continue
165148

166-
drive_specs: List[V1LightningworkDrives] = []
167-
for drive_attr_name, drive in [
168-
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
169-
]:
170-
if drive.protocol == "lit://":
171-
drive_type = V1DriveType.NO_MOUNT_S3
172-
source_type = V1SourceType.S3
173-
else:
174-
raise RuntimeError(
175-
f"unknown drive protocol `{drive.protocol}`. Please verify this "
176-
f"drive type has been configured for use in the cloud dispatcher."
177-
)
149+
work_requirements = "\n".join(work.cloud_build_config.requirements)
150+
build_spec = V1BuildSpec(
151+
commands=work.cloud_build_config.build_commands(),
152+
python_dependencies=V1PythonDependencyInfo(
153+
package_manager=V1PackageManager.PIP, packages=work_requirements
154+
),
155+
image=work.cloud_build_config.image,
156+
)
157+
user_compute_config = V1UserRequestedComputeConfig(
158+
name=work.cloud_compute.name,
159+
count=1,
160+
disk_size=work.cloud_compute.disk_size,
161+
preemptible=work.cloud_compute.preemptible,
162+
shm_size=work.cloud_compute.shm_size,
163+
)
178164

179-
drive_specs.append(
180-
V1LightningworkDrives(
181-
drive=V1Drive(
182-
metadata=V1Metadata(
183-
name=f"{work.name}.{drive_attr_name}",
184-
),
185-
spec=V1DriveSpec(
186-
drive_type=drive_type,
187-
source_type=source_type,
188-
source=f"{drive.protocol}{drive.id}",
189-
),
190-
status=V1DriveStatus(),
165+
drive_specs: List[V1LightningworkDrives] = []
166+
for drive_attr_name, drive in [
167+
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
168+
]:
169+
if drive.protocol == "lit://":
170+
drive_type = V1DriveType.NO_MOUNT_S3
171+
source_type = V1SourceType.S3
172+
else:
173+
raise RuntimeError(
174+
f"unknown drive protocol `{drive.protocol}`. Please verify this "
175+
f"drive type has been configured for use in the cloud dispatcher."
176+
)
177+
178+
drive_specs.append(
179+
V1LightningworkDrives(
180+
drive=V1Drive(
181+
metadata=V1Metadata(
182+
name=f"{work.name}.{drive_attr_name}",
183+
),
184+
spec=V1DriveSpec(
185+
drive_type=drive_type,
186+
source_type=source_type,
187+
source=f"{drive.protocol}{drive.id}",
191188
),
192-
mount_location=str(drive.root_folder),
189+
status=V1DriveStatus(),
193190
),
194-
)
191+
mount_location=str(drive.root_folder),
192+
),
193+
)
195194

196-
# TODO: Move this to the CloudCompute class and update backend
197-
if work.cloud_compute.mounts is not None:
198-
mounts = work.cloud_compute.mounts
199-
if isinstance(mounts, Mount):
200-
mounts = [mounts]
201-
for mount in mounts:
202-
drive_specs.append(
203-
_create_mount_drive_spec(
204-
work_name=work.name,
205-
mount=mount,
206-
)
195+
# TODO: Move this to the CloudCompute class and update backend
196+
if work.cloud_compute.mounts is not None:
197+
mounts = work.cloud_compute.mounts
198+
if isinstance(mounts, Mount):
199+
mounts = [mounts]
200+
for mount in mounts:
201+
drive_specs.append(
202+
_create_mount_drive_spec(
203+
work_name=work.name,
204+
mount=mount,
207205
)
206+
)
208207

209-
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
210-
work_spec = V1LightningworkSpec(
211-
build_spec=build_spec,
212-
drives=drive_specs,
213-
user_requested_compute_config=user_compute_config,
214-
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
215-
)
216-
works.append(V1Work(name=work.name, spec=work_spec))
208+
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
209+
work_spec = V1LightningworkSpec(
210+
build_spec=build_spec,
211+
drives=drive_specs,
212+
user_requested_compute_config=user_compute_config,
213+
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
214+
)
215+
works.append(V1Work(name=work.name, spec=work_spec))
217216

218217
# We need to collect a spec for each flow that contains a frontend so that the backend knows
219218
# for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below)

src/lightning_app/utilities/proxies.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ class ProxyWorkRun:
103103
caller_queue: "BaseQueue"
104104

105105
def __post_init__(self):
106-
self.cache_calls = self.work.cache_calls
107-
self.parallel = self.work.parallel
108106
self.work_state = None
109107

110108
def __call__(self, *args, **kwargs):
@@ -123,7 +121,7 @@ def __call__(self, *args, **kwargs):
123121

124122
# The if/else conditions are left un-compressed to simplify readability
125123
# for the readers.
126-
if self.cache_calls:
124+
if self.work.cache_calls:
127125
if not entered or stopped_on_sigterm:
128126
_send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash)
129127
else:
@@ -137,7 +135,7 @@ def __call__(self, *args, **kwargs):
137135
# the previous task has completed and we can re-queue the next one.
138136
# overriding the return value for next loop iteration.
139137
_send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash)
140-
if not self.parallel:
138+
if not self.work.parallel:
141139
raise CacheMissException("Task never called before. Triggered now")
142140

143141
def _validate_call_args(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:

tests/tests_app/components/database/test_client_server.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import tempfile
44
import time
5+
import traceback
56
from pathlib import Path
67
from time import sleep
78
from typing import List, Optional
@@ -197,7 +198,9 @@ def run(self):
197198
assert len(self._client.select_all()) == 1
198199
self._exit()
199200

200-
with tempfile.TemporaryDirectory() as tmpdir:
201-
202-
app = LightningApp(Flow(tmpdir))
203-
MultiProcessRuntime(app).dispatch()
201+
try:
202+
with tempfile.TemporaryDirectory() as tmpdir:
203+
app = LightningApp(Flow(tmpdir))
204+
MultiProcessRuntime(app).dispatch()
205+
except Exception:
206+
print(traceback.print_exc())

tests/tests_app/core/test_lightning_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
class WorkA(LightningWork):
4444
def __init__(self):
45-
super().__init__(parallel=True)
45+
super().__init__(parallel=True, start_with_flow=False)
4646
self.var_a = 0
4747
self.drive = Drive("lit://test_app_state_api")
4848

0 commit comments

Comments
 (0)