Skip to content

Commit a0f8f70

Browse files
akihironittaBorda
authored andcommitted
[App] Improve the autoscaler UI (#16063)
[App] Improve the autoscaler UI (#16063) (cherry picked from commit 39d27f6)
1 parent 6996dc8 commit a0f8f70

File tree

8 files changed

+110
-26
lines changed

8 files changed

+110
-26
lines changed

docs/source-app/api_references.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ ___________________
4545
~multi_node.lite.LiteMultiNode
4646
~multi_node.pytorch_spawn.PyTorchSpawnMultiNode
4747
~multi_node.trainer.LightningTrainerMultiNode
48-
~auto_scaler.AutoScaler
48+
~serve.auto_scaler.AutoScaler
4949

5050
----
5151

examples/app_server_with_auto_scaler/app.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ! pip install torch torchvision
2-
from typing import Any, List
2+
from typing import List
33

44
import torch
55
import torchvision
@@ -8,16 +8,12 @@
88
import lightning as L
99

1010

11-
class RequestModel(BaseModel):
12-
image: str # bytecode
13-
14-
1511
class BatchRequestModel(BaseModel):
16-
inputs: List[RequestModel]
12+
inputs: List[L.app.components.Image]
1713

1814

1915
class BatchResponse(BaseModel):
20-
outputs: List[Any]
16+
outputs: List[L.app.components.Number]
2117

2218

2319
class PyTorchServer(L.app.components.PythonServer):
@@ -81,8 +77,8 @@ def scale(self, replicas: int, metrics: dict) -> int:
8177
max_replicas=4,
8278
autoscale_interval=10,
8379
endpoint="predict",
84-
input_type=RequestModel,
85-
output_type=Any,
80+
input_type=L.app.components.Image,
81+
output_type=L.app.components.Number,
8682
timeout_batching=1,
8783
max_batch_size=8,
8884
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ module = [
7979
"lightning_app.components.serve.types.image",
8080
"lightning_app.components.serve.types.type",
8181
"lightning_app.components.serve.python_server",
82+
"lightning_app.components.serve.auto_scaler",
8283
"lightning_app.components.training",
83-
"lightning_app.components.auto_scaler",
8484
"lightning_app.core.api",
8585
"lightning_app.core.app",
8686
"lightning_app.core.flow",

src/lightning_app/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))
1313

14+
- Added a nicer UI with URL and examples for the autoscaler component ([#16063](https://github.com/Lightning-AI/lightning/pull/16063))
15+
1416
- Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018))
1517

1618
- Added `work.delete` method to delete the work ([#16103](https://github.com/Lightning-AI/lightning/pull/16103))

src/lightning_app/components/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from lightning_app.components.auto_scaler import AutoScaler
21
from lightning_app.components.database.client import DatabaseClient
32
from lightning_app.components.database.server import Database
43
from lightning_app.components.multi_node import (
@@ -9,6 +8,7 @@
98
)
109
from lightning_app.components.python.popen import PopenPythonScript
1110
from lightning_app.components.python.tracer import Code, TracerPythonScript
11+
from lightning_app.components.serve.auto_scaler import AutoScaler
1212
from lightning_app.components.serve.gradio import ServeGradio
1313
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
1414
from lightning_app.components.serve.serve import ModelInferenceAPI
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from lightning_app.components.serve.auto_scaler import AutoScaler
12
from lightning_app.components.serve.gradio import ServeGradio
23
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
34
from lightning_app.components.serve.streamlit import ServeStreamlit
45

5-
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text"]
6+
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text", "AutoScaler"]

src/lightning_app/components/auto_scaler.py renamed to src/lightning_app/components/serve/auto_scaler.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import uuid
77
from base64 import b64encode
88
from itertools import cycle
9-
from typing import Any, Dict, List, Tuple, Type
9+
from typing import Any, Dict, List, Optional, Tuple, Type
1010

1111
import requests
1212
import uvicorn
@@ -15,11 +15,13 @@
1515
from fastapi.responses import RedirectResponse
1616
from fastapi.security import HTTPBasic, HTTPBasicCredentials
1717
from pydantic import BaseModel
18+
from starlette.staticfiles import StaticFiles
1819
from starlette.status import HTTP_401_UNAUTHORIZED
1920

2021
from lightning_app.core.flow import LightningFlow
2122
from lightning_app.core.work import LightningWork
2223
from lightning_app.utilities.app_helpers import Logger
24+
from lightning_app.utilities.cloud import is_running_in_cloud
2325
from lightning_app.utilities.imports import _is_aiohttp_available, requires
2426
from lightning_app.utilities.packaging.cloud_compute import CloudCompute
2527

@@ -114,20 +116,21 @@ class _LoadBalancer(LightningWork):
114116
requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
115117
timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
116118
timeout_inference_request: The number of seconds to wait for inference.
117-
\**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
119+
**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
118120
"""
119121

120122
@requires(["aiohttp"])
121123
def __init__(
122124
self,
123-
input_type: BaseModel,
124-
output_type: BaseModel,
125+
input_type: Type[BaseModel],
126+
output_type: Type[BaseModel],
125127
endpoint: str,
126128
max_batch_size: int = 8,
127129
# all timeout args are in seconds
128-
timeout_batching: int = 1,
130+
timeout_batching: float = 1,
129131
timeout_keep_alive: int = 60,
130132
timeout_inference_request: int = 60,
133+
work_name: Optional[str] = "API", # used for displaying the name in the UI
131134
**kwargs: Any,
132135
) -> None:
133136
super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
@@ -142,6 +145,7 @@ def __init__(
142145
self._batch = []
143146
self._responses = {} # {request_id: response}
144147
self._last_batch_sent = 0
148+
self._work_name = work_name
145149

146150
if not endpoint.startswith("/"):
147151
endpoint = "/" + endpoint
@@ -280,6 +284,14 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe
280284
async def balance_api(inputs: self._input_type):
281285
return await self.process_request(inputs)
282286

287+
endpoint_info_page = self._get_endpoint_info_page()
288+
if endpoint_info_page:
289+
fastapi_app.mount(
290+
"/endpoint-info", StaticFiles(directory=endpoint_info_page.serve_dir, html=True), name="static"
291+
)
292+
293+
logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
294+
283295
uvicorn.run(
284296
fastapi_app,
285297
host=self.host,
@@ -332,6 +344,60 @@ def send_request_to_update_servers(self, servers: List[str]):
332344
response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10)
333345
response.raise_for_status()
334346

347+
@staticmethod
348+
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
349+
if not hasattr(datatype, "schema"):
350+
# not a pydantic model
351+
raise TypeError(f"datatype must be a pydantic model, for the UI to be generated. but got {datatype}")
352+
353+
if hasattr(datatype, "_get_sample_data"):
354+
return datatype._get_sample_data()
355+
356+
datatype_props = datatype.schema()["properties"]
357+
out: Dict[str, Any] = {}
358+
lut = {"string": "data string", "number": 0.0, "integer": 0, "boolean": False}
359+
for k, v in datatype_props.items():
360+
if v["type"] not in lut:
361+
raise TypeError("Unsupported type")
362+
out[k] = lut[v["type"]]
363+
return out
364+
365+
def get_code_sample(self, url: str) -> Optional[str]:
366+
input_type: Any = self._input_type
367+
output_type: Any = self._output_type
368+
369+
if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
370+
return None
371+
return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"
372+
373+
def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F821
374+
try:
375+
from lightning_api_access import APIAccessFrontend
376+
except ModuleNotFoundError:
377+
logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI")
378+
return
379+
380+
if is_running_in_cloud():
381+
url = f"{self._future_url}{self.endpoint}"
382+
else:
383+
url = f"http://localhost:{self.port}{self.endpoint}"
384+
385+
frontend_objects = {"name": self._work_name, "url": url, "method": "POST", "request": None, "response": None}
386+
code_samples = self.get_code_sample(url)
387+
if code_samples:
388+
frontend_objects["code_samples"] = code_samples
389+
# TODO also set request/response for JS UI
390+
else:
391+
try:
392+
request = self._get_sample_dict_from_datatype(self._input_type)
393+
response = self._get_sample_dict_from_datatype(self._output_type)
394+
except TypeError:
395+
return None
396+
else:
397+
frontend_objects["request"] = request
398+
frontend_objects["response"] = response
399+
return APIAccessFrontend(apis=[frontend_objects])
400+
335401

336402
class AutoScaler(LightningFlow):
337403
"""The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in
@@ -403,8 +469,8 @@ def __init__(
403469
max_batch_size: int = 8,
404470
timeout_batching: float = 1,
405471
endpoint: str = "api/predict",
406-
input_type: BaseModel = Dict,
407-
output_type: BaseModel = Dict,
472+
input_type: Type[BaseModel] = Dict,
473+
output_type: Type[BaseModel] = Dict,
408474
*work_args: Any,
409475
**work_kwargs: Any,
410476
) -> None:
@@ -438,6 +504,7 @@ def __init__(
438504
timeout_batching=timeout_batching,
439505
cache_calls=True,
440506
parallel=True,
507+
work_name=self._work_cls.__name__,
441508
)
442509
for _ in range(min_replicas):
443510
work = self.create_work()
@@ -574,5 +641,8 @@ def autoscale(self) -> None:
574641
self._last_autoscale = time.time()
575642

576643
def configure_layout(self):
577-
tabs = [{"name": "Swagger", "content": self.load_balancer.url}]
644+
tabs = [
645+
{"name": "Endpoint Info", "content": f"{self.load_balancer}/endpoint-info"},
646+
{"name": "Swagger", "content": self.load_balancer.url},
647+
]
578648
return tabs

tests/tests_app/components/test_auto_scaler.py renamed to tests/tests_app/components/serve/test_auto_scaler.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import time
2+
from unittest import mock
23
from unittest.mock import patch
34

45
import pytest
56

67
from lightning_app import CloudCompute, LightningWork
7-
from lightning_app.components import AutoScaler
8+
from lightning_app.components import AutoScaler, Text
89

910

1011
class EmptyWork(LightningWork):
@@ -32,8 +33,8 @@ def test_num_replicas_after_init():
3233

3334

3435
@patch("uvicorn.run")
35-
@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
36-
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
36+
@patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url")
37+
@patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests")
3738
def test_num_replicas_not_above_max_replicas(*_):
3839
"""Test self.num_replicas doesn't exceed max_replicas."""
3940
max_replicas = 6
@@ -52,8 +53,8 @@ def test_num_replicas_not_above_max_replicas(*_):
5253

5354

5455
@patch("uvicorn.run")
55-
@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
56-
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
56+
@patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url")
57+
@patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests")
5758
def test_num_replicas_not_belo_min_replicas(*_):
5859
"""Test self.num_replicas doesn't exceed max_replicas."""
5960
min_replicas = 1
@@ -98,3 +99,17 @@ def test_create_work_cloud_compute_cloned():
9899
auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
99100
_ = auto_scaler.create_work()
100101
assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute
102+
103+
104+
fastapi_mock = mock.MagicMock()
105+
mocked_fastapi_creater = mock.MagicMock(return_value=fastapi_mock)
106+
107+
108+
@patch("lightning_app.components.serve.auto_scaler._create_fastapi", mocked_fastapi_creater)
109+
@patch("lightning_app.components.serve.auto_scaler.uvicorn.run", mock.MagicMock())
110+
def test_API_ACCESS_ENDPOINT_creation():
111+
auto_scaler = AutoScaler(EmptyWork, input_type=Text, output_type=Text)
112+
assert auto_scaler.load_balancer._work_name == "EmptyWork"
113+
114+
auto_scaler.load_balancer.run()
115+
fastapi_mock.mount.assert_called_once_with("/endpoint-info", mock.ANY, name="static")

0 commit comments

Comments
 (0)