File tree Expand file tree Collapse file tree 2 files changed +10
-9
lines changed Expand file tree Collapse file tree 2 files changed +10
-9
lines changed Original file line number Diff line number Diff line change @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
14
14
15
15
### Changed
16
16
17
- -
17
+ - Add support for async predict method in PythonServer and remove torch context ( [ # 16453 ] ( https://github.com/Lightning-AI/lightning/pull/16453 ) )
18
18
19
19
20
20
### Deprecated
Original file line number Diff line number Diff line change 1
1
import abc
2
+ import asyncio
2
3
import base64
3
4
import os
4
5
import platform
@@ -252,19 +253,19 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
252
253
return out
253
254
254
255
def _attach_predict_fn (self , fastapi_app : FastAPI ) -> None :
255
- from torch import inference_mode , no_grad
256
-
257
256
input_type : type = self .configure_input_type ()
258
257
output_type : type = self .configure_output_type ()
259
258
260
- device = _get_device ()
261
- context = no_grad if device . type == "mps" else inference_mode
259
+ def predict_fn_sync ( request : input_type ): # type: ignore
260
+ return self . predict ( request )
262
261
263
- def predict_fn (request : input_type ): # type: ignore
264
- with context ():
265
- return self .predict (request )
262
+ async def async_predict_fn (request : input_type ): # type: ignore
263
+ return await self .predict (request )
266
264
267
- fastapi_app .post ("/predict" , response_model = output_type )(predict_fn )
265
+ if asyncio .iscoroutinefunction (self .predict ):
266
+ fastapi_app .post ("/predict" , response_model = output_type )(async_predict_fn )
267
+ else :
268
+ fastapi_app .post ("/predict" , response_model = output_type )(predict_fn_sync )
268
269
269
270
def get_code_sample (self , url : str ) -> Optional [str ]:
270
271
input_type : Any = self .configure_input_type ()
You can’t perform that action at this time.
0 commit comments