Skip to content

Commit d2a8fbf

Browse files
authored
[App] Enable running with spawn context (#15923)
1 parent 2debd1c commit d2a8fbf

File tree

7 files changed

+27
-5
lines changed

7 files changed

+27
-5
lines changed

examples/app_installation_commands/app.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ def run(self):
1313
print("lmdb successfully installed")
1414
print("accessing a module in a Work or Flow body works!")
1515

16+
@property
17+
def ready(self) -> bool:
18+
return True
19+
1620

1721
print(f"accessing an object in main code body works!: version={lmdb.version()}")
1822

src/lightning_app/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))
2121

22+
- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923))
23+
2224

2325
### Changed
2426

src/lightning_app/core/flow.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,13 @@ def __init__(self, work):
763763
super().__init__()
764764
self.work = work
765765

766+
@property
767+
def ready(self) -> bool:
768+
ready = getattr(self.work, "ready", None)
769+
if ready:
770+
return ready
771+
return self.work.url != ""
772+
766773
def run(self):
767774
if self.work.has_succeeded:
768775
self.work.stop()

src/lightning_app/core/queues.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ class MultiProcessQueue(BaseQueue):
198198
def __init__(self, name: str, default_timeout: float):
199199
self.name = name
200200
self.default_timeout = default_timeout
201-
self.queue = multiprocessing.Queue()
201+
context = multiprocessing.get_context("spawn")
202+
self.queue = context.Queue()
202203

203204
def put(self, item):
204205
self.queue.put(item)

src/lightning_app/core/work.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import time
23
import warnings
34
from copy import deepcopy
@@ -46,6 +47,8 @@ class LightningWork:
4647
)
4748

4849
_run_executor_cls: Type[WorkRunExecutor] = WorkRunExecutor
50+
# TODO: Move to spawn for all Operating System.
51+
_start_method = "spawn" if sys.platform == "win32" else "fork"
4952

5053
def __init__(
5154
self,

src/lightning_app/runners/backends/mp_process.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def start(self):
3131
flow_to_work_delta_queue=self.app.flow_to_work_delta_queues[self.work.name],
3232
run_executor_cls=self.work._run_executor_cls,
3333
)
34-
self._process = multiprocessing.Process(target=self._work_runner)
34+
35+
start_method = self.work._start_method
36+
context = multiprocessing.get_context(start_method)
37+
self._process = context.Process(target=self._work_runner)
3538
self._process.start()
3639

3740
def kill(self):

tests/tests_app/core/test_queues.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from unittest import mock
66

77
import pytest
8-
import redis
98
import requests_mock
109

1110
from lightning_app import LightningFlow
@@ -23,6 +22,7 @@ def test_queue_api(queue_type, monkeypatch):
2322
2423
This test run all the Queue implementation but we monkeypatch the Redis Queues to avoid external interaction
2524
"""
25+
import redis
2626

2727
blpop_out = (b"entry-id", pickle.dumps("test_entry"))
2828

@@ -104,12 +104,14 @@ def test_redis_queue_read_timeout(redis_mock):
104104

105105
@pytest.mark.parametrize(
106106
"queue_type, queue_process_mock",
107-
[(QueuingSystem.SINGLEPROCESS, queue), (QueuingSystem.MULTIPROCESS, multiprocessing)],
107+
[(QueuingSystem.MULTIPROCESS, multiprocessing)],
108108
)
109109
def test_process_queue_read_timeout(queue_type, queue_process_mock, monkeypatch):
110110

111+
context = mock.MagicMock()
111112
queue_mocked = mock.MagicMock()
112-
monkeypatch.setattr(queue_process_mock, "Queue", queue_mocked)
113+
context.Queue = queue_mocked
114+
monkeypatch.setattr(queue_process_mock, "get_context", mock.MagicMock(return_value=context))
113115
my_queue = queue_type.get_readiness_queue()
114116

115117
# default timeout

0 commit comments

Comments
 (0)