Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where `AutoScaler` would fail with min_replica=0 ([#16092](https://github.com/Lightning-AI/lightning/pull/16092)

- Fixed auto-batching to enable batching for requests coming even after batch interval but is in the queue ([#16110](https://github.com/Lightning-AI/lightning/pull/16110))


- Fixed a non-thread safe deepcopy in the scheduler ([#16114](https://github.com/Lightning-AI/lightning/pull/16114))

Expand Down
62 changes: 50 additions & 12 deletions src/lightning_app/components/serve/auto_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ def __init__(
self._batch = []
self._responses = {} # {request_id: response}
self._last_batch_sent = 0
self._server_status = {}
self._api_name = api_name

if not endpoint.startswith("/"):
endpoint = "/" + endpoint

self.endpoint = endpoint

self._fastapi_app = None

self._cold_start_proxy = None
Expand All @@ -212,23 +212,29 @@ def __init__(
else:
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")

async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
server = next(self._iter) # round-robin
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
batch_request_data = _BatchRequestModel(inputs=request_data)

try:
self._server_status[server_url] = False
async with aiohttp.ClientSession() as session:
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
async with session.post(
f"{server}{self.endpoint}",
f"{server_url}{self.endpoint}",
json=batch_request_data.dict(),
timeout=self._timeout_inference_request,
headers=headers,
) as response:
# resetting the server status so other requests can be
# scheduled on this node
if server_url in self._server_status:
# TODO - if the server returns an error, track that so
# we don't send more requests to it
self._server_status[server_url] = True
if response.status == 408:
raise HTTPException(408, "Request timed out")
response.raise_for_status()
Expand All @@ -241,19 +247,39 @@ async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
except Exception as ex:
result = {request[0]: ex for request in batch}
self._responses.update(result)
finally:
self._server_status[server_url] = True

def _find_free_server(self) -> Optional[str]:
existing = set(self._server_status.keys())
for server in existing:
status = self._server_status.get(server, None)
if status is None:
logger.error("Server is not found in the status list. This should not happen.")
if status:
return server

async def consumer(self):
self._last_batch_sent = time.time()
while True:
await asyncio.sleep(0.05)

batch = self._batch[: self.max_batch_size]
while batch and (
(len(batch) == self.max_batch_size) or ((time.time() - self._last_batch_sent) > self.timeout_batching)
):
asyncio.create_task(self.send_batch(batch))

self._batch = self._batch[self.max_batch_size :]
batch = self._batch[: self.max_batch_size]
is_batch_ready = len(batch) == self.max_batch_size
is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching
server_url = self._find_free_server()
# setting the server status to be busy! This will be reset by
# the send_batch function after the server responds
if server_url is None:
# TODO - a timeout until we try looking for servers
logger.error("No servers available")
continue
if batch and (is_batch_ready or is_batch_timeout):
# find server with capacity
# TODO multiple instances of consumer should not be running
# without locking the server array
asyncio.create_task(self.send_batch(batch, server_url))
# resetting the batch array, TODO - not locking the array
self._batch = self._batch[len(batch) :]
self._last_batch_sent = time.time()

async def process_request(self, data: BaseModel, request_id=uuid.uuid4().hex):
Expand Down Expand Up @@ -359,6 +385,18 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe
async with lock:
self._servers = servers
self._iter = cycle(self._servers)
updated_servers = set()
# do not try to loop over the dict keys as the dict might change from other places
existing_servers = list(self._server_status.keys())
for server in servers:
updated_servers.add(server)
if server not in existing_servers:
self._server_status[server] = True
logger.info(f"Registering server {server}", self._server_status)
for existing in existing_servers:
if existing not in updated_servers:
logger.info(f"De-Registering server {existing}", self._server_status)
del self._server_status[existing]

@fastapi_app.post(self.endpoint, response_model=self._output_type)
async def balance_api(inputs: input_type):
Expand Down