Skip to content

Commit 57c27a8

Browse files
ethanwharrisSherin Thomas
authored andcommitted
[App] Add configure_layout method for works (#15926)
* Add `configure_layout` method for works * Check for api access availability * Updates from review * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Sherin Thomas <[email protected]> (cherry picked from commit d5b9c67)
1 parent 94b2ac5 commit 57c27a8

File tree

10 files changed

+362
-71
lines changed

10 files changed

+362
-71
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616
- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))
1717
- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923))
1818

19+
- Added a `configure_layout` method to the `LightningWork` which can be used to control how the work is handled in the layout of a parent flow ([#15926](https://github.com/Lightning-AI/lightning/pull/15926))
20+
1921

2022
### Changed
2123

src/lightning_app/components/serve/gradio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,6 @@ def run(self, *args, **kwargs):
7878
server_port=self.port,
7979
enable_queue=self.enable_queue,
8080
)
81+
82+
def configure_layout(self) -> str:
83+
return self.url

src/lightning_app/components/serve/python_server.py

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from fastapi import FastAPI
99
from lightning_utilities.core.imports import module_available
1010
from pydantic import BaseModel
11-
from starlette.staticfiles import StaticFiles
1211

1312
from lightning_app.core.queues import MultiProcessQueue
1413
from lightning_app.core.work import LightningWork
@@ -222,49 +221,30 @@ def predict_fn(request: input_type): # type: ignore
222221

223222
fastapi_app.post("/predict", response_model=output_type)(predict_fn)
224223

225-
def _attach_frontend(self, fastapi_app: FastAPI) -> None:
226-
from lightning_api_access import APIAccessFrontend
227-
228-
class_name = self.__class__.__name__
229-
url = self._future_url if self._future_url else self.url
230-
if not url:
231-
# if the url is still empty, point it to localhost
232-
url = f"http://127.0.0.1:{self.port}"
233-
url = f"{url}/predict"
234-
datatype_parse_error = False
235-
try:
236-
request = self._get_sample_dict_from_datatype(self.configure_input_type())
237-
except TypeError:
238-
datatype_parse_error = True
239-
240-
try:
241-
response = self._get_sample_dict_from_datatype(self.configure_output_type())
242-
except TypeError:
243-
datatype_parse_error = True
244-
245-
if datatype_parse_error:
246-
247-
@fastapi_app.get("/")
248-
def index() -> str:
249-
return (
250-
"Automatic generation of the UI is only supported for simple, "
251-
"non-nested datatype with types string, integer, float and boolean"
252-
)
253-
254-
return
255-
256-
frontend = APIAccessFrontend(
257-
apis=[
258-
{
259-
"name": class_name,
260-
"url": url,
261-
"method": "POST",
262-
"request": request,
263-
"response": response,
264-
}
265-
]
266-
)
267-
fastapi_app.mount("/", StaticFiles(directory=frontend.serve_dir, html=True), name="static")
224+
def configure_layout(self) -> None:
225+
if module_available("lightning_api_access"):
226+
from lightning_api_access import APIAccessFrontend
227+
228+
class_name = self.__class__.__name__
229+
url = f"{self.url}/predict"
230+
231+
try:
232+
request = self._get_sample_dict_from_datatype(self.configure_input_type())
233+
response = self._get_sample_dict_from_datatype(self.configure_output_type())
234+
except TypeError:
235+
return None
236+
237+
return APIAccessFrontend(
238+
apis=[
239+
{
240+
"name": class_name,
241+
"url": url,
242+
"method": "POST",
243+
"request": request,
244+
"response": response,
245+
}
246+
]
247+
)
268248

269249
def run(self, *args: Any, **kwargs: Any) -> Any:
270250
"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.
@@ -275,7 +255,6 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
275255

276256
fastapi_app = FastAPI()
277257
self._attach_predict_fn(fastapi_app)
278-
self._attach_frontend(fastapi_app)
279258

280259
logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}")
281260
uvicorn.run(app=fastapi_app, host=self.host, port=self.port, log_level="error")

src/lightning_app/components/serve/serve.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import uvicorn
1111
from fastapi import FastAPI
1212
from fastapi.responses import JSONResponse
13-
from starlette.responses import RedirectResponse
1413

1514
from lightning_app.components.serve.types import _DESERIALIZER, _SERIALIZER
1615
from lightning_app.core.work import LightningWork
@@ -37,10 +36,6 @@ async def run(self, data) -> Any:
3736
return self.serialize(self.predict(self.deserialize(data)))
3837

3938

