Skip to content

Commit 36aecde

Browse files
justusschockBordapre-commit-ci[bot]carmoccaawaelchli
authored
Multinode on MPS (#15748)
* Fix restarting attribute for lr finder * update lite executor * update trainer executor * update spawn executor * add multinode component tests * add testing helpers * add lite tests * add trainer tests * update changelog * update trainer * update workflow * update tests * debug * add reason for skipif * Apply suggestions from code review * switch skipif Co-authored-by: Jirka <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 8475f85 commit 36aecde

File tree

9 files changed

+268
-20
lines changed

9 files changed

+268
-20
lines changed

.github/workflows/ci-app-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494

9595
- name: Adjust tests
9696
if: ${{ matrix.pkg-name == 'lightning' }}
97-
run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app" --target_import="lightning.app"
97+
run: python .actions/assistant.py copy_replace_imports --source_dir="./tests" --source_import="lightning_app,lightning_lite,pytorch_lightning" --target_import="lightning.app,lightning.lite,lightning.pytorch"
9898

9999
- name: Adjust examples
100100
if: ${{ matrix.pkg-name != 'lightning' }}

src/lightning_app/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5454

5555
- Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810))
5656

57+
- Fixed MPS error for multinode component (defaults to cpu on mps devices now as distributed operations are not supported by pytorch on mps) ([#15748](https://github.com/Ligtning-AI/lightning/pull/15748))
58+
59+
5760

5861
- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))
5962

