Skip to content

Commit 065db7e

Browse files
committed
wip clean up autoscaler ui
1 parent 6745531 commit 065db7e

File tree

2 files changed

+54
-11
lines changed

2 files changed

+54
-11
lines changed

examples/app_server_with_auto_scaler/app.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ! pip install torch torchvision
2-
from typing import Any, List
2+
from typing import List
33

44
import torch
55
import torchvision
@@ -8,16 +8,12 @@
88
import lightning as L
99

1010

11-
class RequestModel(BaseModel):
12-
image: str # bytecode
13-
14-
1511
class BatchRequestModel(BaseModel):
16-
inputs: List[RequestModel]
12+
inputs: List[L.app.components.Image]
1713

1814

1915
class BatchResponse(BaseModel):
20-
outputs: List[Any]
16+
outputs: List[L.app.components.Number]
2117

2218

2319
class PyTorchServer(L.app.components.PythonServer):
@@ -81,8 +77,8 @@ def scale(self, replicas: int, metrics: dict) -> int:
8177
max_replicas=4,
8278
autoscale_interval=10,
8379
endpoint="predict",
84-
input_type=RequestModel,
85-
output_type=Any,
80+
input_type=L.app.components.Image,
81+
output_type=L.app.components.Number,
8682
timeout_batching=1,
8783
max_batch_size=8,
8884
)

src/lightning_app/components/auto_scaler.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe
280280
async def balance_api(inputs: self._input_type):
281281
return await self.process_request(inputs)
282282

283+
logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
284+
283285
uvicorn.run(
284286
fastapi_app,
285287
host=self.host,
@@ -332,6 +334,51 @@ def send_request_to_update_servers(self, servers: List[str]):
332334
response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10)
333335
response.raise_for_status()
334336

337+
@staticmethod
338+
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
339+
if hasattr(datatype, "_get_sample_data"):
340+
return datatype._get_sample_data()
341+
342+
datatype_props = datatype.schema()["properties"]
343+
out: Dict[str, Any] = {}
344+
for k, v in datatype_props.items():
345+
if v["type"] == "string":
346+
out[k] = "data string"
347+
elif v["type"] == "number":
348+
out[k] = 0.0
349+
elif v["type"] == "integer":
350+
out[k] = 0
351+
elif v["type"] == "boolean":
352+
out[k] = False
353+
else:
354+
raise TypeError("Unsupported type")
355+
return out
356+
357+
def configure_layout(self) -> None:
358+
try:
359+
from lightning_api_access import APIAccessFrontend
360+
except ModuleNotFoundError:
361+
logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI")
362+
return
363+
364+
try:
365+
request = self._get_sample_dict_from_datatype(self._input_type)
366+
response = self._get_sample_dict_from_datatype(self._output_type)
367+
except (AttributeError, TypeError):
368+
return
369+
370+
return APIAccessFrontend(
371+
apis=[
372+
{
373+
"name": self.__class__.__name__,
374+
"url": f"{self.url}{self.endpoint}",
375+
"method": "POST",
376+
"request": request,
377+
"response": response,
378+
}
379+
]
380+
)
381+
335382

336383
class AutoScaler(LightningFlow):
337384
"""The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in
@@ -574,5 +621,5 @@ def autoscale(self) -> None:
574621
self._last_autoscale = time.time()
575622

576623
def configure_layout(self):
577-
tabs = [{"name": "Swagger", "content": self.load_balancer.url}]
578-
return tabs
624+
layout = self.load_balancer.configure_layout()
625+
return layout if layout else super().configure_layout()

0 commit comments

Comments
 (0)