40-
async def _redirect():
41-
return RedirectResponse("/docs")
42-
43-
4439
class ModelInferenceAPI(LightningWork, abc.ABC):
4540
def __init__(
4641
self,
@@ -121,7 +116,6 @@ def run(self):
121116
def _populate_app(self, fastapi_service: FastAPI):
122117
self._model = self.build_model()
123118

124-
fastapi_service.get("/")(_redirect)
125119
fastapi_service.post("/predict", response_class=JSONResponse)(
126120
_InferenceCallable(
127121
deserialize=_DESERIALIZER[self.input] if self.input else self.deserialize,
@@ -134,6 +128,9 @@ def _launch_server(self, fastapi_service: FastAPI):
134128
logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}")
135129
uvicorn.run(app=fastapi_service, host=self.host, port=self.port, log_level="error")
136130

131+
def configure_layout(self) -> str:
132+
return f"{self.url}/docs"
133+
137134

138135
def _maybe_create_instance() -> Optional[ModelInferenceAPI]:
139136
"""This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi

src/lightning_app/components/serve/streamlit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def on_exit(self) -> None:
6363
if self._process is not None:
6464
self._process.kill()
6565

66+
def configure_layout(self) -> str:
67+
return self.url
68+
6669

6770
class _PatchedWork:
6871
"""The ``_PatchedWork`` is used to emulate a work instance from a subprocess. This is acheived by patching the

src/lightning_app/core/flow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lightning_app.frontend import Frontend
1111
from lightning_app.storage import Path
1212
from lightning_app.storage.drive import _maybe_create_drive, Drive
13-
from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name
13+
from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name, is_overridden
1414
from lightning_app.utilities.component import _sanitize_state
1515
from lightning_app.utilities.exceptions import ExitAppException
1616
from lightning_app.utilities.introspection import _is_init_context, _is_run_context
@@ -777,4 +777,6 @@ def run(self):
777777
self.work.run()
778778

779779
def configure_layout(self):
780-
return [{"name": "Main", "content": self.work}]
780+
if is_overridden("configure_layout", self.work):
781+
return [{"name": "Main", "content": self.work}]
782+
return []

src/lightning_app/core/work.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from copy import deepcopy
55
from functools import partial, wraps
6-
from typing import Any, Callable, Dict, List, Optional, Type, Union
6+
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
77

88
from deepdiff import DeepHash, Delta
99

@@ -33,6 +33,9 @@
3333
)
3434
from lightning_app.utilities.proxies import Action, LightningWorkSetAttrProxy, ProxyWorkRun, unwrap, WorkRunExecutor
3535

36+
if TYPE_CHECKING:
37+
from lightning_app.frontend import Frontend
38+
3639

3740
class LightningWork:
3841

@@ -629,3 +632,45 @@ def apply_flow_delta(self, delta: Delta):
629632
property_object.fset(self, value)
630633
else:
631634
self._default_setattr(name, value)
635+
636+
def configure_layout(self) -> Union[None, str, "Frontend"]:
637+
"""Configure the UI of this LightningWork.
638+
639+
You can either
640+
641+
1. Return a single :class:`~lightning_app.frontend.frontend.Frontend` object to serve a user interface
642+
for this Work.
643+
2. Return a string containing a URL to act as the user interface for this Work.
644+
3. Return ``None`` to indicate that this Work doesn't currently have a user interface.
645+
646+
**Example:** Serve a static directory (with at least a file index.html inside).
647+
648+
.. code-block:: python
649+
650+
from lightning_app.frontend import StaticWebFrontend
651+
652+
653+
class Work(LightningWork):
654+
def configure_layout(self):
655+
return StaticWebFrontend("path/to/folder/to/serve")
656+
657+
**Example:** Arrange the UI of my children in tabs (default UI by Lightning).
658+
659+
.. code-block:: python
660+
661+
class Work(LightningWork):
662+
def configure_layout(self):
663+
return [
664+
dict(name="First Tab", content=self.child0),
665+
dict(name="Second Tab", content=self.child1),
666+
dict(name="Lightning", content="https://lightning.ai"),
667+
]
668+
669+
If you don't implement ``configure_layout``, Lightning will use ``self.url``.
670+
671+
Note:
672+
This hook gets called at the time of app creation and then again as part of the loop. If desired, a
673+
returned URL can depend on the state. This is not the case if the work returns a
674+
:class:`~lightning_app.frontend.frontend.Frontend`. These need to be provided at the time of app creation
675+
in order for the runtime to start the server.
676+
"""

src/lightning_app/utilities/layout.py

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import lightning_app
66
from lightning_app.frontend.frontend import Frontend
7-
from lightning_app.utilities.app_helpers import _MagicMockJsonSerializable
7+
from lightning_app.utilities.app_helpers import _MagicMockJsonSerializable, is_overridden
88
from lightning_app.utilities.cloud import is_running_in_cloud
99

1010

@@ -45,9 +45,9 @@ def _collect_layout(app: "lightning_app.LightningApp", flow: "lightning_app.Ligh
4545
app.frontends.setdefault(flow.name, "mock")
4646
return flow._layout
4747
elif isinstance(layout, dict):
48-
layout = _collect_content_layout([layout], flow)
48+
layout = _collect_content_layout([layout], app, flow)
4949
elif isinstance(layout, (list, tuple)) and all(isinstance(item, dict) for item in layout):
50-
layout = _collect_content_layout(layout, flow)
50+
layout = _collect_content_layout(layout, app, flow)
5151
else:
5252
lines = _add_comment_to_literal_code(flow.configure_layout, contains="return", comment=" <------- this guy")
5353
m = f"""
@@ -76,7 +76,9 @@ def configure_layout(self):
7676
return layout
7777

7878

79-
def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFlow") -> List[Dict]:
79+
def _collect_content_layout(
80+
layout: List[Dict], app: "lightning_app.LightningApp", flow: "lightning_app.LightningFlow"
81+
) -> Union[List[Dict], Dict]:
8082
"""Process the layout returned by the ``configure_layout()`` method if the returned format represents an
8183
aggregation of child layouts."""
8284
for entry in layout:
@@ -102,12 +104,43 @@ def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFl
102104
entry["content"] = entry["content"].name
103105

104106
elif isinstance(entry["content"], lightning_app.LightningWork):
105-
if entry["content"].url and not entry["content"].url.startswith("/"):
106-
entry["content"] = entry["content"].url
107-
entry["target"] = entry["content"]
108-
else:
107+
work = entry["content"]
108+
work_layout = _collect_work_layout(work)
109+
110+
if work_layout is None:
109111
entry["content"] = ""
110-
entry["target"] = ""
112+
elif isinstance(work_layout, str):
113+
entry["content"] = work_layout
114+
entry["target"] = work_layout
115+
elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)):
116+
if len(layout) > 1:
117+
lines = _add_comment_to_literal_code(
118+
flow.configure_layout, contains="return", comment=" <------- this guy"
119+
)
120+
m = f"""
121+
The return value of configure_layout() in `{flow.__class__.__name__}` is an
122+
unsupported format:
123+
\n{lines}
124+
125+
The tab containing a `{work.__class__.__name__}` must be the only tab in the
126+
layout of this flow.
127+
128+
(see the docs for `LightningWork.configure_layout`).
129+
"""
130+
raise TypeError(m)
131+
132+
if isinstance(work_layout, Frontend):
133+
# If the work returned a frontend, treat it as belonging to the flow.
134+
# NOTE: This could evolve in the future to run the Frontend directly in the work machine.
135+
frontend = work_layout
136+
frontend.flow = flow
137+
elif isinstance(work_layout, _MagicMockJsonSerializable):
138+
# The import was mocked, we set a dummy `Frontend` so that `is_headless` knows there is a UI.
139+
frontend = "mock"
140+
141+
app.frontends.setdefault(flow.name, frontend)
142+
return flow._layout
143+
111144
elif isinstance(entry["content"], _MagicMockJsonSerializable):
112145
# The import was mocked, we just record dummy content so that `is_headless` knows there is a UI
113146
entry["content"] = "mock"
@@ -126,3 +159,43 @@ def configure_layout(self):
126159
"""
127160
raise ValueError(m)
128161
return layout
162+
163+
164+
def _collect_work_layout(work: "lightning_app.LightningWork") -> Union[None, str, Frontend, _MagicMockJsonSerializable]:
165+
"""Check if ``configure_layout`` is overridden on the given work and return the work layout (either a string, a
166+
``Frontend`` object, or an instance of a mocked import).
167+
168+
Args:
169+
work: The work to collect the layout for.
170+
171+
Raises:
172+
TypeError: If the value returned by ``configure_layout`` is not of a supported format.
173+
"""
174+
if is_overridden("configure_layout", work):
175+
work_layout = work.configure_layout()
176+
else:
177+
work_layout = work.url
178+
179+
if work_layout is None:
180+
return None
181+
elif isinstance(work_layout, str):
182+
url = work_layout
183+
# The URL isn't fully defined yet. Looks something like ``self.work.url + /something``.
184+
if url and not url.startswith("/"):
185+
return url
186+
return ""
187+
elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)):
188+
return work_layout
189+
else:
190+
m = f"""
191+
The value returned by `{work.__class__.__name__}.configure_layout()` is of an unsupported type.
192+
193+
{repr(work_layout)}
194+
195+
Return a `Frontend` or a URL string, for example:
196+
197+
class {work.__class__.__name__}(LightningWork):
198+
def configure_layout(self):
199+
return MyFrontend() OR 'http://some/url'
200+
"""
201+
raise TypeError(m)

0 commit comments

Comments
 (0)