From c03b344a378e6ce2509a7999598a5163395f6df6 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 25 Nov 2022 07:12:00 -0500 Subject: [PATCH 1/5] update --- MANIFEST.in | 99 +++++++++++++++++++ .../components/serve/python_server.py | 49 ++++++++- 2 files changed, 147 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index ac8c2556d4f02..2d53517a0dff0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,3 +5,102 @@ include .actions/setup_tools.py include .actions/assistant.py include src/version.info include *.cff # citation info +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +include src/lightning/version.info + +include src/lightning/app/components/serve/catimage.png + +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 731bf1c37e969..652403d69db1a 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -2,7 +2,7 @@ import base64 from pathlib import Path from typing import Any, Dict, Optional - +import os import torch import uvicorn from fastapi import FastAPI @@ -11,10 +11,53 @@ from lightning_app.core.work import LightningWork from lightning_app.utilities.app_helpers import Logger +from typing import Any, Callable, Type +from lightning_app.core.queues import MultiProcessQueue +from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver logger = Logger(__name__) +class _PyTorchSpawnRunExecutor(WorkRunExecutor): + + """This Executor enables to move PyTorch tensors on GPU. + + Without this executor, it woud raise the following expection: + RuntimeError: Cannot re-initialize CUDA in forked subprocess. + To use CUDA with multiprocessing, you must use the 'spawn' start method + """ + + enable_start_observer: bool = False + + def __call__(self, *args: Any, **kwargs: Any): + import torch + + with self.enable_spawn(): + queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict() + torch.multiprocessing.spawn( + self.dispatch_run, + args=(self.__class__, self.work, queue, args, kwargs), + nprocs=1, + ) + + @staticmethod + def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs): + if local_rank == 0: + if isinstance(delta_queue, dict): + delta_queue = cls.process_queue(delta_queue) + work._request_queue = cls.process_queue(work._request_queue) + work._response_queue = cls.process_queue(work._response_queue) + + state_observer = WorkStateObserver(work, delta_queue=delta_queue) + state_observer.start() + _proxy_setattr(work, delta_queue, state_observer) + + unwrap(work.run)(*args, **kwargs) + + if local_rank == 0: + state_observer.join(0) + + class _DefaultInputData(BaseModel): payload: str @@ -43,6 +86,7 @@ def _get_sample_data() -> Dict[Any, Any]: class PythonServer(LightningWork, abc.ABC): + def __init__( # type: ignore self, host: str = "127.0.0.1", @@ -105,6 +149,9 @@ def predict(self, request): raise TypeError("output_type must be a pydantic BaseModel class") self._input_type = input_type self._output_type = output_type + + # Note: Enable to run inference on GPUs. + self._run_executor_cls = WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor def setup(self, *args, **kwargs) -> None: """This method is called before the server starts. Override this if you need to download the model or From 00f10beb67fc49d9c922dda94daac8feda18fa72 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 25 Nov 2022 12:13:46 +0000 Subject: [PATCH 2/5] update --- .../components/serve/python_server.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 652403d69db1a..9ce1b23701059 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -1,18 +1,18 @@ import abc import base64 +import os from pathlib import Path from typing import Any, Dict, Optional -import os + import torch import uvicorn from fastapi import FastAPI from pydantic import BaseModel from starlette.staticfiles import StaticFiles +from lightning_app.core.queues import MultiProcessQueue from lightning_app.core.work import LightningWork from lightning_app.utilities.app_helpers import Logger -from typing import Any, Callable, Type -from lightning_app.core.queues import MultiProcessQueue from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver logger = Logger(__name__) @@ -21,9 +21,9 @@ class _PyTorchSpawnRunExecutor(WorkRunExecutor): """This Executor enables to move PyTorch tensors on GPU. - + Without this executor, it woud raise the following expection: - RuntimeError: Cannot re-initialize CUDA in forked subprocess. + RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method """ @@ -86,7 +86,6 @@ def _get_sample_data() -> Dict[Any, Any]: class PythonServer(LightningWork, abc.ABC): - def __init__( # type: ignore self, host: str = "127.0.0.1", @@ -149,9 +148,11 @@ def predict(self, request): raise TypeError("output_type must be a pydantic BaseModel class") self._input_type = input_type self._output_type = output_type - - # Note: Enable to run inference on GPUs. - self._run_executor_cls = WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor + + # Note: Enable to run inference on GPUs. + self._run_executor_cls = ( + WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor + ) def setup(self, *args, **kwargs) -> None: """This method is called before the server starts. Override this if you need to download the model or From 0393e4d4d5c4d5f1646efc0906055574c3f15a42 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 25 Nov 2022 12:15:40 +0000 Subject: [PATCH 3/5] update --- src/lightning_app/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 305ee591b0257..4c5da2c96e2e4 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801)) +- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813)) + ## [1.8.2] - 2022-11-17 From 2aa7e59faf060e9a022dfe71930997d120b01d40 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 25 Nov 2022 12:16:28 +0000 Subject: [PATCH 4/5] update --- MANIFEST.in | 99 ----------------------------------------------------- 1 file changed, 99 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 2d53517a0dff0..ac8c2556d4f02 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,102 +5,3 @@ include .actions/setup_tools.py include .actions/assistant.py include src/version.info include *.cff # citation info -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -include src/lightning/version.info - -include src/lightning/app/components/serve/catimage.png - -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning From 0f0d6bd6412221777b662b6dabb54e609ac9f185 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 25 Nov 2022 13:10:56 +0000 Subject: [PATCH 5/5] update --- src/lightning_app/components/serve/gradio.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/lightning_app/components/serve/gradio.py b/src/lightning_app/components/serve/gradio.py index 328e70e743b43..6e9b1d8777f67 100644 --- a/src/lightning_app/components/serve/gradio.py +++ b/src/lightning_app/components/serve/gradio.py @@ -1,8 +1,10 @@ import abc +import os from functools import partial from types import ModuleType from typing import Any, List, Optional +from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor from lightning_app.core.work import LightningWork from lightning_app.utilities.imports import _is_gradio_available, requires @@ -39,6 +41,10 @@ def __init__(self, *args, **kwargs): assert self.inputs assert self.outputs self._model = None + # Note: Enable to run inference on GPUs. + self._run_executor_cls = ( + WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor + ) @property def model(self):