12
12
13
13
import lightning_app
14
14
from lightning_app import CloudCompute , LightningApp
15
- from lightning_app .core .flow import LightningFlow
15
+ from lightning_app .core .flow import _RootFlow , LightningFlow
16
16
from lightning_app .core .work import LightningWork
17
17
from lightning_app .runners import MultiProcessRuntime
18
18
from lightning_app .storage import Path
@@ -868,10 +868,10 @@ def test_lightning_flow_flows_and_works():
868
868
class WorkReady (LightningWork ):
869
869
def __init__ (self ):
870
870
super ().__init__ (parallel = True )
871
- self .counter = 0
871
+ self .ready = False
872
872
873
873
def run (self ):
874
- self .counter += 1
874
+ self .ready = True
875
875
876
876
877
877
class FlowReady (LightningFlow ):
@@ -890,7 +890,13 @@ def run(self):
890
890
self ._exit ()
891
891
892
892
893
- def test_flow_ready ():
893
+ class RootFlowReady (_RootFlow ):
894
+ def __init__ (self ):
895
+ super ().__init__ (WorkReady ())
896
+
897
+
898
+ @pytest .mark .parametrize ("flow" , [FlowReady , RootFlowReady ])
899
+ def test_flow_ready (flow ):
894
900
"""This test validates that the app status queue is populated correctly."""
895
901
896
902
mock_queue = _MockQueue ("api_publish_state_queue" )
@@ -910,7 +916,7 @@ def lagged_run_once(method):
910
916
state ["done" ] = new_done
911
917
return False
912
918
913
- app = LightningApp (FlowReady ())
919
+ app = LightningApp (flow ())
914
920
app ._run = partial (run_patch , method = app ._run )
915
921
app .run_once = partial (lagged_run_once , method = app .run_once )
916
922
MultiProcessRuntime (app , start_server = False ).dispatch ()
0 commit comments