Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue with `WandbLogger(log_model=True|'all)` raising an error and not being able to serialize tensors in the metadata ([#15544](https://github.com/Lightning-AI/lightning/pull/15544))

- Fixed model state transfer in multiprocessing launcher when running multi-node ([#15567](https://github.com/Lightning-AI/lightning/pull/15567))


## [1.8.0] - 2022-11-01

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from collections import UserList
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
Expand Down Expand Up @@ -172,13 +173,14 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
# requires to compute the state_dict on all processes in case Metrics are present
state_dict = trainer.lightning_module.state_dict()

if self._strategy.global_rank != 0:
if self._strategy.local_rank != 0:
return None

# save the last weights
weights_path = None
if trainer.state.fn == TrainerFn.FITTING:
weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
# use tempdir here to avoid race conditions because the filesystem may be shared between nodes
weights_path = os.path.join(tempfile.mkdtemp(), ".temp.ckpt")
self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)

# adds the `callback_metrics` to the queue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import ANY, Mock

import pytest
import torch

from lightning_lite.plugins import ClusterEnvironment
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.runif import RunIf


Expand Down Expand Up @@ -76,3 +82,45 @@ def test_global_state_snapshot():
assert torch.are_deterministic_algorithms_enabled()
assert not torch.backends.cudnn.benchmark
assert torch.initial_seed() == 123


@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
@pytest.mark.parametrize("fake_node_rank", [0, 1])
@pytest.mark.parametrize("fake_local_rank", [0, 1])
def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir):
"""Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary
file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
fake_global_rank = 2 * fake_node_rank + fake_local_rank

cluster_environment = Mock(spec=ClusterEnvironment)
cluster_environment.world_size.return_value = 4
cluster_environment.node_rank.return_value = fake_node_rank
cluster_environment.local_rank.return_value = fake_local_rank
cluster_environment.global_rank.return_value = fake_global_rank

strategy = DDPSpawnStrategy(cluster_environment=cluster_environment)
strategy._local_rank = fake_local_rank

launcher = _MultiProcessingLauncher(strategy=strategy)
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)

assert strategy.node_rank == fake_node_rank
assert strategy.local_rank == fake_local_rank
assert strategy.global_rank == fake_global_rank

trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state

spawn_output = launcher._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
is_fitting = trainer_fn == TrainerFn.FITTING
if strategy.local_rank == 0:
# on local rank 0 (each node), we expect a temp checkpoint (when fitting)
assert not is_fitting or spawn_output.weights_path.endswith(".temp.ckpt")
assert not is_fitting or os.path.isfile(spawn_output.weights_path)
assert is_fitting or spawn_output.weights_path is None
else:
# all other ranks don't have outputs (rank 0 needs to handle the output)
assert spawn_output is None
10 changes: 3 additions & 7 deletions tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import timedelta
from pathlib import Path
from unittest import mock
from unittest.mock import Mock

Expand Down Expand Up @@ -135,23 +135,19 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
trainer.strategy.connect(model)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state
temp_file = Path(tmpdir, ".temp.ckpt")

assert not temp_file.exists()
spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path == str(temp_file)
assert temp_file.exists()
assert spawn_output.weights_path.endswith(".temp.ckpt")
assert os.path.isfile(spawn_output.weights_path)
else:
assert spawn_output.weights_path is None
assert not temp_file.exists()

# <-- here would normally be the multiprocessing boundary
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)
assert not temp_file.exists()


@mock.patch("torch.distributed.init_process_group")
Expand Down