|
3 | 3 | from pathlib import Path
|
4 | 4 | from typing import Any, Dict, Optional
|
5 | 5 |
|
| 6 | +import torch |
6 | 7 | import uvicorn
|
7 | 8 | from fastapi import FastAPI
|
8 | 9 | from pydantic import BaseModel
|
@@ -105,7 +106,7 @@ def predict(self, request):
|
105 | 106 | self._input_type = input_type
|
106 | 107 | self._output_type = output_type
|
107 | 108 |
|
108 |
| - def setup(self) -> None: |
| 109 | + def setup(self, *args, **kwargs) -> None: |
109 | 110 | """This method is called before the server starts. Override this if you need to download the model or
|
110 | 111 | initialize the weights, setting up pipelines etc.
|
111 | 112 |
|
@@ -154,7 +155,8 @@ def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
|
154 | 155 | output_type: type = self.configure_output_type()
|
155 | 156 |
|
156 | 157 | def predict_fn(request: input_type): # type: ignore
|
157 |
| - return self.predict(request) |
| 158 | + with torch.inference_mode(): |
| 159 | + return self.predict(request) |
158 | 160 |
|
159 | 161 | fastapi_app.post("/predict", response_model=output_type)(predict_fn)
|
160 | 162 |
|
@@ -207,7 +209,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
|
207 | 209 |
|
208 | 210 | Normally, you don't need to override this method.
|
209 | 211 | """
|
210 |
| - self.setup() |
| 212 | + self.setup(*args, **kwargs) |
211 | 213 |
|
212 | 214 | fastapi_app = FastAPI()
|
213 | 215 | self._attach_predict_fn(fastapi_app)
|
|
0 commit comments