Skip to content

Commit 08d14ec

Browse files
author
Sherin Thomas
authored
Torch inference mode for prediction (#15719)
torch inference mode for prediction
1 parent f40eb2c commit 08d14ec

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/lightning_app/components/serve/python_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from typing import Any, Dict, Optional
55

6+
import torch
67
import uvicorn
78
from fastapi import FastAPI
89
from pydantic import BaseModel
@@ -105,7 +106,7 @@ def predict(self, request):
105106
self._input_type = input_type
106107
self._output_type = output_type
107108

108-
def setup(self) -> None:
109+
def setup(self, *args, **kwargs) -> None:
109110
"""This method is called before the server starts. Override this if you need to download the model or
110111
initialize the weights, setting up pipelines etc.
111112
@@ -154,7 +155,8 @@ def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
154155
output_type: type = self.configure_output_type()
155156

156157
def predict_fn(request: input_type): # type: ignore
157-
return self.predict(request)
158+
with torch.inference_mode():
159+
return self.predict(request)
158160

159161
fastapi_app.post("/predict", response_model=output_type)(predict_fn)
160162

@@ -207,7 +209,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
207209
208210
Normally, you don't need to override this method.
209211
"""
210-
self.setup()
212+
self.setup(*args, **kwargs)
211213

212214
fastapi_app = FastAPI()
213215
self._attach_predict_fn(fastapi_app)

0 commit comments

Comments
 (0)