Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions examples/app_server_with_auto_scaler/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ! pip install torch torchvision
from typing import Any, List

import torch
Expand All @@ -21,11 +22,13 @@ class BatchResponse(BaseModel):

class PyTorchServer(L.app.components.PythonServer):
def __init__(self, *args, **kwargs):
print(args)
print(kwargs)
super().__init__(
port=L.app.utilities.network.find_free_network_port(),
input_type=BatchRequestModel,
output_type=BatchResponse,
cloud_compute=L.CloudCompute("gpu"),
*args,
**kwargs,
)

def setup(self):
Expand Down Expand Up @@ -57,30 +60,32 @@ def scale(self, replicas: int, metrics: dict) -> int:
"""The default scaling logic that users can override."""
# scale out if the number of pending requests exceeds max batch size.
max_requests_per_work = self.max_batch_size
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
replicas + metrics["pending_works"]
)
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
pending_requests_per_work = metrics["pending_requests"] / (replicas + metrics["pending_works"])
if pending_requests_per_work >= max_requests_per_work:
return replicas + 1

# scale in if the number of pending requests is below 25% of max_requests_per_work
min_requests_per_work = max_requests_per_work * 0.25
pending_requests_per_running_work = metrics["pending_requests"] / replicas
if pending_requests_per_running_work < min_requests_per_work:
pending_requests_per_work = metrics["pending_requests"] / replicas
if pending_requests_per_work < min_requests_per_work:
return replicas - 1

return replicas


app = L.LightningApp(
MyAutoScaler(
# work class and args
PyTorchServer,
min_replicas=2,
cloud_compute=L.CloudCompute("gpu"),
# autoscaler specific args
min_replicas=1,
max_replicas=4,
autoscale_interval=10,
endpoint="predict",
input_type=RequestModel,
output_type=Any,
timeout_batching=1,
max_batch_size=8,
)
)
11 changes: 9 additions & 2 deletions src/lightning_app/components/auto_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.network import find_free_network_port
from lightning_app.utilities.packaging.cloud_compute import CloudCompute

logger = Logger(__name__)
Expand Down Expand Up @@ -445,8 +446,14 @@ def workers(self) -> List[LightningWork]:

def create_work(self) -> LightningWork:
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
return self._work_cls(*self._work_args, **self._work_kwargs, start_with_flow=False)
self._work_kwargs.update(
dict(
port=find_free_network_port(),
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
start_with_flow=False,
)
)
return self._work_cls(*self._work_args, **self._work_kwargs)

def add_work(self, work) -> str:
"""Adds a new LightningWork instance.
Expand Down