|
1 | 1 | import abc
|
2 | 2 | import base64
|
3 | 3 | import os
|
| 4 | +import platform |
4 | 5 | from pathlib import Path
|
5 | 6 | from typing import Any, Dict, Optional
|
6 | 7 |
|
7 | 8 | import uvicorn
|
8 | 9 | from fastapi import FastAPI
|
9 |
| -from lightning_utilities.core.imports import module_available |
| 10 | +from lightning_utilities.core.imports import compare_version, module_available |
10 | 11 | from pydantic import BaseModel
|
11 | 12 |
|
12 |
| -from lightning_app.core.queues import MultiProcessQueue |
13 | 13 | from lightning_app.core.work import LightningWork
|
14 | 14 | from lightning_app.utilities.app_helpers import Logger
|
15 | 15 | from lightning_app.utilities.imports import _is_torch_available, requires
|
16 |
| -from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver |
17 | 16 |
|
18 | 17 | logger = Logger(__name__)
|
19 | 18 |
|
|
27 | 26 | __doctest_skip__ += ["PythonServer", "PythonServer.*"]
|
28 | 27 |
|
29 | 28 |
|
30 |
| -class _PyTorchSpawnRunExecutor(WorkRunExecutor): |
| 29 | +def _get_device(): |
| 30 | + import operator |
31 | 31 |
|
32 |
| - """This Executor enables to move PyTorch tensors on GPU. |
| 32 | + import torch |
33 | 33 |
|
34 |
| - Without this executor, it would raise the following exception: |
35 |
| - RuntimeError: Cannot re-initialize CUDA in forked subprocess. |
36 |
| - To use CUDA with multiprocessing, you must use the 'spawn' start method |
37 |
| - """ |
| 34 | + _TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0") |
38 | 35 |
|
39 |
| - enable_start_observer: bool = False |
| 36 | + local_rank = int(os.getenv("LOCAL_RANK", "0")) |
40 | 37 |
|
41 |
| - def __call__(self, *args: Any, **kwargs: Any): |
42 |
| - import torch |
43 |
| - |
44 |
| - with self.enable_spawn(): |
45 |
| - queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict() |
46 |
| - torch.multiprocessing.spawn( |
47 |
| - self.dispatch_run, |
48 |
| - args=(self.__class__, self.work, queue, args, kwargs), |
49 |
| - nprocs=1, |
50 |
| - ) |
51 |
| - |
52 |
| - @staticmethod |
53 |
| - def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs): |
54 |
| - if local_rank == 0: |
55 |
| - if isinstance(delta_queue, dict): |
56 |
| - delta_queue = cls.process_queue(delta_queue) |
57 |
| - work._request_queue = cls.process_queue(work._request_queue) |
58 |
| - work._response_queue = cls.process_queue(work._response_queue) |
59 |
| - |
60 |
| - state_observer = WorkStateObserver(work, delta_queue=delta_queue) |
61 |
| - state_observer.start() |
62 |
| - _proxy_setattr(work, delta_queue, state_observer) |
63 |
| - |
64 |
| - unwrap(work.run)(*args, **kwargs) |
65 |
| - |
66 |
| - if local_rank == 0: |
67 |
| - state_observer.join(0) |
| 38 | + if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): |
| 39 | + return torch.device("mps", local_rank) |
| 40 | + else: |
| 41 | + return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") |
68 | 42 |
|
69 | 43 |
|
70 | 44 | class _DefaultInputData(BaseModel):
|
@@ -95,6 +69,9 @@ def _get_sample_data() -> Dict[Any, Any]:
|
95 | 69 |
|
96 | 70 |
|
97 | 71 | class PythonServer(LightningWork, abc.ABC):
|
| 72 | + |
| 73 | + _start_method = "spawn" |
| 74 | + |
98 | 75 | @requires(["torch", "lightning_api_access"])
|
99 | 76 | def __init__( # type: ignore
|
100 | 77 | self,
|
@@ -160,11 +137,6 @@ def predict(self, request):
|
160 | 137 | self._input_type = input_type
|
161 | 138 | self._output_type = output_type
|
162 | 139 |
|
163 |
| - # Note: Enable to run inference on GPUs. |
164 |
| - self._run_executor_cls = ( |
165 |
| - WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor |
166 |
| - ) |
167 |
| - |
168 | 140 | def setup(self, *args, **kwargs) -> None:
|
169 | 141 | """This method is called before the server starts. Override this if you need to download the model or
|
170 | 142 | initialize the weights, setting up pipelines etc.
|
@@ -210,13 +182,16 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
|
210 | 182 | return out
|
211 | 183 |
|
212 | 184 | def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
|
213 |
| - from torch import inference_mode |
| 185 | + from torch import inference_mode, no_grad |
214 | 186 |
|
215 | 187 | input_type: type = self.configure_input_type()
|
216 | 188 | output_type: type = self.configure_output_type()
|
217 | 189 |
|
| 190 | + device = _get_device() |
| 191 | + context = no_grad if device.type == "mps" else inference_mode |
| 192 | + |
218 | 193 | def predict_fn(request: input_type): # type: ignore
|
219 |
| - with inference_mode(): |
| 194 | + with context(): |
220 | 195 | return self.predict(request)
|
221 | 196 |
|
222 | 197 | fastapi_app.post("/predict", response_model=output_type)(predict_fn)
|
|
0 commit comments