Skip to content

Commit b4691ff

Browse files
committed
hotfix import torch (#15849)
* fix import torch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * plugin * fix * skip * patch require * seed * warn * . * .. * skip True * 0.0.3 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit ad4bd66)
1 parent 5d9d5d7 commit b4691ff

File tree

11 files changed

+41
-23
lines changed

11 files changed

+41
-23
lines changed

.github/workflows/ci-pkg-install.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
- name: DocTests actions
4747
working-directory: .actions/
4848
run: |
49-
pip install pytest -q
49+
pip install -q pytest
5050
python -m pytest setup_tools.py
5151
5252
- run: python -c "print('NB_DIRS=' + str(2 if '${{ matrix.pkg-name }}' == 'pytorch' else 1))" >> $GITHUB_ENV
@@ -67,7 +67,10 @@ jobs:
6767

6868
- name: DocTest package
6969
env:
70+
LIGHTING_TESTING: 1 # path for require wrapper
7071
PY_IGNORE_IMPORTMISMATCH: 1
7172
run: |
73+
pip install -q "pytest-doctestplus>=0.9.0"
74+
pip list
7275
PKG_NAME=$(python -c "print({'app': 'lightning_app', 'lite': 'lightning_lite', 'pytorch': 'pytorch_lightning', 'lightning': 'lightning'}['${{matrix.pkg-name}}'])")
7376
python -m pytest src/${PKG_NAME} --ignore-glob="**/cli/*-template/**"

requirements/app/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ beautifulsoup4>=4.8.0, <4.11.2
1212
inquirer>=2.10.0
1313
psutil<5.9.4
1414
click<=8.1.3
15-
lightning_api_access>=0.0.1
15+
lightning_api_access>=0.0.3
1616
s3fs>=2022.5.0, <2022.8.3

requirements/app/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ codecov==2.1.12
33
pytest==7.2.0
44
pytest-timeout==2.1.0
55
pytest-cov==4.0.0
6+
pytest-doctestplus>=0.9.0
67
playwright==1.27.1
78
httpx
89
trio<0.22.0

src/lightning_app/components/serve/python_server.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pathlib import Path
55
from typing import Any, Dict, Optional
66

7-
import torch
87
import uvicorn
98
from fastapi import FastAPI
109
from pydantic import BaseModel
@@ -13,16 +12,21 @@
1312
from lightning_app.core.queues import MultiProcessQueue
1413
from lightning_app.core.work import LightningWork
1514
from lightning_app.utilities.app_helpers import Logger
15+
from lightning_app.utilities.imports import _is_torch_available, requires
1616
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver
1717

1818
logger = Logger(__name__)
1919

20+
# Skip doctests if requirements aren't available
21+
if not _is_torch_available():
22+
__doctest_skip__ = ["PythonServer", "PythonServer.*"]
23+
2024

2125
class _PyTorchSpawnRunExecutor(WorkRunExecutor):
2226

2327
"""This Executor enables to move PyTorch tensors on GPU.
2428
25-
Without this executor, it woud raise the following expection:
29+
Without this executor, it would raise the following exception:
2630
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
2731
To use CUDA with multiprocessing, you must use the 'spawn' start method
2832
"""
@@ -86,6 +90,7 @@ def _get_sample_data() -> Dict[Any, Any]:
8690

8791

