Skip to content

Commit 8ec7ddf

Browse files
tchatonBordalantiga
authored
[App] Pass LightningWork to LightningApp (#15215)
* update * update * update * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * ll Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: Luca Antiga <[email protected]>
1 parent 775e9eb commit 8ec7ddf

File tree

5 files changed

+40
-5
lines changed

5 files changed

+40
-5
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Added support for adding descriptions to commands either through a docstring or the `DESCRIPTION` attribute ([#15193](https://github.com/Lightning-AI/lightning/pull/15193)
1818
- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187)
1919
- Added a Database Component ([#14995](https://github.com/Lightning-AI/lightning/pull/14995)
20-
20+
- Added support to pass a `LightningWork` to the `LightningApp` ([#15215](https://github.com/Lightning-AI/lightning/pull/15215)
2121

2222
### Fixed
2323

src/lightning_app/core/app.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
class LightningApp:
4747
def __init__(
4848
self,
49-
root: "lightning_app.LightningFlow",
49+
root: "t.Union[lightning_app.LightningFlow, lightning_app.LightningWork]",
5050
debug: bool = False,
5151
info: frontend.AppInfo = None,
5252
root_path: str = "",
@@ -62,8 +62,8 @@ def __init__(
6262
the :class:`~lightning_app.core.flow.LightningFlow` provided.
6363
6464
Arguments:
65-
root: The root LightningFlow component, that defines all the app's nested components, running infinitely.
66-
It must define a `run()` method that the app can call.
65+
root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested
66+
components, running infinitely. It must define a `run()` method that the app can call.
6767
debug: Whether to activate the Lightning Logger debug mode.
6868
This can be helpful when reporting bugs on Lightning repo.
6969
info: Provide additional info about the app which will be used to update html title,
@@ -89,6 +89,10 @@ def __init__(
8989
"""
9090

9191
self.root_path = root_path # when running behind a proxy
92+
93+
if isinstance(root, lightning_app.LightningWork):
94+
root = lightning_app.core.flow._RootFlow(root)
95+
9296
_validate_root_flow(root)
9397
self._root = root
9498

src/lightning_app/core/flow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,15 @@ def load_state_dict(self, flow_state, children_states, strict) -> None:
763763
child.set_state(state)
764764
elif strict:
765765
raise ValueError(f"The component {child_name} wasn't instantiated for the component {self.name}")
766+
767+
768+
class _RootFlow(LightningFlow):
769+
def __init__(self, work):
770+
super().__init__()
771+
self.work = work
772+
773+
def run(self):
774+
self.work.run()
775+
776+
def configure_layout(self):
777+
return [{"name": "Main", "content": self.work}]

tests/tests_app/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def clear_app_state_state_variables():
7575
lightning_app.utilities.state._STATE = None
7676
lightning_app.utilities.state._LAST_STATE = None
7777
AppState._MY_AFFILIATION = ()
78-
cloud_compute._CLOUD_COMPUTE_STORE.clear()
78+
if hasattr(cloud_compute, "_CLOUD_COMPUTE_STORE"):
79+
cloud_compute._CLOUD_COMPUTE_STORE.clear()
7980

8081

8182
@pytest.fixture

tests/tests_app/core/test_lightning_work.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lightning_app.runners import MultiProcessRuntime
1111
from lightning_app.storage import Path
1212
from lightning_app.testing.helpers import EmptyFlow, EmptyWork, MockQueue
13+
from lightning_app.testing.testing import LightningTestApp
1314
from lightning_app.utilities.enum import WorkStageStatus
1415
from lightning_app.utilities.proxies import ProxyWorkRun, WorkRunner
1516

@@ -327,3 +328,20 @@ def run(self, *args, **kwargs):
327328

328329
w = Work()
329330
w.run()
331+
332+
333+
class WorkCounter(LightningWork):
334+
def run(self):
335+
pass
336+
337+
338+
class LightningTestAppWithWork(LightningTestApp):
339+
def on_before_run_once(self):
340+
if self.root.work.has_succeeded:
341+
return True
342+
return super().on_before_run_once()
343+
344+
345+
def test_lightning_app_with_work():
346+
app = LightningTestAppWithWork(WorkCounter())
347+
MultiProcessRuntime(app, start_server=False).dispatch()

0 commit comments

Comments
 (0)