|
4 | 4 | from pathlib import Path
|
5 | 5 | from typing import Any, Dict, Optional
|
6 | 6 |
|
7 |
| -import torch |
8 | 7 | import uvicorn
|
9 | 8 | from fastapi import FastAPI
|
10 | 9 | from pydantic import BaseModel
|
|
13 | 12 | from lightning_app.core.queues import MultiProcessQueue
|
14 | 13 | from lightning_app.core.work import LightningWork
|
15 | 14 | from lightning_app.utilities.app_helpers import Logger
|
| 15 | +from lightning_app.utilities.imports import _is_torch_available, requires |
16 | 16 | from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver
|
17 | 17 |
|
18 | 18 | logger = Logger(__name__)
|
19 | 19 |
|
| 20 | +# Skip doctests if requirements aren't available |
| 21 | +if not _is_torch_available(): |
| 22 | + __doctest_skip__ = ["PythonServer", "PythonServer.*"] |
| 23 | + |
20 | 24 |
|
21 | 25 | class _PyTorchSpawnRunExecutor(WorkRunExecutor):
|
22 | 26 |
|
23 | 27 | """This Executor enables to move PyTorch tensors on GPU.
|
24 | 28 |
|
25 |
| - Without this executor, it woud raise the following expection: |
| 29 | + Without this executor, it would raise the following exception: |
26 | 30 | RuntimeError: Cannot re-initialize CUDA in forked subprocess.
|
27 | 31 | To use CUDA with multiprocessing, you must use the 'spawn' start method
|
28 | 32 | """
|
@@ -86,6 +90,7 @@ def _get_sample_data() -> Dict[Any, Any]:
|
86 | 90 |
|
87 | 91 |
|
88 | 92 | class PythonServer(LightningWork, abc.ABC):
|
| 93 | + @requires("torch") |
89 | 94 | def __init__( # type: ignore
|
90 | 95 | self,
|
91 | 96 | host: str = "127.0.0.1",
|
@@ -127,15 +132,16 @@ def predict(self, request):
|
127 | 132 | and this can be accessed as `response.json()["prediction"]` in the client if
|
128 | 133 | you are using requests library
|
129 | 134 |
|
130 |
| - .. doctest:: |
| 135 | + Example: |
131 | 136 |
|
132 | 137 | >>> from lightning_app.components.serve.python_server import PythonServer
|
133 | 138 | >>> from lightning_app import LightningApp
|
134 |
| - >>> |
135 | 139 | ...
|
136 | 140 | >>> class SimpleServer(PythonServer):
|
| 141 | + ... |
137 | 142 | ... def setup(self):
|
138 | 143 | ... self._model = lambda x: x + " " + x
|
| 144 | + ... |
139 | 145 | ... def predict(self, request):
|
140 | 146 | ... return {"prediction": self._model(request.image)}
|
141 | 147 | ...
|
@@ -199,11 +205,13 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
|
199 | 205 | return out
|
200 | 206 |
|
201 | 207 | def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
|
| 208 | + from torch import inference_mode |
| 209 | + |
202 | 210 | input_type: type = self.configure_input_type()
|
203 | 211 | output_type: type = self.configure_output_type()
|
204 | 212 |
|
205 | 213 | def predict_fn(request: input_type): # type: ignore
|
206 |
| - with torch.inference_mode(): |
| 214 | + with inference_mode(): |
207 | 215 | return self.predict(request)
|
208 | 216 |
|
209 | 217 | fastapi_app.post("/predict", response_model=output_type)(predict_fn)
|
|
0 commit comments