Skip to content

Commit 822a7f5

Browse files
awaelchlikrshrimalicarmocca
authored
Align ddp and ddp-spawn strategies in setting up the environment (#11073)
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 3a70e5d commit 822a7f5

File tree

11 files changed

+30
-33
lines changed

11 files changed

+30
-33
lines changed

src/lightning_lite/lite.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,16 +380,17 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
380380
return seed_everything(seed=seed, workers=workers)
381381

382382
def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
383-
# apply sharded context to prevent OOM
384-
run_method = partial(self._run_with_strategy_setup, run_method)
383+
# wrap the real run method with setup logic
384+
run_method = partial(self._run_with_setup, run_method)
385385

386386
if self._strategy.launcher is not None:
387387
return self._strategy.launcher.launch(run_method, *args, **kwargs)
388388
else:
389389
return run_method(*args, **kwargs)
390390

391-
def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
391+
def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
392392
self._strategy.setup_environment()
393+
# apply sharded context to prevent OOM
393394
with self._strategy.module_sharded_context(), _replace_dunder_methods(
394395
DataLoader, "dataset"
395396
), _replace_dunder_methods(BatchSampler):

src/lightning_lite/strategies/xla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def is_distributed(self) -> bool:
9494
def _configure_launcher(self) -> None:
9595
self._launcher = _XLALauncher(self)
9696

97-
def setup_environment(self) -> None:
97+
def _setup_distributed(self) -> None:
9898
self._launched = True
9999
self._set_world_ranks()
100100
rank_zero_only.rank = self.global_rank

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9393
- Removed fall-back to `LightningEnvironment` when number of SLURM tasks does not correspond to number of processes in Trainer ([#14300](https://github.com/Lightning-AI/lightning/pull/14300))
9494

9595

96+
- Aligned DDP and DDPSpawn strategies in setting up the environment ([#11073](https://github.com/Lightning-AI/lightning/pull/11073))
97+
98+
9699
- Integrated the Lite Precision plugins into the PL Precision plugins - the base class in PL now extends the `lightning_lite.precision.Precision` base class ([#14798](https://github.com/Lightning-AI/lightning/pull/14798))
97100
* The `PrecisionPlugin.backward` signature changed: The `closure_loss` argument was renamed to `tensor`
98101
* The `PrecisionPlugin.{pre_,post_}backward` signature changed: The `closure_loss` argument was renamed to `tensor` and moved as the first argument

src/pytorch_lightning/strategies/ddp.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,8 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
196196
def setup_distributed(self) -> None:
197197
log.detail(f"{self.__class__.__name__}: setting up distributed...")
198198
reset_seed()
199-
200-
# determine which process we are and world size
201199
self.set_world_ranks()
202-
203-
# set warning rank
204200
rank_zero_only.rank = self.global_rank
205-
206201
self._process_group_backend = self._get_process_group_backend()
207202
assert self.cluster_environment is not None
208203
init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

src/pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def process_group_backend(self) -> Optional[str]:
133133
def _configure_launcher(self) -> None:
134134
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)
135135

136+
def setup_environment(self) -> None:
137+
self.setup_distributed()
138+
super().setup_environment()
139+
136140
def setup(self, trainer: "pl.Trainer") -> None:
137141

138142
assert self.cluster_environment is not None
@@ -160,16 +164,9 @@ def _setup_model(self, model: Module) -> DistributedDataParallel:
160164
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
161165
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
162166

163-
def set_world_ranks(self, process_idx: int = 0) -> None:
164-
self._local_rank = process_idx
165-
if self.cluster_environment is None:
166-
return
167-
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
168-
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
169-
rank_zero_only.rank = self.cluster_environment.global_rank()
170-
171-
def _worker_setup(self, process_idx: int) -> None:
172-
self.set_world_ranks(process_idx)
167+
def setup_distributed(self) -> None:
168+
log.detail(f"{self.__class__.__name__}: setting up distributed...")
169+
self.set_world_ranks()
173170
rank_zero_only.rank = self.global_rank
174171
self._process_group_backend = self._get_process_group_backend()
175172
assert self.cluster_environment is not None
@@ -181,6 +178,13 @@ def _worker_setup(self, process_idx: int) -> None:
181178
timeout=self._timeout,
182179
)
183180

181+
def set_world_ranks(self) -> None:
182+
if self.cluster_environment is None:
183+
return
184+
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
185+
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
186+
rank_zero_only.rank = self.cluster_environment.global_rank()
187+
184188
def _get_process_group_backend(self) -> str:
185189
return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device)
186190

src/pytorch_lightning/strategies/deepspeed.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from pytorch_lightning.utilities import GradClipAlgorithmType
4646
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4747
from pytorch_lightning.utilities.model_helpers import is_overridden
48-
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
48+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn
4949
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT
5050

5151
log = logging.getLogger(__name__)
@@ -348,12 +348,9 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
348348

349349
def setup_distributed(self) -> None:
350350
reset_seed()
351-
352-
# determine which process we are and world size
353351
self.set_world_ranks()
354-
352+
rank_zero_only.rank = self.global_rank
355353
self._init_deepspeed_distributed()
356-
357354
if not self._config_initialized:
358355
self._format_config()
359356
self._config_initialized = True

src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _wrapping_function(
132132
) -> None:
133133
if global_states:
134134
global_states.restore()
135-
self._strategy._worker_setup(process_idx)
135+
self._strategy._local_rank = process_idx
136136
results = function(*args, **kwargs)
137137

138138
if trainer is not None:

src/pytorch_lightning/strategies/launchers/xla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _wrapping_function(
103103
return_queue: SimpleQueue,
104104
global_states: Optional[_GlobalStateSnapshot] = None,
105105
) -> None:
106-
self._strategy._worker_setup(process_idx)
106+
self._strategy._local_rank = process_idx
107107
results = function(*args, **kwargs)
108108

109109
if trainer is not None:

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ def reduce(
212212

213213
return output
214214

215-
def _worker_setup(self, process_idx: int) -> None:
215+
def setup_distributed(self) -> None:
216216
self._launched = True
217-
self.set_world_ranks(process_idx)
217+
self.set_world_ranks()
218218
rank_zero_only.rank = self.global_rank
219219

220220
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:

tests/tests_pytorch/strategies/test_ddp_spawn_strategy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,9 @@ def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
154154
assert not temp_file.exists()
155155

156156

157-
@RunIf(min_cuda_gpus=1)
158157
@mock.patch("torch.distributed.init_process_group")
159158
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
160-
"""Tests with ddp strategy."""
159+
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
161160
test_timedelta = timedelta(seconds=30)
162161
model = BoringModel()
163162
ddp_spawn_strategy = DDPSpawnStrategy(timeout=test_timedelta)
@@ -170,7 +169,6 @@ def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
170169
trainer.strategy.connect(model)
171170
trainer.lightning_module.trainer = trainer
172171
trainer.strategy.setup_environment()
173-
trainer.strategy._worker_setup(0)
174172

175173
process_group_backend = trainer.strategy._get_process_group_backend()
176174
global_rank = trainer.strategy.cluster_environment.global_rank()

0 commit comments

Comments
 (0)