From 82a7db02db70c5565cd69bcd2ad1fd7cce837a24 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 7 Nov 2022 12:10:36 +0100 Subject: [PATCH 01/10] Fix result transfer in multiprocessing launcher on multi-node --- src/pytorch_lightning/strategies/launchers/multiprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index fc9723b36550e..04a6cef786cfe 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -172,7 +172,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() - if self._strategy.global_rank != 0: + if self._strategy.local_rank != 0: return None # save the last weights From 2f1aed7d6aefa37f874072a840856347e92d74f1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 7 Nov 2022 16:30:36 +0100 Subject: [PATCH 02/10] add simple test --- .../launchers/test_multiprocessing.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index ef1a5ccce1547..d0d33d62faaf2 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -11,15 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from pathlib import Path from unittest import mock from unittest.mock import ANY, Mock import pytest import torch +from lightning_lite.plugins.environments.debug import _DebugEnvironment +from pytorch_lightning import Trainer +from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher from tests_pytorch.helpers.runif import RunIf +from pytorch_lightning.trainer.states import TrainerFn + @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) def test_multiprocessing_launcher_forking_on_unsupported_platform(_): @@ -76,3 +84,43 @@ def test_global_state_snapshot(): assert torch.are_deterministic_algorithms_enabled() assert not torch.backends.cudnn.benchmark assert torch.initial_seed() == 123 + + +@pytest.mark.parametrize("trainer_fn", [ + TrainerFn.FITTING, + "other", +]) +@pytest.mark.parametrize("fake_node_rank", [0, 1]) +@pytest.mark.parametrize("fake_local_rank", [0, 1]) +def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir, monkeypatch): + """Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary + file.""" + monkeypatch.setitem(os.environ, "NODE_RANK", str(fake_node_rank)) + monkeypatch.setitem(os.environ, "LOCAL_RANK", str(fake_local_rank)) + + model = Mock(wraps=BoringModel(), spec=BoringModel) + fake_global_rank = 2 * fake_node_rank + fake_local_rank + strategy = DDPSpawnStrategy(cluster_environment=_DebugEnvironment(world_size=4, node_rank=fake_node_rank, local_rank=fake_local_rank, global_rank=fake_global_rank)) + strategy._local_rank = fake_local_rank + + launcher = _MultiProcessingLauncher(strategy=strategy) + trainer = Trainer(default_root_dir=tmpdir, strategy=strategy) + + assert strategy.node_rank == fake_node_rank + assert strategy.local_rank == fake_local_rank + assert strategy.global_rank == fake_global_rank + + trainer.strategy.connect(model) + trainer.state.fn = trainer_fn # pretend we are in a particular trainer state + temp_file = Path(tmpdir, ".temp.ckpt") + + assert not temp_file.exists() + spawn_output = launcher._collect_rank_zero_results(trainer, {}) + + model.state_dict.assert_called_once() + is_fitting = trainer_fn == TrainerFn.FITTING + if strategy.local_rank == 0: + assert spawn_output.weights_path == (str(temp_file) if is_fitting else None) + assert not is_fitting or temp_file.exists() + else: + assert spawn_output is None From 179ccbebe2a9f7e8af0658c14e4ce56768d7e2cb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 8 Nov 2022 01:01:47 +0100 Subject: [PATCH 03/10] add comment --- .../tests_pytorch/strategies/launchers/test_multiprocessing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index d0d33d62faaf2..256de4ac2b3ae 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -120,7 +120,9 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, model.state_dict.assert_called_once() is_fitting = trainer_fn == TrainerFn.FITTING if strategy.local_rank == 0: + # on local rank 0 (each node), we expect a temp checkpoint (when fitting) assert spawn_output.weights_path == (str(temp_file) if is_fitting else None) assert not is_fitting or temp_file.exists() else: + # all other ranks don't have outputs (rank 0 needs to handle the output) assert spawn_output is None From 7bb28bd3fd7637a71261b8e13f0f6b9a724800a6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 8 Nov 2022 01:13:34 +0100 Subject: [PATCH 04/10] update test --- .../launchers/test_multiprocessing.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 256de4ac2b3ae..7499d341fc44c 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from pathlib import Path from unittest import mock from unittest.mock import ANY, Mock @@ -19,14 +18,13 @@ import pytest import torch -from lightning_lite.plugins.environments.debug import _DebugEnvironment +from lightning_lite.plugins import ClusterEnvironment from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher -from tests_pytorch.helpers.runif import RunIf - from pytorch_lightning.trainer.states import TrainerFn +from tests_pytorch.helpers.runif import RunIf @mock.patch("pytorch_lightning.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) @@ -86,21 +84,22 @@ def test_global_state_snapshot(): assert torch.initial_seed() == 123 -@pytest.mark.parametrize("trainer_fn", [ - TrainerFn.FITTING, - "other", -]) +@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"]) @pytest.mark.parametrize("fake_node_rank", [0, 1]) @pytest.mark.parametrize("fake_local_rank", [0, 1]) -def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir, monkeypatch): +def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir): """Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary file.""" - monkeypatch.setitem(os.environ, "NODE_RANK", str(fake_node_rank)) - monkeypatch.setitem(os.environ, "LOCAL_RANK", str(fake_local_rank)) - model = Mock(wraps=BoringModel(), spec=BoringModel) fake_global_rank = 2 * fake_node_rank + fake_local_rank - strategy = DDPSpawnStrategy(cluster_environment=_DebugEnvironment(world_size=4, node_rank=fake_node_rank, local_rank=fake_local_rank, global_rank=fake_global_rank)) + + cluster_environment = Mock(spec=ClusterEnvironment) + cluster_environment.world_size.return_value = 4 + cluster_environment.node_rank.return_value = fake_node_rank + cluster_environment.local_rank.return_value = fake_local_rank + cluster_environment.global_rank.return_value = fake_global_rank + + strategy = DDPSpawnStrategy(cluster_environment=cluster_environment) strategy._local_rank = fake_local_rank launcher = _MultiProcessingLauncher(strategy=strategy) From 4b162c983a0a71d3d02df21561c6f1f74b4f3898 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 8 Nov 2022 01:18:02 +0100 Subject: [PATCH 05/10] changelog --- src/pytorch_lightning/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 2cb24cb961f75..f8d9ed40924ac 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `WandbLogger(log_model=True|'all)` raising an error and not being able to serialize tensors in the metadata ([#15544](https://github.com/Lightning-AI/lightning/pull/15544)) +- Fixed model state transfer in multiprocessing launcher when running multi-node ([#15567](https://github.com/Lightning-AI/lightning/pull/15567)) + ## [1.8.0] - 2022-11-01 From 592ecab0577b3d109975edf714fbfe9c49142ccf Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 8 Nov 2022 02:19:16 +0100 Subject: [PATCH 06/10] use tempfile --- src/pytorch_lightning/strategies/launchers/multiprocessing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 04a6cef786cfe..19a013844307b 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import tempfile from collections import UserList from dataclasses import dataclass from multiprocessing.queues import SimpleQueue @@ -178,7 +179,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: - weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + weights_path = os.path.join(tempfile.mkdtemp(), ".temp.ckpt") self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) # adds the `callback_metrics` to the queue From d34e343aeeeb5092e42e7e34a2c7954466922831 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 8 Nov 2022 02:48:48 +0100 Subject: [PATCH 07/10] fix --- .../strategies/launchers/test_multiprocessing.py | 7 +++---- .../tests_pytorch/strategies/test_ddp_spawn_strategy.py | 9 +++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 7499d341fc44c..fdaa649f02b6e 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from pathlib import Path from unittest import mock from unittest.mock import ANY, Mock @@ -111,17 +112,15 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, trainer.strategy.connect(model) trainer.state.fn = trainer_fn # pretend we are in a particular trainer state - temp_file = Path(tmpdir, ".temp.ckpt") - assert not temp_file.exists() spawn_output = launcher._collect_rank_zero_results(trainer, {}) model.state_dict.assert_called_once() is_fitting = trainer_fn == TrainerFn.FITTING if strategy.local_rank == 0: # on local rank 0 (each node), we expect a temp checkpoint (when fitting) - assert spawn_output.weights_path == (str(temp_file) if is_fitting else None) - assert not is_fitting or temp_file.exists() + assert not is_fitting or spawn_output.weights_path.endswith(".temp.ckpt") + assert not is_fitting or os.path.isfile(spawn_output.weights_path) else: # all other ranks don't have outputs (rank 0 needs to handle the output) assert spawn_output is None diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py index 7c1d347970b43..6e5775cb2523f 100644 --- a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from datetime import timedelta from pathlib import Path from unittest import mock @@ -135,23 +136,19 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): trainer = Trainer(default_root_dir=tmpdir, strategy=strategy) trainer.strategy.connect(model) trainer.state.fn = trainer_fn # pretend we are in a particular trainer state - temp_file = Path(tmpdir, ".temp.ckpt") - assert not temp_file.exists() spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {}) model.state_dict.assert_called_once() if trainer_fn == TrainerFn.FITTING: - assert spawn_output.weights_path == str(temp_file) - assert temp_file.exists() + assert spawn_output.weights_path.endswith(".temp.ckpt") + assert os.path.isfile(spawn_output.weights_path) else: assert spawn_output.weights_path is None - assert not temp_file.exists() # <-- here would normally be the multiprocessing boundary strategy._launcher._recover_results_in_main_process(spawn_output, trainer) assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None) - assert not temp_file.exists() @mock.patch("torch.distributed.init_process_group") From 585d69127b11aae109f38d83b98c85b3a33af610 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 8 Nov 2022 02:50:32 +0100 Subject: [PATCH 08/10] assert None --- tests/tests_pytorch/strategies/launchers/test_multiprocessing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index fdaa649f02b6e..957c2f57296ba 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -121,6 +121,7 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, # on local rank 0 (each node), we expect a temp checkpoint (when fitting) assert not is_fitting or spawn_output.weights_path.endswith(".temp.ckpt") assert not is_fitting or os.path.isfile(spawn_output.weights_path) + assert is_fitting or spawn_output.weights_path is None else: # all other ranks don't have outputs (rank 0 needs to handle the output) assert spawn_output is None From b2b17b24978dcfb7264091b701c68c9e66da5fd8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 8 Nov 2022 03:36:02 +0100 Subject: [PATCH 09/10] unused import --- tests/tests_pytorch/strategies/launchers/test_multiprocessing.py | 1 - tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 957c2f57296ba..142f6c53d2015 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from pathlib import Path from unittest import mock from unittest.mock import ANY, Mock diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py index 6e5775cb2523f..22a30b927b1e4 100644 --- a/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py @@ -13,7 +13,6 @@ # limitations under the License. import os from datetime import timedelta -from pathlib import Path from unittest import mock from unittest.mock import Mock From 3ceb5525f5e5fdec0640614482c7d5340b32fe00 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 8 Nov 2022 12:02:13 +0100 Subject: [PATCH 10/10] add comment --- src/pytorch_lightning/strategies/launchers/multiprocessing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 19a013844307b..1f225d749d1fe 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -179,6 +179,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: + # use tempdir here to avoid race conditions because the filesystem may be shared between nodes weights_path = os.path.join(tempfile.mkdtemp(), ".temp.ckpt") self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)