diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py index 0275596288ff0..d2d8773ee2178 100644 --- a/src/lightning_app/components/__init__.py +++ b/src/lightning_app/components/__init__.py @@ -8,7 +8,8 @@ ) from lightning_app.components.python.popen import PopenPythonScript from lightning_app.components.python.tracer import Code, TracerPythonScript -from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy +from lightning_app.components.serve.auto_scaler import AutoScaler +from lightning_app.components.serve.cold_start_proxy import ColdStartProxy from lightning_app.components.serve.gradio import ServeGradio from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text from lightning_app.components.serve.serve import ModelInferenceAPI diff --git a/src/lightning_app/components/serve/__init__.py b/src/lightning_app/components/serve/__init__.py index 39dafe2f7ff1b..feebac39d018b 100644 --- a/src/lightning_app/components/serve/__init__.py +++ b/src/lightning_app/components/serve/__init__.py @@ -1,4 +1,5 @@ -from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy +from lightning_app.components.serve.auto_scaler import AutoScaler +from lightning_app.components.serve.cold_start_proxy import ColdStartProxy from lightning_app.components.serve.gradio import ServeGradio from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text from lightning_app.components.serve.streamlit import ServeStreamlit diff --git a/src/lightning_app/components/serve/auto_scaler.py b/src/lightning_app/components/serve/auto_scaler.py index 3c1edb0e33b9d..e29822081527d 100644 --- a/src/lightning_app/components/serve/auto_scaler.py +++ b/src/lightning_app/components/serve/auto_scaler.py @@ -1,23 +1,21 @@ import asyncio import logging -import os -import secrets import time import uuid -from base64 import b64encode from itertools import cycle -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional +from typing import SupportsFloat as Numeric +from typing import Tuple, Type, Union import requests import uvicorn -from fastapi import Depends, FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse -from fastapi.security import HTTPBasic, HTTPBasicCredentials from pydantic import BaseModel from starlette.staticfiles import StaticFiles -from starlette.status import HTTP_401_UNAUTHORIZED +from lightning_app.components.serve.cold_start_proxy import ColdStartProxy from lightning_app.core.flow import LightningFlow from lightning_app.core.work import LightningWork from lightning_app.utilities.app_helpers import Logger @@ -32,50 +30,14 @@ logger = Logger(__name__) -class ColdStartProxy: - """ColdStartProxy allows users to configure the load balancer to use a proxy service while the work is cold - starting. This is useful with services that gets realtime requests but startup time for workers is high. +class _TrackableFastAPI(FastAPI): + """A FastAPI subclass that tracks the request metadata.""" - If the request body is same and the method is POST for the proxy service, - then the default implementation of `handle_request` can be used. In that case - initialize the proxy with the proxy url. Otherwise, the user can override the `handle_request` - - Args: - proxy_url (str): The url of the proxy service - """ - - def __init__(self, proxy_url): - self.proxy_url = proxy_url - self.proxy_timeout = 50 - # checking `asyncio.iscoroutinefunction` instead of `inspect.iscoroutinefunction` - # because AsyncMock in the tests requres the former to pass - if not asyncio.iscoroutinefunction(self.handle_request): - raise TypeError("handle_request must be an `async` function") - - async def handle_request(self, request: BaseModel) -> Any: - """This method is called when the request is received while the work is cold starting. The default - implementation of this method is to forward the request body to the proxy service with POST method but the - user can override this method to handle the request in any way. - - Args: - request (BaseModel): The request body, a pydantic model that is being - forwarded by load balancer which is a FastAPI service - """ - try: - async with aiohttp.ClientSession() as session: - headers = { - "accept": "application/json", - "Content-Type": "application/json", - } - async with session.post( - self.proxy_url, - json=request.dict(), - timeout=self.proxy_timeout, - headers=headers, - ) as response: - return await response.json() - except Exception as ex: - raise HTTPException(status_code=500, detail=f"Error in proxy: {ex}") + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.global_request_count = 0 + self.num_current_requests = 0 + self.last_processing_time = 0 def _maybe_raise_granular_exception(exception: Exception) -> None: @@ -116,8 +78,8 @@ class _BatchRequestModel(BaseModel): inputs: List[Any] -def _create_fastapi(title: str) -> FastAPI: - fastapi_app = FastAPI(title=title) +def _create_fastapi(title: str) -> _TrackableFastAPI: + fastapi_app = _TrackableFastAPI(title=title) fastapi_app.add_middleware( CORSMiddleware, @@ -127,10 +89,6 @@ def _create_fastapi(title: str) -> FastAPI: allow_headers=["*"], ) - fastapi_app.global_request_count = 0 - fastapi_app.num_current_requests = 0 - fastapi_app.last_processing_time = 0 - @fastapi_app.get("/", include_in_schema=False) async def docs(): return RedirectResponse("/docs") @@ -146,11 +104,6 @@ class _LoadBalancer(LightningWork): r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests. - The LoadBalancer exposes system endpoints with a basic HTTP authentication, in order to activate the authentication - you need to provide a system password from environment variable:: - - lightning run app app.py --env AUTO_SCALER_AUTH_PASSWORD=PASSWORD - After enabling you will require to send username and password from the request header for the private endpoints. Args: @@ -187,15 +140,16 @@ def __init__( self._output_type = output_type self._timeout_keep_alive = timeout_keep_alive self._timeout_inference_request = timeout_inference_request - self._servers = [] + self.servers = [] self.max_batch_size = max_batch_size self.timeout_batching = timeout_batching self._iter = None self._batch = [] self._responses = {} # {request_id: response} - self._last_batch_sent = 0 + self._last_batch_sent = None self._server_status = {} self._api_name = api_name + self.ready = False if not endpoint.startswith("/"): endpoint = "/" + endpoint @@ -212,7 +166,10 @@ def __init__( else: raise ValueError("cold_start_proxy must be of type ColdStartProxy or str") - self.ready = False + def get_internal_url(self) -> str: + if not self._internal_ip: + raise ValueError("Internal IP not set") + return f"http://{self._internal_ip}:{self._port}" async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str): request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch] @@ -267,30 +224,40 @@ async def consumer(self): Two instances of this function should not be running with shared `_state_server` as that would create race conditions """ - self._last_batch_sent = time.time() while True: await asyncio.sleep(0.05) batch = self._batch[: self.max_batch_size] is_batch_ready = len(batch) == self.max_batch_size - is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching + if len(batch) > 0 and self._last_batch_sent is None: + self._last_batch_sent = time.time() + + if self._last_batch_sent: + is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching + else: + is_batch_timeout = False + server_url = self._find_free_server() # setting the server status to be busy! This will be reset by # the send_batch function after the server responds if server_url is None: continue if batch and (is_batch_ready or is_batch_timeout): + self._server_status[server_url] = False # find server with capacity asyncio.create_task(self.send_batch(batch, server_url)) # resetting the batch array, TODO - not locking the array self._batch = self._batch[len(batch) :] self._last_batch_sent = time.time() - async def process_request(self, data: BaseModel, request_id=uuid.uuid4().hex): - if not self._servers and not self._cold_start_proxy: - raise HTTPException(500, "None of the workers are healthy!") + async def process_request(self, data: BaseModel, request_id=None): + if request_id is None: + request_id = uuid.uuid4().hex + if not self.servers and not self._cold_start_proxy: + # sleeping to trigger the scale up + raise HTTPException(503, "None of the workers are healthy!, try again in a few seconds") # if no servers are available, proxy the request to cold start proxy handler - if not self._servers and self._cold_start_proxy: + if not self.servers and self._cold_start_proxy: return await self._cold_start_proxy.handle_request(data) # if out of capacity, proxy the request to cold start proxy handler @@ -314,20 +281,17 @@ def _has_processing_capacity(self): """ if not self._fastapi_app: return False - active_server_count = len(self._servers) + active_server_count = len(self.servers) max_processable = self.max_batch_size * active_server_count current_req_count = self._fastapi_app.num_current_requests return current_req_count < max_processable def run(self): - logger.info(f"servers: {self._servers}") - lock = asyncio.Lock() + logger.info(f"servers: {self.servers}") - self._iter = cycle(self._servers) - self._last_batch_sent = time.time() + self._iter = cycle(self.servers) fastapi_app = _create_fastapi("Load Balancer") - security = HTTPBasic() fastapi_app.SEND_TASK = None self._fastapi_app = fastapi_app @@ -354,40 +318,20 @@ async def startup_event(): def shutdown_event(): fastapi_app.SEND_TASK.cancel() - def authenticate_private_endpoint(credentials: HTTPBasicCredentials = Depends(security)): - AUTO_SCALER_AUTH_PASSWORD = os.environ.get("AUTO_SCALER_AUTH_PASSWORD", "") - if len(AUTO_SCALER_AUTH_PASSWORD) == 0: - logger.warn( - "You have not set a password for private endpoints! To set a password, add " - "`--env AUTO_SCALER_AUTH_PASSWORD=` to your lightning run command." - ) - current_password_bytes = credentials.password.encode("utf8") - is_correct_password = secrets.compare_digest( - current_password_bytes, AUTO_SCALER_AUTH_PASSWORD.encode("utf8") - ) - if not is_correct_password: - raise HTTPException( - status_code=401, - detail="Incorrect password", - headers={"WWW-Authenticate": "Basic"}, - ) - return True - @fastapi_app.get("/system/info", response_model=_SysInfo) - async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)): + async def sys_info(): return _SysInfo( - num_workers=len(self._servers), - servers=self._servers, + num_workers=len(self.servers), + servers=self.servers, num_requests=fastapi_app.num_current_requests, processing_time=fastapi_app.last_processing_time, global_request_count=fastapi_app.global_request_count, ) @fastapi_app.put("/system/update-servers") - async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)): - async with lock: - self._servers = servers - self._iter = cycle(self._servers) + async def update_servers(servers: List[str]): + self.servers = servers + self._iter = cycle(self.servers) updated_servers = set() # do not try to loop over the dict keys as the dict might change from other places existing_servers = list(self._server_status.keys()) @@ -413,7 +357,6 @@ async def balance_api(inputs: input_type): logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'") self.ready = True - uvicorn.run( fastapi_app, host=self.host, @@ -428,54 +371,43 @@ def update_servers(self, server_works: List[LightningWork]): AutoScaler uses this method to increase/decrease the number of works. """ - old_servers = set(self._servers) - server_urls: List[str] = [ + old_server_urls = set(self.servers) + current_server_urls = { f"http://{server._internal_ip}:{server.port}" for server in server_works if server._internal_ip - ] - new_servers = set(server_urls) + } - if new_servers == old_servers: + # doing nothing if no server work has been added/removed + if old_server_urls == current_server_urls: return - if new_servers - old_servers: - logger.info(f"servers added: {new_servers - old_servers}") + # checking if the url is ready or not + available_urls = set() + for url in current_server_urls: + try: + _ = requests.get(url) + except requests.exceptions.ConnectionError: + continue + else: + available_urls.add(url) + if old_server_urls == available_urls: + return - deleted_servers = old_servers - new_servers - if deleted_servers: - logger.info(f"servers deleted: {deleted_servers}") + newly_added = available_urls - old_server_urls + if newly_added: + logger.info(f"servers added: {newly_added}") - self.send_request_to_update_servers(server_urls) + deleted = old_server_urls - available_urls + if deleted: + logger.info(f"servers deleted: {deleted}") + self.send_request_to_update_servers(list(available_urls)) def send_request_to_update_servers(self, servers: List[str]): - AUTHORIZATION_TYPE = "Basic" - USERNAME = "lightning" - AUTO_SCALER_AUTH_PASSWORD = os.environ.get("AUTO_SCALER_AUTH_PASSWORD", "") - try: - param = f"{USERNAME}:{AUTO_SCALER_AUTH_PASSWORD}".encode() - data = b64encode(param).decode("utf-8") - except (ValueError, UnicodeDecodeError) as e: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail="Invalid authentication credentials", - headers={"WWW-Authenticate": "Basic"}, - ) from e - - if not self._internal_ip: + internal_url = self.get_internal_url() + except ValueError: + logger.warn("Cannot update servers as internal_url is not set") return - - headers = { - "accept": "application/json", - "username": USERNAME, - "Authorization": AUTHORIZATION_TYPE + " " + data, - } - - response = requests.put( - f"http://{self._internal_ip}:{self.port}/system/update-servers", - json=servers, - headers=headers, - timeout=10, - ) + response = requests.put(f"{internal_url}/system/update-servers", json=servers, timeout=10) response.raise_for_status() @staticmethod @@ -519,7 +451,7 @@ def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F82 frontend_objects = {"name": self._api_name, "url": url, "method": "POST", "request": None, "response": None} code_samples = self.get_code_sample(url) if code_samples: - frontend_objects["code_samples"] = code_samples + frontend_objects["code_sample"] = code_samples # TODO also set request/response for JS UI else: try: @@ -603,8 +535,8 @@ def __init__( work_cls: Type[LightningWork], min_replicas: int = 1, max_replicas: int = 4, - scale_out_interval: int = 10, - scale_in_interval: int = 10, + scale_out_interval: Numeric = 10, + scale_in_interval: Numeric = 10, max_batch_size: int = 8, timeout_batching: float = 1, endpoint: str = "api/predict", @@ -648,26 +580,21 @@ def __init__( api_name=self._work_cls.__name__, cold_start_proxy=cold_start_proxy, ) - for _ in range(min_replicas): - work = self.create_work() - self.add_work(work) - - @property - def workers(self) -> List[LightningWork]: - return [self.get_work(i) for i in range(self.num_replicas)] @property def ready(self) -> bool: return self.load_balancer.ready + @property + def workers(self) -> List[LightningWork]: + return [self.get_work(i) for i in range(self.num_replicas)] + def create_work(self) -> LightningWork: """Replicates a LightningWork instance with args and kwargs provided via ``__init__``.""" cloud_compute = self._work_kwargs.get("cloud_compute", None) self._work_kwargs.update( dict( - # TODO: Remove `start_with_flow=False` for faster initialization on the cloud start_with_flow=False, - # don't try to create multiple works in a single machine cloud_compute=cloud_compute.clone() if cloud_compute else None, ) ) @@ -704,10 +631,8 @@ def get_work(self, index: int) -> LightningWork: def run(self): if not self.load_balancer.is_running: self.load_balancer.run() - for work in self.workers: work.run() - if self.load_balancer.url: self.fake_trigger += 1 # Note: change state to keep calling `run`. self.autoscale() @@ -747,11 +672,12 @@ def scale(self, replicas: int, metrics: dict) -> int: @property def num_pending_requests(self) -> int: """Fetches the number of pending requests via load balancer.""" - if not self.load_balancer._internal_ip: + try: + load_balancer_url = self.load_balancer.get_internal_url() + except ValueError: + logger.warn("Cannot update servers as internal_url is not set") return 0 - return int( - requests.get(f"http://{self.load_balancer._internal_ip}:{self.load_balancer.port}/num-requests").json() - ) + return int(requests.get(f"{load_balancer_url}/num-requests").json()) @property def num_pending_works(self) -> int: @@ -773,6 +699,8 @@ def autoscale(self) -> None: # scale-out if time.time() - self._last_autoscale > self.scale_out_interval: + # TODO figuring out number of workers to add only based on num_replicas isn't right because pending works + # are not added to num_replicas num_workers_to_add = num_target_workers - self.num_replicas for _ in range(num_workers_to_add): logger.info(f"Scaling out from {self.num_replicas} to {self.num_replicas + 1}") @@ -785,6 +713,8 @@ def autoscale(self) -> None: # scale-in if time.time() - self._last_autoscale > self.scale_in_interval: + # TODO figuring out number of workers to remove only based on num_replicas isn't right because pending works + # are not added to num_replicas num_workers_to_remove = self.num_replicas - num_target_workers for _ in range(num_workers_to_remove): logger.info(f"Scaling in from {self.num_replicas} to {self.num_replicas - 1}") diff --git a/src/lightning_app/components/serve/cold_start_proxy.py b/src/lightning_app/components/serve/cold_start_proxy.py new file mode 100644 index 0000000000000..6bfe8bc01f275 --- /dev/null +++ b/src/lightning_app/components/serve/cold_start_proxy.py @@ -0,0 +1,56 @@ +import asyncio +from typing import Any + +from fastapi import HTTPException +from pydantic import BaseModel + +from lightning_app.utilities.imports import _is_aiohttp_available, requires + +if _is_aiohttp_available(): + import aiohttp + import aiohttp.client_exceptions + + +class ColdStartProxy: + """ColdStartProxy allows users to configure the load balancer to use a proxy service while the work is cold + starting. This is useful with services that gets realtime requests but startup time for workers is high. + + If the request body is same and the method is POST for the proxy service, + then the default implementation of `handle_request` can be used. In that case + initialize the proxy with the proxy url. Otherwise, the user can override the `handle_request` + + Args: + proxy_url (str): The url of the proxy service + """ + + @requires(["aiohttp"]) + def __init__(self, proxy_url: str): + self.proxy_url = proxy_url + self.proxy_timeout = 50 + if not asyncio.iscoroutinefunction(self.handle_request): + raise TypeError("handle_request must be an `async` function") + + async def handle_request(self, request: BaseModel) -> Any: + """This method is called when the request is received while the work is cold starting. The default + implementation of this method is to forward the request body to the proxy service with POST method but the + user can override this method to handle the request in any way. + + Args: + request (BaseModel): The request body, a pydantic model that is being + forwarded by load balancer which is a FastAPI service + """ + try: + async with aiohttp.ClientSession() as session: + headers = { + "accept": "application/json", + "Content-Type": "application/json", + } + async with session.post( + self.proxy_url, + json=request.dict(), + timeout=self.proxy_timeout, + headers=headers, + ) as response: + return await response.json() + except Exception as ex: + raise HTTPException(status_code=500, detail=f"Error in proxy: {ex}") diff --git a/tests/tests_app/components/serve/test_auto_scaler.py b/tests/tests_app/components/serve/test_auto_scaler.py index e53c7890696a4..5bb28fd8a0ae0 100644 --- a/tests/tests_app/components/serve/test_auto_scaler.py +++ b/tests/tests_app/components/serve/test_auto_scaler.py @@ -28,13 +28,6 @@ def scale(self, replicas: int, metrics) -> int: return replicas - 1 -def test_num_replicas_after_init(): - """Test the number of works is the same as min_replicas after initialization.""" - min_replicas = 2 - auto_scaler = AutoScaler(EmptyWork, min_replicas=min_replicas) - assert auto_scaler.num_replicas == min_replicas - - @patch("uvicorn.run") @patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url") @patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests") @@ -59,7 +52,7 @@ def test_num_replicas_not_above_max_replicas(*_): @patch("uvicorn.run") @patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url") @patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests") -def test_num_replicas_not_belo_min_replicas(*_): +def test_num_replicas_not_below_min_replicas(*_): """Test self.num_replicas doesn't exceed max_replicas.""" min_replicas = 1 auto_scaler = AutoScaler2( @@ -210,7 +203,7 @@ async def test_workers_have_no_capacity_with_cold_start_proxy(self, monkeypatch) ) load_balancer._fastapi_app = mock.MagicMock() load_balancer._fastapi_app.num_current_requests = 1000 - load_balancer._servers.append(mock.MagicMock()) + load_balancer.servers.append(mock.MagicMock()) req_id = uuid.uuid4().hex await load_balancer.process_request("test", req_id) load_balancer._cold_start_proxy.handle_request.assert_called_once_with("test") @@ -222,7 +215,7 @@ async def test_workers_are_free(self): output_type=Text, endpoint="/predict", ) - load_balancer._servers.append(mock.MagicMock()) + load_balancer.servers.append(mock.MagicMock()) req_id = uuid.uuid4().hex # populating the responses so the while loop exists load_balancer._responses = {req_id: "Dummy"}