6
6
import uuid
7
7
from base64 import b64encode
8
8
from itertools import cycle
9
- from typing import Any , Dict , List , Tuple , Type
9
+ from typing import Any , Dict , List , Optional , Tuple , Type
10
10
11
11
import requests
12
12
import uvicorn
15
15
from fastapi .responses import RedirectResponse
16
16
from fastapi .security import HTTPBasic , HTTPBasicCredentials
17
17
from pydantic import BaseModel
18
+ from starlette .staticfiles import StaticFiles
18
19
from starlette .status import HTTP_401_UNAUTHORIZED
19
20
20
21
from lightning_app .core .flow import LightningFlow
21
22
from lightning_app .core .work import LightningWork
22
23
from lightning_app .utilities .app_helpers import Logger
24
+ from lightning_app .utilities .cloud import is_running_in_cloud
23
25
from lightning_app .utilities .imports import _is_aiohttp_available , requires
24
26
from lightning_app .utilities .packaging .cloud_compute import CloudCompute
25
27
@@ -114,20 +116,21 @@ class _LoadBalancer(LightningWork):
114
116
requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
115
117
timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
116
118
timeout_inference_request: The number of seconds to wait for inference.
117
- \ **kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
119
+ **kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
118
120
"""
119
121
120
122
@requires (["aiohttp" ])
121
123
def __init__ (
122
124
self ,
123
- input_type : BaseModel ,
124
- output_type : BaseModel ,
125
+ input_type : Type [ BaseModel ] ,
126
+ output_type : Type [ BaseModel ] ,
125
127
endpoint : str ,
126
128
max_batch_size : int = 8 ,
127
129
# all timeout args are in seconds
128
- timeout_batching : int = 1 ,
130
+ timeout_batching : float = 1 ,
129
131
timeout_keep_alive : int = 60 ,
130
132
timeout_inference_request : int = 60 ,
133
+ work_name : Optional [str ] = "API" , # used for displaying the name in the UI
131
134
** kwargs : Any ,
132
135
) -> None :
133
136
super ().__init__ (cloud_compute = CloudCompute ("default" ), ** kwargs )
@@ -142,6 +145,7 @@ def __init__(
142
145
self ._batch = []
143
146
self ._responses = {} # {request_id: response}
144
147
self ._last_batch_sent = 0
148
+ self ._work_name = work_name
145
149
146
150
if not endpoint .startswith ("/" ):
147
151
endpoint = "/" + endpoint
@@ -280,6 +284,14 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe
280
284
async def balance_api (inputs : self ._input_type ):
281
285
return await self .process_request (inputs )
282
286
287
+ endpoint_info_page = self ._get_endpoint_info_page ()
288
+ if endpoint_info_page :
289
+ fastapi_app .mount (
290
+ "/endpoint-info" , StaticFiles (directory = endpoint_info_page .serve_dir , html = True ), name = "static"
291
+ )
292
+
293
+ logger .info (f"Your load balancer has started. The endpoint is 'http://{ self .host } :{ self .port } { self .endpoint } '" )
294
+
283
295
uvicorn .run (
284
296
fastapi_app ,
285
297
host = self .host ,
@@ -332,6 +344,60 @@ def send_request_to_update_servers(self, servers: List[str]):
332
344
response = requests .put (f"{ self .url } /system/update-servers" , json = servers , headers = headers , timeout = 10 )
333
345
response .raise_for_status ()
334
346
347
+ @staticmethod
348
+ def _get_sample_dict_from_datatype (datatype : Any ) -> dict :
349
+ if not hasattr (datatype , "schema" ):
350
+ # not a pydantic model
351
+ raise TypeError (f"datatype must be a pydantic model, for the UI to be generated. but got { datatype } " )
352
+
353
+ if hasattr (datatype , "_get_sample_data" ):
354
+ return datatype ._get_sample_data ()
355
+
356
+ datatype_props = datatype .schema ()["properties" ]
357
+ out : Dict [str , Any ] = {}
358
+ lut = {"string" : "data string" , "number" : 0.0 , "integer" : 0 , "boolean" : False }
359
+ for k , v in datatype_props .items ():
360
+ if v ["type" ] not in lut :
361
+ raise TypeError ("Unsupported type" )
362
+ out [k ] = lut [v ["type" ]]
363
+ return out
364
+
365
+ def get_code_sample (self , url : str ) -> Optional [str ]:
366
+ input_type : Any = self ._input_type
367
+ output_type : Any = self ._output_type
368
+
369
+ if not (hasattr (input_type , "request_code_sample" ) and hasattr (output_type , "response_code_sample" )):
370
+ return None
371
+ return f"{ input_type .request_code_sample (url )} \n { output_type .response_code_sample ()} "
372
+
373
+ def _get_endpoint_info_page (self ) -> Optional ["APIAccessFrontend" ]: # noqa: F821
374
+ try :
375
+ from lightning_api_access import APIAccessFrontend
376
+ except ModuleNotFoundError :
377
+ logger .warn ("APIAccessFrontend not found. Please install lightning-api-access to enable the UI" )
378
+ return
379
+
380
+ if is_running_in_cloud ():
381
+ url = f"{ self ._future_url } { self .endpoint } "
382
+ else :
383
+ url = f"http://localhost:{ self .port } { self .endpoint } "
384
+
385
+ frontend_objects = {"name" : self ._work_name , "url" : url , "method" : "POST" , "request" : None , "response" : None }
386
+ code_samples = self .get_code_sample (url )
387
+ if code_samples :
388
+ frontend_objects ["code_samples" ] = code_samples
389
+ # TODO also set request/response for JS UI
390
+ else :
391
+ try :
392
+ request = self ._get_sample_dict_from_datatype (self ._input_type )
393
+ response = self ._get_sample_dict_from_datatype (self ._output_type )
394
+ except TypeError :
395
+ return None
396
+ else :
397
+ frontend_objects ["request" ] = request
398
+ frontend_objects ["response" ] = response
399
+ return APIAccessFrontend (apis = [frontend_objects ])
400
+
335
401
336
402
class AutoScaler (LightningFlow ):
337
403
"""The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in
@@ -403,8 +469,8 @@ def __init__(
403
469
max_batch_size : int = 8 ,
404
470
timeout_batching : float = 1 ,
405
471
endpoint : str = "api/predict" ,
406
- input_type : BaseModel = Dict ,
407
- output_type : BaseModel = Dict ,
472
+ input_type : Type [ BaseModel ] = Dict ,
473
+ output_type : Type [ BaseModel ] = Dict ,
408
474
* work_args : Any ,
409
475
** work_kwargs : Any ,
410
476
) -> None :
@@ -438,6 +504,7 @@ def __init__(
438
504
timeout_batching = timeout_batching ,
439
505
cache_calls = True ,
440
506
parallel = True ,
507
+ work_name = self ._work_cls .__name__ ,
441
508
)
442
509
for _ in range (min_replicas ):
443
510
work = self .create_work ()
@@ -574,5 +641,8 @@ def autoscale(self) -> None:
574
641
self ._last_autoscale = time .time ()
575
642
576
643
def configure_layout (self ):
577
- tabs = [{"name" : "Swagger" , "content" : self .load_balancer .url }]
644
+ tabs = [
645
+ {"name" : "Endpoint Info" , "content" : f"{ self .load_balancer } /endpoint-info" },
646
+ {"name" : "Swagger" , "content" : self .load_balancer .url },
647
+ ]
578
648
return tabs
0 commit comments