8892
class PythonServer(LightningWork, abc.ABC):
93+
@requires("torch")
8994
def __init__( # type: ignore
9095
self,
9196
host: str = "127.0.0.1",
@@ -127,15 +132,16 @@ def predict(self, request):
127132
and this can be accessed as `response.json()["prediction"]` in the client if
128133
you are using requests library
129134
130-
.. doctest::
135+
Example:
131136
132137
>>> from lightning_app.components.serve.python_server import PythonServer
133138
>>> from lightning_app import LightningApp
134-
>>>
135139
...
136140
>>> class SimpleServer(PythonServer):
141+
...
137142
... def setup(self):
138143
... self._model = lambda x: x + " " + x
144+
...
139145
... def predict(self, request):
140146
... return {"prediction": self._model(request.image)}
141147
...
@@ -199,11 +205,13 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
199205
return out
200206

201207
def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
208+
from torch import inference_mode
209+
202210
input_type: type = self.configure_input_type()
203211
output_type: type = self.configure_output_type()
204212

205213
def predict_fn(request: input_type): # type: ignore
206-
with torch.inference_mode():
214+
with inference_mode():
207215
return self.predict(request)
208216

209217
fastapi_app.post("/predict", response_model=output_type)(predict_fn)

src/lightning_app/utilities/imports.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""General utilities."""
15+
1516
import functools
1617
import os
18+
import warnings
1719
from typing import List, Union
1820

1921
from lightning_utilities.core.imports import module_available
@@ -52,10 +54,13 @@ def decorator(func):
5254
@functools.wraps(func)
5355
def wrapper(*args, **kwargs):
5456
unavailable_modules = [f"'{module}'" for module in module_paths if not module_available(module)]
55-
if any(unavailable_modules) and not bool(int(os.getenv("LIGHTING_TESTING", "0"))):
56-
raise ModuleNotFoundError(
57-
f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
58-
)
57+
if any(unavailable_modules):
58+
is_lit_testing = bool(int(os.getenv("LIGHTING_TESTING", "0")))
59+
msg = f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
60+
if is_lit_testing:
61+
warnings.warn(msg)
62+
else:
63+
raise ModuleNotFoundError(msg)
5964
return func(*args, **kwargs)
6065

6166
return wrapper

src/lightning_app/utilities/name_generator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,12 +1332,13 @@ def get_unique_name():
13321332
Original source:
13331333
https://raw.githubusercontent.com/moby/moby/master/pkg/namesgenerator/names-generator.go
13341334
1335-
Examples
1336-
--------
1337-
>>> get_unique_name() # doctest: +SKIP
1338-
'focused-turing-23'
1339-
>>> get_unique_name() # doctest: +SKIP
1340-
'thirsty-allen-9200'
1335+
Examples:
1336+
1337+
>>> import random ; random.seed(42)
1338+
>>> get_unique_name()
1339+
'meek-ardinghelli-4506'
1340+
>>> get_unique_name()
1341+
'truthful-dijkstra-2286'
13411342
"""
13421343
adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999)
13431344
return f"{adjective}-{surname}-{i}"

tests/tests_app/core/test_lightning_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def run(self):
106106

107107

108108
# TODO: Find why this test is flaky.
109-
@pytest.mark.skipif(True, reason="flaky test.")
109+
@pytest.mark.skip(reason="flaky test.")
110110
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime])
111111
def test_app_state_api_with_flows(runtime_cls, tmpdir):
112112
"""This test validates the AppState can properly broadcast changes from flows."""
@@ -180,7 +180,7 @@ def maybe_apply_changes(self):
180180

181181

182182
# FIXME: This test doesn't assert anything
183-
@pytest.mark.skipif(True, reason="TODO: Resolve flaky test.")
183+
@pytest.mark.skip(reason="TODO: Resolve flaky test.")
184184
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime])
185185
def test_app_stage_from_frontend(runtime_cls):
186186
"""This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would

tests/tests_app/core/test_lightning_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def run(self):
582582

583583

584584
# TODO (tchaton) Resolve this test.
585-
@pytest.mark.skipif(True, reason="flaky test which never terminates")
585+
@pytest.mark.skip(reason="flaky test which never terminates")
586586
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime])
587587
@pytest.mark.parametrize("use_same_args", [False, True])
588588
def test_state_wait_for_all_all_works(tmpdir, runtime_cls, use_same_args):

tests/tests_app/structures/test_structures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def run(self):
308308
self.counter += 1
309309

310310

311-
@pytest.mark.skipif(True, reason="tchaton: Resolve this test.")
311+
@pytest.mark.skip(reason="tchaton: Resolve this test.")
312312
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime, SingleProcessRuntime])
313313
@pytest.mark.parametrize("run_once_iterable", [False, True])
314314
@pytest.mark.parametrize("cache_calls", [False, True])

tests/tests_app/utilities/packaging/test_docker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from lightning_app.utilities.redis import check_if_redis_running
1313

1414

15-
@pytest.mark.skipif(True, reason="FIXME (tchaton)")
15+
@pytest.mark.skip(reason="FIXME (tchaton)")
1616
@pytest.mark.skipif(not _is_docker_available(), reason="docker is required for this test.")
1717
@pytest.mark.skipif(not check_if_redis_running(), reason="redis is required for this test.")
1818
@_RunIf(skip_windows=True)

0 commit comments

Comments
 (0)