Skip to content

Commit 347565a

Browse files
tchatonthomas
authored andcommitted
Add support for async method and remove context PythonServer (#16453)
Co-authored-by: thomas <[email protected]> (cherry picked from commit 48e1c9c)
1 parent a2788a5 commit 347565a

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414

1515
### Changed
1616

17-
-
17+
- Add support for async predict method in PythonServer and remove torch context ([#16453](https://github.com/Lightning-AI/lightning/pull/16453))
1818

1919

2020
### Deprecated

src/lightning_app/components/serve/python_server.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import asyncio
23
import base64
34
import os
45
import platform
@@ -252,19 +253,19 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
252253
return out
253254

254255
def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
255-
from torch import inference_mode, no_grad
256-
257256
input_type: type = self.configure_input_type()
258257
output_type: type = self.configure_output_type()
259258

260-
device = _get_device()
261-
context = no_grad if device.type == "mps" else inference_mode
259+
def predict_fn_sync(request: input_type): # type: ignore
260+
return self.predict(request)
262261

263-
def predict_fn(request: input_type): # type: ignore
264-
with context():
265-
return self.predict(request)
262+
async def async_predict_fn(request: input_type): # type: ignore
263+
return await self.predict(request)
266264

267-
fastapi_app.post("/predict", response_model=output_type)(predict_fn)
265+
if asyncio.iscoroutinefunction(self.predict):
266+
fastapi_app.post("/predict", response_model=output_type)(async_predict_fn)
267+
else:
268+
fastapi_app.post("/predict", response_model=output_type)(predict_fn_sync)
268269

269270
def get_code_sample(self, url: str) -> Optional[str]:
270271
input_type: Any = self.configure_input_type()

0 commit comments

Comments
 (0)