|
1 | 1 | import abc
|
2 | 2 | import base64
|
| 3 | +import os |
3 | 4 | from pathlib import Path
|
4 | 5 | from typing import Any, Dict, Optional
|
5 | 6 |
|
|
9 | 10 | from pydantic import BaseModel
|
10 | 11 | from starlette.staticfiles import StaticFiles
|
11 | 12 |
|
| 13 | +from lightning_app.core.queues import MultiProcessQueue |
12 | 14 | from lightning_app.core.work import LightningWork
|
13 | 15 | from lightning_app.utilities.app_helpers import Logger
|
| 16 | +from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver |
14 | 17 |
|
15 | 18 | logger = Logger(__name__)
|
16 | 19 |
|
17 | 20 |
|
| 21 | +class _PyTorchSpawnRunExecutor(WorkRunExecutor): |
| 22 | + |
| 23 | + """This Executor enables to move PyTorch tensors on GPU. |
| 24 | +
|
| 25 | + Without this executor, it woud raise the following expection: |
| 26 | + RuntimeError: Cannot re-initialize CUDA in forked subprocess. |
| 27 | + To use CUDA with multiprocessing, you must use the 'spawn' start method |
| 28 | + """ |
| 29 | + |
| 30 | + enable_start_observer: bool = False |
| 31 | + |
| 32 | + def __call__(self, *args: Any, **kwargs: Any): |
| 33 | + import torch |
| 34 | + |
| 35 | + with self.enable_spawn(): |
| 36 | + queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict() |
| 37 | + torch.multiprocessing.spawn( |
| 38 | + self.dispatch_run, |
| 39 | + args=(self.__class__, self.work, queue, args, kwargs), |
| 40 | + nprocs=1, |
| 41 | + ) |
| 42 | + |
| 43 | + @staticmethod |
| 44 | + def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs): |
| 45 | + if local_rank == 0: |
| 46 | + if isinstance(delta_queue, dict): |
| 47 | + delta_queue = cls.process_queue(delta_queue) |
| 48 | + work._request_queue = cls.process_queue(work._request_queue) |
| 49 | + work._response_queue = cls.process_queue(work._response_queue) |
| 50 | + |
| 51 | + state_observer = WorkStateObserver(work, delta_queue=delta_queue) |
| 52 | + state_observer.start() |
| 53 | + _proxy_setattr(work, delta_queue, state_observer) |
| 54 | + |
| 55 | + unwrap(work.run)(*args, **kwargs) |
| 56 | + |
| 57 | + if local_rank == 0: |
| 58 | + state_observer.join(0) |
| 59 | + |
| 60 | + |
18 | 61 | class _DefaultInputData(BaseModel):
|
19 | 62 | payload: str
|
20 | 63 |
|
@@ -106,6 +149,11 @@ def predict(self, request):
|
106 | 149 | self._input_type = input_type
|
107 | 150 | self._output_type = output_type
|
108 | 151 |
|
| 152 | + # Note: Enable to run inference on GPUs. |
| 153 | + self._run_executor_cls = ( |
| 154 | + WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor |
| 155 | + ) |
| 156 | + |
109 | 157 | def setup(self, *args, **kwargs) -> None:
|
110 | 158 | """This method is called before the server starts. Override this if you need to download the model or
|
111 | 159 | initialize the weights, setting up pipelines etc.
|
|
0 commit comments