From 12287997c7ca3006b5a886d19f9faeb93b9d7498 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 20 Jan 2023 15:46:42 +0000 Subject: [PATCH 1/2] update --- .../components/serve/python_server.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index caae6f584cbf4..19a088b86c599 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -1,4 +1,5 @@ import abc +import asyncio import base64 import os import platform @@ -252,19 +253,19 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict: return out def _attach_predict_fn(self, fastapi_app: FastAPI) -> None: - from torch import inference_mode, no_grad - input_type: type = self.configure_input_type() output_type: type = self.configure_output_type() - device = _get_device() - context = no_grad if device.type == "mps" else inference_mode + def predict_fn_sync(request: input_type): # type: ignore + return self.predict(request) - def predict_fn(request: input_type): # type: ignore - with context(): - return self.predict(request) + async def async_predict_fn(request: input_type): # type: ignore + return await self.predict(request) - fastapi_app.post("/predict", response_model=output_type)(predict_fn) + if asyncio.iscoroutinefunction(self.predict): + fastapi_app.post("/predict", response_model=output_type)(async_predict_fn) + else: + fastapi_app.post("/predict", response_model=output_type)(predict_fn_sync) def get_code_sample(self, url: str) -> Optional[str]: input_type: Any = self.configure_input_type() From f8aedfc34fcae25d0c80dcf012e49402ec708942 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 20 Jan 2023 15:49:29 +0000 Subject: [PATCH 2/2] add changelog --- src/lightning_app/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index f3f1800c8cd6c..72f79f4d1f771 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Add support for async predict method in PythonServer and remove torch context ([#16453](https://github.com/Lightning-AI/lightning/pull/16453)) ### Deprecated