Skip to content

Commit 904323b

Browse files
tchatonthomaspre-commit-ci[bot]
authored
[App] Resolve PythonServer on M1 (#15949)
Co-authored-by: thomas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 36aecde commit 904323b

File tree

6 files changed

+27
-56
lines changed

6 files changed

+27
-56
lines changed

requirements/app/ui.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
streamlit>=1.3.1, <=1.11.1
1+
streamlit>=1.0.0, <=1.15.2
22
panel>=0.12.7, <=0.13.1

src/lightning_app/CHANGELOG.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5656

5757
- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748))
5858

59-
60-
6159
- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))
6260

63-
6461
- Fixed the `enable_spawn` method of the `WorkRunExecutor` ([#15812](https://github.com/Lightning-AI/lightning/pull/15812)
6562

6663
- Fixed Sigterm Handler causing thread lock which caused KeyboardInterrupt to hang ([#15881](https://github.com/Lightning-AI/lightning/pull/15881))
6764

6865
- Fixed a bug where using `L.app.structures` would cause multiple apps to be opened and fail with an error in the cloud ([#15911](https://github.com/Lightning-AI/lightning/pull/15911))
6966

67+
- Fixed PythonServer generating noise on M1 ([#15949](https://github.com/Lightning-AI/lightning/pull/15949))
68+
7069

7170
## [1.8.3] - 2022-11-22
7271

src/lightning_app/components/auto_scaler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ async def process_request(self, data: BaseModel):
206206
return result
207207

208208
def run(self):
209-
210209
logger.info(f"servers: {self.servers}")
211210
lock = asyncio.Lock()
212211

@@ -271,7 +270,6 @@ async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint))
271270
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
272271
async with lock:
273272
self.servers = servers
274-
275273
self._iter = cycle(self.servers)
276274

277275
@fastapi_app.post(self.endpoint, response_model=self._output_type)

src/lightning_app/components/python/tracer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ class Code(TypedDict):
2222

2323

2424
class TracerPythonScript(LightningWork):
25+
26+
_start_method = "spawn"
27+
2528
def on_before_run(self):
2629
"""Called before the python script is executed."""
2730

src/lightning_app/components/serve/gradio.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import abc
2-
import os
32
from functools import partial
43
from types import ModuleType
54
from typing import Any, List, Optional
65

7-
from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor
86
from lightning_app.core.work import LightningWork
97
from lightning_app.utilities.imports import _is_gradio_available, requires
108

@@ -36,15 +34,13 @@ class ServeGradio(LightningWork, abc.ABC):
3634
title: Optional[str] = None
3735
description: Optional[str] = None
3836

37+
_start_method = "spawn"
38+
3939
def __init__(self, *args, **kwargs):
4040
requires("gradio")(super().__init__(*args, **kwargs))
4141
assert self.inputs
4242
assert self.outputs
4343
self._model = None
44-
# Note: Enable to run inference on GPUs.
45-
self._run_executor_cls = (
46-
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
47-
)
4844

4945
@property
5046
def model(self):

src/lightning_app/components/serve/python_server.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
import abc
22
import base64
33
import os
4+
import platform
45
from pathlib import Path
56
from typing import Any, Dict, Optional
67

78
import uvicorn
89
from fastapi import FastAPI
9-
from lightning_utilities.core.imports import module_available
10+
from lightning_utilities.core.imports import compare_version, module_available
1011
from pydantic import BaseModel
1112

12-
from lightning_app.core.queues import MultiProcessQueue
1313
from lightning_app.core.work import LightningWork
1414
from lightning_app.utilities.app_helpers import Logger
1515
from lightning_app.utilities.imports import _is_torch_available, requires
16-
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver
1716

1817
logger = Logger(__name__)
1918

@@ -27,44 +26,19 @@
2726
__doctest_skip__ += ["PythonServer", "PythonServer.*"]
2827

2928

30-
class _PyTorchSpawnRunExecutor(WorkRunExecutor):
29+
def _get_device():
30+
import operator
3131

32-
"""This Executor enables to move PyTorch tensors on GPU.
32+
import torch
3333

34-
Without this executor, it would raise the following exception:
35-
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
36-
To use CUDA with multiprocessing, you must use the 'spawn' start method
37-
"""
34+
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
3835

39-
enable_start_observer: bool = False
36+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
4037

41-
def __call__(self, *args: Any, **kwargs: Any):
42-
import torch
43-
44-
with self.enable_spawn():
45-
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
46-
torch.multiprocessing.spawn(
47-
self.dispatch_run,
48-
args=(self.__class__, self.work, queue, args, kwargs),
49-
nprocs=1,
50-
)
51-
52-
@staticmethod
53-
def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs):
54-
if local_rank == 0:
55-
if isinstance(delta_queue, dict):
56-
delta_queue = cls.process_queue(delta_queue)
57-
work._request_queue = cls.process_queue(work._request_queue)
58-
work._response_queue = cls.process_queue(work._response_queue)
59-
60-
state_observer = WorkStateObserver(work, delta_queue=delta_queue)
61-
state_observer.start()
62-
_proxy_setattr(work, delta_queue, state_observer)
63-
64-
unwrap(work.run)(*args, **kwargs)
65-
66-
if local_rank == 0:
67-
state_observer.join(0)
38+
if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):
39+
return torch.device("mps", local_rank)
40+
else:
41+
return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
6842

6943

7044
class _DefaultInputData(BaseModel):
@@ -95,6 +69,9 @@ def _get_sample_data() -> Dict[Any, Any]:
9569

9670

9771
class PythonServer(LightningWork, abc.ABC):
72+
73+
_start_method = "spawn"
74+
9875
@requires(["torch", "lightning_api_access"])
9976
def __init__( # type: ignore
10077
self,
@@ -160,11 +137,6 @@ def predict(self, request):
160137
self._input_type = input_type
161138
self._output_type = output_type
162139

163-
# Note: Enable to run inference on GPUs.
164-
self._run_executor_cls = (
165-
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
166-
)
167-
168140
def setup(self, *args, **kwargs) -> None:
169141
"""This method is called before the server starts. Override this if you need to download the model or
170142
initialize the weights, setting up pipelines etc.
@@ -210,13 +182,16 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
210182
return out
211183

212184
def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
213-
from torch import inference_mode
185+
from torch import inference_mode, no_grad
214186

215187
input_type: type = self.configure_input_type()
216188
output_type: type = self.configure_output_type()
217189

190+
device = _get_device()
191+
context = no_grad if device.type == "mps" else inference_mode
192+
218193
def predict_fn(request: input_type): # type: ignore
219-
with inference_mode():
194+
with context():
220195
return self.predict(request)
221196

222197
fastapi_app.post("/predict", response_model=output_type)(predict_fn)

0 commit comments

Comments
 (0)