Skip to content

Commit 711aec5

Browse files
authored
[App] Implement ready for components (#16129)
1 parent ae14f9d commit 711aec5

File tree

6 files changed

+29
-6
lines changed

6 files changed

+29
-6
lines changed

src/lightning_app/components/serve/auto_scaler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ def __init__(
212212
else:
213213
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")
214214

215+
self.ready = False
216+
215217
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
216218
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
217219
batch_request_data = _BatchRequestModel(inputs=request_data)
@@ -410,6 +412,7 @@ async def balance_api(inputs: input_type):
410412
)
411413

412414
logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
415+
self.ready = True
413416

414417
uvicorn.run(
415418
fastapi_app,
@@ -641,6 +644,10 @@ def __init__(
641644
def workers(self) -> List[LightningWork]:
642645
return [self.get_work(i) for i in range(self.num_replicas)]
643646

647+
@property
648+
def ready(self) -> bool:
649+
return self.load_balancer.ready
650+
644651
def create_work(self) -> LightningWork:
645652
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
646653
cloud_compute = self._work_kwargs.get("cloud_compute", None)

src/lightning_app/components/serve/gradio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def __init__(self, *args, **kwargs):
4242
assert self.outputs
4343
self._model = None
4444

45+
self.ready = False
46+
4547
@property
4648
def model(self):
4749
return self._model
@@ -62,6 +64,7 @@ def run(self, *args, **kwargs):
6264
self._model = self.build_model()
6365
fn = partial(self.predict, *args, **kwargs)
6466
fn.__name__ = self.predict.__name__
67+
self.ready = True
6568
gradio.Interface(
6669
fn=fn,
6770
inputs=self.inputs,

src/lightning_app/components/serve/python_server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def predict(self, request):
193193
self._input_type = input_type
194194
self._output_type = output_type
195195

196+
self.ready = False
197+
196198
def setup(self, *args, **kwargs) -> None:
197199
"""This method is called before the server starts. Override this if you need to download the model or
198200
initialize the weights, setting up pipelines etc.
@@ -300,6 +302,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
300302
fastapi_app = FastAPI()
301303
self._attach_predict_fn(fastapi_app)
302304

305+
self.ready = True
303306
logger.info(
304307
f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}"
305308
)

src/lightning_app/components/serve/serve.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
self.workers = workers
6565
self._model = None
6666

67+
self.ready = False
68+
6769
@property
6870
def model(self):
6971
return self._model
@@ -108,9 +110,11 @@ def run(self):
108110
"serve:fastapi_service",
109111
]
110112
process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__))
113+
self.ready = True
111114
process.wait()
112115
else:
113116
self._populate_app(fastapi_service)
117+
self.ready = True
114118
self._launch_server(fastapi_service)
115119

116120
def _populate_app(self, fastapi_service: FastAPI):

src/lightning_app/core/flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def __init__(self, work):
797797
@property
798798
def ready(self) -> bool:
799799
ready = getattr(self.work, "ready", None)
800-
if ready:
800+
if ready is not None:
801801
return ready
802802
return self.work.url != ""
803803

tests/tests_app/core/test_lightning_flow.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import lightning_app
1414
from lightning_app import CloudCompute, LightningApp
15-
from lightning_app.core.flow import LightningFlow
15+
from lightning_app.core.flow import _RootFlow, LightningFlow
1616
from lightning_app.core.work import LightningWork
1717
from lightning_app.runners import MultiProcessRuntime
1818
from lightning_app.storage import Path
@@ -868,10 +868,10 @@ def test_lightning_flow_flows_and_works():
868868
class WorkReady(LightningWork):
869869
def __init__(self):
870870
super().__init__(parallel=True)
871-
self.counter = 0
871+
self.ready = False
872872

873873
def run(self):
874-
self.counter += 1
874+
self.ready = True
875875

876876

877877
class FlowReady(LightningFlow):
@@ -890,7 +890,13 @@ def run(self):
890890
self._exit()
891891

892892

893-
def test_flow_ready():
893+
class RootFlowReady(_RootFlow):
894+
def __init__(self):
895+
super().__init__(WorkReady())
896+
897+
898+
@pytest.mark.parametrize("flow", [FlowReady, RootFlowReady])
899+
def test_flow_ready(flow):
894900
"""This test validates that the app status queue is populated correctly."""
895901

896902
mock_queue = _MockQueue("api_publish_state_queue")
@@ -910,7 +916,7 @@ def lagged_run_once(method):
910916
state["done"] = new_done
911917
return False
912918

913-
app = LightningApp(FlowReady())
919+
app = LightningApp(flow())
914920
app._run = partial(run_patch, method=app._run)
915921
app.run_once = partial(lagged_run_once, method=app.run_once)
916922
MultiProcessRuntime(app, start_server=False).dispatch()

0 commit comments

Comments
 (0)