Skip to content

Commit d48aa03

Browse files
tchatonlantiga
andauthored
Slightly safer multi node (#15538)
update Co-authored-by: Luca Antiga <[email protected]>
1 parent dcfaa06 commit d48aa03

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

src/lightning_app/components/multi_node.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,30 @@ def run(
5858
self._cloud_compute = cloud_compute
5959
self._work_args = work_args
6060
self._work_kwargs = work_kwargs
61-
self.has_initialized = False
61+
self.has_started = False
6262

6363
def run(self) -> None:
64-
# 1. Create & start the works
65-
if not self.has_initialized:
66-
for node_rank in range(self.nodes):
67-
self.ws.append(
68-
self._work_cls(
69-
*self._work_args,
70-
cloud_compute=self._cloud_compute,
71-
**self._work_kwargs,
72-
parallel=True,
64+
if not self.has_started:
65+
66+
# 1. Create & start the works
67+
if not self.ws:
68+
for node_rank in range(self.nodes):
69+
self.ws.append(
70+
self._work_cls(
71+
*self._work_args,
72+
cloud_compute=self._cloud_compute,
73+
**self._work_kwargs,
74+
parallel=True,
75+
)
7376
)
74-
)
75-
# Starting node `node_rank`` ...
76-
self.ws[-1].start()
77-
self.has_initialized = True
77+
# Starting node `node_rank`` ...
78+
self.ws[-1].start()
79+
80+
# 2. Wait for all machines to be started !
81+
if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
82+
return
7883

79-
# 2. Wait for all machines to be started !
80-
if all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
81-
return
84+
self.has_started = True
8285

8386
# Loop over all node machines
8487
for node_rank in range(self.nodes):

0 commit comments

Comments
 (0)