diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index e92e80c71deb0..093f6a778e84d 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -62,6 +62,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `AutoScaler` would fail with min_replica=0 ([#16092](https://github.com/Lightning-AI/lightning/pull/16092) +- Fixed auto-batching to enable batching for requests coming even after batch interval but is in the queue ([#16110](https://github.com/Lightning-AI/lightning/pull/16110)) + - Fixed a non-thread safe deepcopy in the scheduler ([#16114](https://github.com/Lightning-AI/lightning/pull/16114)) diff --git a/src/lightning_app/components/serve/auto_scaler.py b/src/lightning_app/components/serve/auto_scaler.py index 035b38f1bd5d8..f1ab081e12d8b 100644 --- a/src/lightning_app/components/serve/auto_scaler.py +++ b/src/lightning_app/components/serve/auto_scaler.py @@ -194,13 +194,13 @@ def __init__( self._batch = [] self._responses = {} # {request_id: response} self._last_batch_sent = 0 + self._server_status = {} self._api_name = api_name if not endpoint.startswith("/"): endpoint = "/" + endpoint self.endpoint = endpoint - self._fastapi_app = None self._cold_start_proxy = None @@ -212,23 +212,29 @@ def __init__( else: raise ValueError("cold_start_proxy must be of type ColdStartProxy or str") - async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]): - server = next(self._iter) # round-robin + async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str): request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch] batch_request_data = _BatchRequestModel(inputs=request_data) try: + self._server_status[server_url] = False async with aiohttp.ClientSession() as session: headers = { "accept": "application/json", "Content-Type": "application/json", } async with session.post( - f"{server}{self.endpoint}", + f"{server_url}{self.endpoint}", json=batch_request_data.dict(), timeout=self._timeout_inference_request, headers=headers, ) as response: + # resetting the server status so other requests can be + # scheduled on this node + if server_url in self._server_status: + # TODO - if the server returns an error, track that so + # we don't send more requests to it + self._server_status[server_url] = True if response.status == 408: raise HTTPException(408, "Request timed out") response.raise_for_status() @@ -241,19 +247,39 @@ async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]): except Exception as ex: result = {request[0]: ex for request in batch} self._responses.update(result) + finally: + self._server_status[server_url] = True + + def _find_free_server(self) -> Optional[str]: + existing = set(self._server_status.keys()) + for server in existing: + status = self._server_status.get(server, None) + if status is None: + logger.error("Server is not found in the status list. This should not happen.") + if status: + return server async def consumer(self): + self._last_batch_sent = time.time() while True: await asyncio.sleep(0.05) - batch = self._batch[: self.max_batch_size] - while batch and ( - (len(batch) == self.max_batch_size) or ((time.time() - self._last_batch_sent) > self.timeout_batching) - ): - asyncio.create_task(self.send_batch(batch)) - - self._batch = self._batch[self.max_batch_size :] - batch = self._batch[: self.max_batch_size] + is_batch_ready = len(batch) == self.max_batch_size + is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching + server_url = self._find_free_server() + # setting the server status to be busy! This will be reset by + # the send_batch function after the server responds + if server_url is None: + # TODO - a timeout until we try looking for servers + logger.error("No servers available") + continue + if batch and (is_batch_ready or is_batch_timeout): + # find server with capacity + # TODO multiple instances of consumer should not be running + # without locking the server array + asyncio.create_task(self.send_batch(batch, server_url)) + # resetting the batch array, TODO - not locking the array + self._batch = self._batch[len(batch) :] self._last_batch_sent = time.time() async def process_request(self, data: BaseModel, request_id=uuid.uuid4().hex): @@ -359,6 +385,18 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe async with lock: self._servers = servers self._iter = cycle(self._servers) + updated_servers = set() + # do not try to loop over the dict keys as the dict might change from other places + existing_servers = list(self._server_status.keys()) + for server in servers: + updated_servers.add(server) + if server not in existing_servers: + self._server_status[server] = True + logger.info(f"Registering server {server}", self._server_status) + for existing in existing_servers: + if existing not in updated_servers: + logger.info(f"De-Registering server {existing}", self._server_status) + del self._server_status[existing] @fastapi_app.post(self.endpoint, response_model=self._output_type) async def balance_api(inputs: input_type):