@@ -111,6 +114,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
111114
- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))
112115
- Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650))
113116
- Fixed catimage import ([#15712](https://github.com/Lightning-AI/lightning/pull/15712))
117+
- Fixed setting property to the LightningFlow ([#15750](https://github.com/Lightning-AI/lightning/pull/15750))
114118
- Parse all lines in app file looking for shebangs to run commands ([#15714](https://github.com/Lightning-AI/lightning/pull/15714))
115119

116120

src/lightning_app/components/multi_node/lite.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import importlib
12
import os
3+
import warnings
24
from dataclasses import dataclass
35
from typing import Any, Callable, Type
46

@@ -30,8 +32,16 @@ def run(
3032
node_rank: int,
3133
nprocs: int,
3234
):
33-
from lightning.lite import LightningLite
34-
from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy
35+
lites = []
36+
strategies = []
37+
mps_accelerators = []
38+
39+
for pkg_name in ("lightning.lite", "lightning_" + "lite"):
40+
pkg = importlib.import_module(pkg_name)
41+
lites.append(pkg.LightningLite)
42+
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
43+
strategies.append(pkg.strategies.DDPSpawnStrategy)
44+
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
3545

3646
# Used to configure PyTorch progress group
3747
os.environ["MASTER_ADDR"] = main_address
@@ -52,23 +62,36 @@ def run(
5262
def pre_fn(lite, *args, **kwargs):
5363
kwargs["devices"] = nprocs
5464
kwargs["num_nodes"] = num_nodes
55-
kwargs["accelerator"] = "auto"
65+
66+
if any(acc.is_available() for acc in mps_accelerators):
67+
old_acc_value = kwargs.get("accelerator", "auto")
68+
kwargs["accelerator"] = "cpu"
69+
70+
if old_acc_value != kwargs["accelerator"]:
71+
warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
72+
else:
73+
kwargs["accelerator"] = "auto"
5674
strategy = kwargs.get("strategy", None)
5775
if strategy:
5876
if isinstance(strategy, str):
5977
if strategy == "ddp_spawn":
6078
strategy = "ddp"
6179
elif strategy == "ddp_sharded_spawn":
6280
strategy = "ddp_sharded"
63-
elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)):
64-
raise Exception("DDP Spawned strategies aren't supported yet.")
81+
elif isinstance(strategy, tuple(strategies)):
82+
raise ValueError("DDP Spawned strategies aren't supported yet.")
83+
84+
kwargs["strategy"] = strategy
85+
6586
return {}, args, kwargs
6687

6788
tracer = Tracer()
68-
tracer.add_traced(LightningLite, "__init__", pre_fn=pre_fn)
89+
for ll in lites:
90+
tracer.add_traced(ll, "__init__", pre_fn=pre_fn)
6991
tracer._instrument()
70-
work_run()
92+
ret_val = work_run()
7193
tracer._restore()
94+
return ret_val
7295

7396

7497
class LiteMultiNode(MultiNode):

src/lightning_app/components/multi_node/pytorch_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def run(
8888
elif world_size > 1:
8989
raise Exception("Torch distributed should be available.")
9090

91-
work_run(world_size, node_rank, global_rank, local_rank)
91+
return work_run(world_size, node_rank, global_rank, local_rank)
9292

9393

9494
class PyTorchSpawnMultiNode(MultiNode):

src/lightning_app/components/multi_node/trainer.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import importlib
12
import os
3+
import warnings
24
from dataclasses import dataclass
35
from typing import Any, Callable, Type
46

@@ -30,9 +32,16 @@ def run(
3032
node_rank: int,
3133
nprocs: int,
3234
):
33-
from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy
34-
from lightning.pytorch import Trainer as LTrainer
35-
from pytorch_lightning import Trainer as PLTrainer
35+
trainers = []
36+
strategies = []
37+
mps_accelerators = []
38+
39+
for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"):
40+
pkg = importlib.import_module(pkg_name)
41+
trainers.append(pkg.Trainer)
42+
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
43+
strategies.append(pkg.strategies.DDPSpawnStrategy)
44+
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
3645

3746
# Used to configure PyTorch progress group
3847
os.environ["MASTER_ADDR"] = main_address
@@ -50,24 +59,34 @@ def run(
5059
def pre_fn(trainer, *args, **kwargs):
5160
kwargs["devices"] = nprocs
5261
kwargs["num_nodes"] = num_nodes
53-
kwargs["accelerator"] = "auto"
62+
if any(acc.is_available() for acc in mps_accelerators):
63+
old_acc_value = kwargs.get("accelerator", "auto")
64+
kwargs["accelerator"] = "cpu"
65+
66+
if old_acc_value != kwargs["accelerator"]:
67+
warnings.warn("Forcing `accelerator=cpu` as MPS does not support distributed training.")
68+
else:
69+
kwargs["accelerator"] = "auto"
70+
5471
strategy = kwargs.get("strategy", None)
5572
if strategy:
5673
if isinstance(strategy, str):
5774
if strategy == "ddp_spawn":
5875
strategy = "ddp"
5976
elif strategy == "ddp_sharded_spawn":
6077
strategy = "ddp_sharded"
61-
elif isinstance(strategy, (DDPSpawnStrategy, DDPSpawnShardedStrategy)):
62-
raise Exception("DDP Spawned strategies aren't supported yet.")
78+
elif isinstance(strategy, tuple(strategies)):
79+
raise ValueError("DDP Spawned strategies aren't supported yet.")
80+
kwargs["strategy"] = strategy
6381
return {}, args, kwargs
6482

6583
tracer = Tracer()
66-
tracer.add_traced(PLTrainer, "__init__", pre_fn=pre_fn)
67-
tracer.add_traced(LTrainer, "__init__", pre_fn=pre_fn)
84+
for trainer in trainers:
85+
tracer.add_traced(trainer, "__init__", pre_fn=pre_fn)
6886
tracer._instrument()
69-
work_run()
87+
ret_val = work_run()
7088
tracer._restore()
89+
return ret_val
7190

7291

7392
class LightningTrainerMultiNode(MultiNode):

tests/tests_app/components/multi_node/__init__.py

Whitespace-only changes.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
from copy import deepcopy
3+
from functools import partial
4+
from unittest import mock
5+
6+
import pytest
7+
from lightning_utilities.core.imports import module_available
8+
from tests_app.helpers.utils import no_warning_call
9+
10+
import lightning_lite as ll
11+
from lightning_app.components.multi_node.lite import _LiteRunExecutor
12+
13+
14+
class DummyLite(ll.LightningLite):
15+
def run(self):
16+
pass
17+
18+
19+
def dummy_callable(**kwargs):
20+
lite = DummyLite(**kwargs)
21+
return lite._all_passed_kwargs
22+
23+
24+
def dummy_init(self, **kwargs):
25+
self._all_passed_kwargs = kwargs
26+
27+
28+
def _get_args_after_tracer_injection(**kwargs):
29+
with mock.patch.object(ll.LightningLite, "__init__", dummy_init):
30+
ret_val = _LiteRunExecutor.run(
31+
local_rank=0,
32+
work_run=partial(dummy_callable, **kwargs),
33+
main_address="1.2.3.4",
34+
main_port=5,
35+
node_rank=6,
36+
num_nodes=7,
37+
nprocs=8,
38+
)
39+
env_vars = deepcopy(os.environ)
40+
return ret_val, env_vars
41+
42+
43+
def check_lightning_lite_mps():
44+
if module_available("lightning_lite"):
45+
return ll.accelerators.MPSAccelerator.is_available()
46+
return False
47+
48+
49+
@pytest.mark.skipif(not check_lightning_lite_mps(), reason="Lightning lite not available or mps not available")
50+
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
51+
def test_lite_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
52+
warning_str = (
53+
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
54+
+ "by PyTorch for distributed training on mps capable devices"
55+
)
56+
if accelerator_expected != accelerator_given:
57+
warning_context = pytest.warns(UserWarning, match=warning_str)
58+
else:
59+
warning_context = no_warning_call(match=warning_str + "*")
60+
61+
with warning_context:
62+
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
63+
assert ret_val["accelerator"] == accelerator_expected
64+
65+
66+
@pytest.mark.parametrize(
67+
"args_given,args_expected",
68+
[
69+
({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}),
70+
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
71+
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
72+
],
73+
)
74+
@pytest.mark.skipif(not module_available("lightning"), reason="Lightning is required for this test")
75+
def test_trainer_run_executor_arguments_choices(args_given: dict, args_expected: dict):
76+
77+
# ddp with mps devices not available (tested separately, just patching here for cross-os testing of other args)
78+
if ll.accelerators.MPSAccelerator.is_available():
79+
args_expected["accelerator"] = "cpu"
80+
81+
ret_val, env_vars = _get_args_after_tracer_injection(**args_given)
82+
83+
for k, v in args_expected.items():
84+
assert ret_val[k] == v
85+
86+
assert env_vars["MASTER_ADDR"] == "1.2.3.4"
87+
assert env_vars["MASTER_PORT"] == "5"
88+
assert env_vars["GROUP_RANK"] == "6"
89+
assert env_vars["RANK"] == str(0 + 6 * 8)
90+
assert env_vars["LOCAL_RANK"] == "0"
91+
assert env_vars["WORLD_SIZE"] == str(7 * 8)
92+
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
93+
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
94+
assert env_vars["LT_CLI_USED"] == "1"
95+
96+
97+
@pytest.mark.skipif(not module_available("lightning"), reason="Lightning not available")
98+
def test_lite_run_executor_invalid_strategy_instances():
99+
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
100+
_, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnStrategy())
101+
102+
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
103+
_, _ = _get_args_after_tracer_injection(strategy=ll.strategies.DDPSpawnShardedStrategy())
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import os
2+
from copy import deepcopy
3+
from functools import partial
4+
from unittest import mock
5+
6+
import pytest
7+
from lightning_utilities.core.imports import module_available
8+
from tests_app.helpers.utils import no_warning_call
9+
10+
import pytorch_lightning as pl
11+
from lightning_app.components.multi_node.trainer import _LightningTrainerRunExecutor
12+
13+
14+
def dummy_callable(**kwargs):
15+
t = pl.Trainer(**kwargs)
16+
return t._all_passed_kwargs
17+
18+
19+
def dummy_init(self, **kwargs):
20+
self._all_passed_kwargs = kwargs
21+
22+
23+
def _get_args_after_tracer_injection(**kwargs):
24+
with mock.patch.object(pl.Trainer, "__init__", dummy_init):
25+
ret_val = _LightningTrainerRunExecutor.run(
26+
local_rank=0,
27+
work_run=partial(dummy_callable, **kwargs),
28+
main_address="1.2.3.4",
29+
main_port=5,
30+
node_rank=6,
31+
num_nodes=7,
32+
nprocs=8,
33+
)
34+
env_vars = deepcopy(os.environ)
35+
return ret_val, env_vars
36+
37+
38+
def check_lightning_pytorch_and_mps():
39+
if module_available("pytorch_lightning"):
40+
return pl.accelerators.MPSAccelerator.is_available()
41+
return False
42+
43+
44+
@pytest.mark.skipif(not check_lightning_pytorch_and_mps(), reason="pytorch_lightning and mps are required")
45+
@pytest.mark.parametrize("accelerator_given,accelerator_expected", [("cpu", "cpu"), ("auto", "cpu"), ("gpu", "cpu")])
46+
def test_trainer_run_executor_mps_forced_cpu(accelerator_given, accelerator_expected):
47+
warning_str = (
48+
r"Forcing accelerator=cpu as other accelerators \(specifically MPS\) are not supported "
49+
+ "by PyTorch for distributed training on mps capable devices"
50+
)
51+
if accelerator_expected != accelerator_given:
52+
warning_context = pytest.warns(UserWarning, match=warning_str)
53+
else:
54+
warning_context = no_warning_call(match=warning_str + "*")
55+
56+
with warning_context:
57+
ret_val, env_vars = _get_args_after_tracer_injection(accelerator=accelerator_given)
58+
assert ret_val["accelerator"] == accelerator_expected
59+
60+
61+
@pytest.mark.parametrize(
62+
"args_given,args_expected",
63+
[
64+
({"devices": 1, "num_nodes": 1, "accelerator": "gpu"}, {"devices": 8, "num_nodes": 7, "accelerator": "auto"}),
65+
({"strategy": "ddp_spawn"}, {"strategy": "ddp"}),
66+
({"strategy": "ddp_sharded_spawn"}, {"strategy": "ddp_sharded"}),
67+
],
68+
)
69+
@pytest.mark.skipif(not module_available("pytorch"), reason="Lightning is not available")
70+
def test_trainer_run_executor_arguments_choices(
71+
args_given: dict,
72+
args_expected: dict,
73+
):
74+
75+
if pl.accelerators.MPSAccelerator.is_available():
76+
args_expected.pop("accelerator", None) # Cross platform tests -> MPS is tested separately
77+
78+
ret_val, env_vars = _get_args_after_tracer_injection(**args_given)
79+
80+
for k, v in args_expected.items():
81+
assert ret_val[k] == v
82+
83+
assert env_vars["MASTER_ADDR"] == "1.2.3.4"
84+
assert env_vars["MASTER_PORT"] == "5"
85+
assert env_vars["GROUP_RANK"] == "6"
86+
assert env_vars["RANK"] == str(0 + 6 * 8)
87+
assert env_vars["LOCAL_RANK"] == "0"
88+
assert env_vars["WORLD_SIZE"] == str(7 * 8)
89+
assert env_vars["LOCAL_WORLD_SIZE"] == "8"
90+
assert env_vars["TORCHELASTIC_RUN_ID"] == "1"
91+
92+
93+
@pytest.mark.skipif(not module_available("lightning"), reason="lightning not available")
94+
def test_trainer_run_executor_invalid_strategy_instances():
95+
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
96+
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnStrategy())
97+
98+
with pytest.raises(ValueError, match="DDP Spawned strategies aren't supported yet."):
99+
_, _ = _get_args_after_tracer_injection(strategy=pl.strategies.DDPSpawnShardedStrategy())

tests/tests_app/utilities/packaging/test_build_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_build_config_requirements_provided():
2929
assert spec.requirements == [
3030
"dask",
3131
"pandas",
32-
"pytorch_" + "lightning==1.5.9", # ugly hack due to replacing `pytorch_lightning string`
32+
"pytorch_lightning==1.5.9",
3333
"git+https://github.com/mit-han-lab/[email protected]",
3434
]
3535
assert spec == BuildConfig.from_dict(spec.to_dict())
@@ -50,7 +50,7 @@ def test_build_config_dockerfile_provided():
5050
spec = BuildConfig(dockerfile="./projects/Dockerfile.cpu")
5151
assert not spec.requirements
5252
# ugly hack due to replacing `pytorch_lightning string
53-
assert "pytorchlightning/pytorch_" + "lightning" in spec.dockerfile.data[0]
53+
assert "pytorchlightning/pytorch_lightning" in spec.dockerfile.data[0]
5454

5555

5656
class DockerfileLightningTestApp(LightningTestApp):

0 commit comments

Comments
 (0)