Skip to content

Commit e65d344

Browse files
Support multi-run with hydra + DDP
1 parent 664aa5b commit e65d344

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

src/lightning/fabric/strategies/launchers/subprocess_script.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import subprocess
1616
import sys
17+
from pathlib import Path
1718
from typing import Any, Callable, Optional, Sequence, Tuple
1819

1920
from lightning_utilities.core.imports import RequirementCache
@@ -143,6 +144,8 @@ def _basic_subprocess_cmd() -> Sequence[str]:
143144

144145
def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
145146
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
147+
from hydra.core.hydra_config import HydraConfig
148+
from hydra.types import RunMode
146149
from hydra.utils import get_original_cwd, to_absolute_path
147150

148151
# when user is using hydra find the absolute path
@@ -151,9 +154,18 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
151154
else:
152155
command = [sys.executable, "-m", __main__.__spec__.name]
153156

154-
command += sys.argv[1:]
155-
156157
cwd = get_original_cwd()
157-
os_cwd = f'"{os.getcwd()}"'
158-
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
158+
hydra_cfg = HydraConfig.get()
159+
run_dir = Path(hydra_cfg.runtime.output_dir)
160+
161+
if hydra_cfg.output_subdir is None: # config isn't saved, so re-run original command
162+
if hydra_cfg.mode == RunMode.MULTIRUN:
163+
raise RuntimeError(f"DDP with multirun requires saved config file")
164+
command += sys.argv[1:]
165+
else:
166+
hydra_subdir = run_dir / hydra_cfg.output_subdir
167+
command += ["-cp", str(hydra_subdir), "-cn", "config.yaml"] # Used saved config for new run
168+
command += [f"hydra.output_subdir=.pl_ddp_hydra_{local_rank}"] # Log to different subdir
169+
170+
command += [f"hydra.run.dir={run_dir}", f"hydra.job.name=train_ddp_process_{local_rank}"]
159171
return command, cwd

tests/tests_pytorch/strategies/launchers/test_subprocess_script.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import subprocess
22
import sys
3+
from pathlib import Path
34
from unittest.mock import Mock
45

56
import pytest
@@ -13,6 +14,7 @@
1314

1415
if _HYDRA_WITH_RUN_PROCESS:
1516
from hydra.test_utils.test_utils import run_process
17+
from omegaconf import OmegaConf
1618

1719

1820
# Script to run from command line
@@ -48,7 +50,7 @@ def task_fn(cfg):
4850

4951
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
5052
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
51-
@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"])
53+
@pytest.mark.parametrize("subdir", [None, "null", "dksa", ".hello"])
5254
def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
5355
monkeypatch.chdir(tmpdir)
5456

@@ -58,11 +60,56 @@ def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
5860

5961
# Run CLI
6062
devices = 2
61-
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"']
63+
run_dir = Path(tmpdir) / "hydra_output"
64+
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"', f"hydra.run.dir={run_dir}"]
6265
if subdir is not None:
6366
cmd += [f"hydra.output_subdir={subdir}"]
6467
run_process(cmd)
6568

69+
# Make sure config.yaml was created for additional processes iff subdir is present.
70+
saved_confs = list(run_dir.glob("**/config.yaml"))
71+
assert len(saved_confs) == (0 if subdir == "null" else devices)
72+
73+
if saved_confs: # Make sure the parameter was set and used
74+
cfg = OmegaConf.load(saved_confs[0])
75+
assert cfg.devices == devices
76+
77+
# Make sure PL spawned jobs that are logged by Hydra
78+
logs = list(run_dir.glob("**/*.log"))
79+
assert len(logs) == devices
80+
81+
82+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
83+
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
84+
@pytest.mark.parametrize("num_jobs", [1, 2])
85+
def test_ddp_with_hydra_multirunjob(tmpdir, num_jobs, monkeypatch):
86+
monkeypatch.chdir(tmpdir)
87+
88+
# Save script locally
89+
with open("temp.py", "w") as fn:
90+
fn.write(script)
91+
92+
# Run CLI
93+
devices = 2
94+
sweep_dir = Path(tmpdir) / "hydra_output"
95+
command = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"', f"hydra.sweep.dir={sweep_dir}"]
96+
command += ["--multirun", "+foo=" + ",".join(str(i) for i in range(num_jobs))] # fake multirun params
97+
run_process(command)
98+
99+
# Make sure config.yaml was created for each job
100+
saved_confs = list(sweep_dir.glob("**/config.yaml"))
101+
assert len(saved_confs) == devices * num_jobs
102+
103+
# Make sure the parameter was set and used for each job
104+
for config in saved_confs:
105+
cfg = OmegaConf.load(config)
106+
local_rank = int(config.parent.parent.parts[-1])
107+
assert cfg.devices == devices
108+
assert cfg.foo == local_rank
109+
110+
logs = list(sweep_dir.glob("**/*.log"))
111+
assert len(logs) == devices * num_jobs
112+
66113

67114
def test_kill():
68115
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)

0 commit comments

Comments
 (0)