Skip to content

Commit 7d36db8

Browse files
awaelchlipre-commit-ci[bot]Borda
authored
Disable strict loading in multiprocessing launcher (#16365)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka <[email protected]>
1 parent 85786d0 commit 7d36db8

File tree

4 files changed

+44
-29
lines changed

4 files changed

+44
-29
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6363
- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915))
6464
- Overriding the `on_train_batch_{start,end}` hooks in conjunction with taking a `dataloader_iter` in the `training_step` no longer errors out and instead shows a warning ([#16062](https://github.com/Lightning-AI/lightning/pull/16062))
6565
- Move `tensorboardX` to extra dependencies. Use the `CSVLogger` by default ([#16349](https://github.com/Lightning-AI/lightning/pull/16349))
66-
66+
- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))
6767

6868
### Deprecated
6969

src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train
153153
# load last weights
154154
if worker_output.weights_path is not None:
155155
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
156-
trainer.lightning_module.load_state_dict(ckpt)
156+
# choose non-strict loading of parameters on the main process, because the model's composition
157+
# could have changed in the worker process (layers added or removed)
158+
trainer.lightning_module.load_state_dict(ckpt, strict=False)
157159
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)
158160

159161
trainer.state = worker_output.trainer_state

tests/tests_pytorch/strategies/launchers/test_multiprocessing.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
103103
strategy._local_rank = fake_local_rank
104104

105105
launcher = _MultiProcessingLauncher(strategy=strategy)
106-
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
106+
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
107107

108108
assert strategy.node_rank == fake_node_rank
109109
assert strategy.local_rank == fake_local_rank
@@ -124,3 +124,42 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank,
124124
else:
125125
# all other ranks don't have outputs (rank 0 needs to handle the output)
126126
assert spawn_output is None
127+
128+
129+
@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
130+
def test_transfer_weights(tmpdir, trainer_fn):
131+
"""Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the
132+
temporary file."""
133+
model = Mock(wraps=BoringModel(), spec=BoringModel)
134+
strategy = DDPSpawnStrategy()
135+
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
136+
trainer.strategy.connect(model)
137+
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state
138+
139+
spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
140+
141+
model.state_dict.assert_called_once()
142+
if trainer_fn == TrainerFn.FITTING:
143+
assert spawn_output.weights_path.endswith(".temp.ckpt")
144+
assert os.path.isfile(spawn_output.weights_path)
145+
else:
146+
assert spawn_output.weights_path is None
147+
148+
# <-- here would normally be the multiprocessing boundary
149+
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
150+
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)
151+
152+
153+
def test_non_strict_loading(tmpdir):
154+
"""Tests that the multiprocessing launcher loads the weights back into the main process but with strict loading
155+
disabled, not erroring for missing keys."""
156+
model = Mock(wraps=BoringModel(), spec=BoringModel)
157+
strategy = DDPSpawnStrategy()
158+
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy)
159+
trainer.strategy.connect(model)
160+
trainer.state.fn = TrainerFn.FITTING # state dict loading only relevant for the FITTING case
161+
162+
spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
163+
# <-- here would normally be the multiprocessing boundary
164+
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
165+
model.load_state_dict.assert_called_once_with(ANY, strict=False)

tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
from datetime import timedelta
1615
from unittest import mock
17-
from unittest.mock import Mock
1816

1917
import pytest
2018
import torch
@@ -126,30 +124,6 @@ def test_ddp_spawn_configure_ddp(tmpdir):
126124
trainer.predict(model, dataloaders=model.predict_dataloader())
127125

128126

129-
@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
130-
def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
131-
"""Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary
132-
file."""
133-
model = Mock(wraps=BoringModel(), spec=BoringModel)
134-
strategy = DDPSpawnStrategy()
135-
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy)
136-
trainer.strategy.connect(model)
137-
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state
138-
139-
spawn_output = strategy._launcher._collect_rank_zero_results(trainer, {})
140-
141-
model.state_dict.assert_called_once()
142-
if trainer_fn == TrainerFn.FITTING:
143-
assert spawn_output.weights_path.endswith(".temp.ckpt")
144-
assert os.path.isfile(spawn_output.weights_path)
145-
else:
146-
assert spawn_output.weights_path is None
147-
148-
# <-- here would normally be the multiprocessing boundary
149-
strategy._launcher._recover_results_in_main_process(spawn_output, trainer)
150-
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)
151-
152-
153127
@mock.patch("torch.distributed.init_process_group")
154128
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
155129
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""

0 commit comments

Comments
 (0)