|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
| 14 | +import os |
14 | 15 | from unittest import mock
|
15 | 16 | from unittest.mock import ANY, Mock
|
16 | 17 |
|
17 | 18 | import pytest
|
18 | 19 | import torch
|
19 | 20 |
|
| 21 | +from lightning_lite.plugins import ClusterEnvironment |
| 22 | +from pytorch_lightning import Trainer |
| 23 | +from pytorch_lightning.demos.boring_classes import BoringModel |
| 24 | +from pytorch_lightning.strategies import DDPSpawnStrategy |
20 | 25 | from pytorch_lightning.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
|
| 26 | +from pytorch_lightning.trainer.states import TrainerFn |
21 | 27 | from tests_pytorch.helpers.runif import RunIf
|
22 | 28 |
|
23 | 29 |
|
@@ -76,3 +82,45 @@ def test_global_state_snapshot():
|
76 | 82 | assert torch.are_deterministic_algorithms_enabled()
|
77 | 83 | assert not torch.backends.cudnn.benchmark
|
78 | 84 | assert torch.initial_seed() == 123
|
| 85 | + |
| 86 | + |
| 87 | +@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"]) |
| 88 | +@pytest.mark.parametrize("fake_node_rank", [0, 1]) |
| 89 | +@pytest.mark.parametrize("fake_local_rank", [0, 1]) |
| 90 | +def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir): |
| 91 | + """Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary |
| 92 | + file.""" |
| 93 | + model = Mock(wraps=BoringModel(), spec=BoringModel) |
| 94 | + fake_global_rank = 2 * fake_node_rank + fake_local_rank |
| 95 | + |
| 96 | + cluster_environment = Mock(spec=ClusterEnvironment) |
| 97 | + cluster_environment.world_size.return_value = 4 |
| 98 | + cluster_environment.node_rank.return_value = fake_node_rank |
| 99 | + cluster_environment.local_rank.return_value = fake_local_rank |
| 100 | + cluster_environment.global_rank.return_value = fake_global_rank |
| 101 | + |
| 102 | + strategy = DDPSpawnStrategy(cluster_environment=cluster_environment) |
| 103 | + strategy._local_rank = fake_local_rank |
| 104 | + |
| 105 | + launcher = _MultiProcessingLauncher(strategy=strategy) |
| 106 | + trainer = Trainer(default_root_dir=tmpdir, strategy=strategy) |
| 107 | + |
| 108 | + assert strategy.node_rank == fake_node_rank |
| 109 | + assert strategy.local_rank == fake_local_rank |
| 110 | + assert strategy.global_rank == fake_global_rank |
| 111 | + |
| 112 | + trainer.strategy.connect(model) |
| 113 | + trainer.state.fn = trainer_fn # pretend we are in a particular trainer state |
| 114 | + |
| 115 | + spawn_output = launcher._collect_rank_zero_results(trainer, {}) |
| 116 | + |
| 117 | + model.state_dict.assert_called_once() |
| 118 | + is_fitting = trainer_fn == TrainerFn.FITTING |
| 119 | + if strategy.local_rank == 0: |
| 120 | + # on local rank 0 (each node), we expect a temp checkpoint (when fitting) |
| 121 | + assert not is_fitting or spawn_output.weights_path.endswith(".temp.ckpt") |
| 122 | + assert not is_fitting or os.path.isfile(spawn_output.weights_path) |
| 123 | + assert is_fitting or spawn_output.weights_path is None |
| 124 | + else: |
| 125 | + # all other ranks don't have outputs (rank 0 needs to handle the output) |
| 126 | + assert spawn_output is None |
0 commit comments