Skip to content

Commit f2d57e0

Browse files
Hydra + DDP improvements
* Create different hydra output subdirectories for processes started by DDP * Support experimental-rerun * If rerun is not enabled but multi-run used, raise explicit error Reverts parts of Lightning-AI#15737
1 parent 324d90a commit f2d57e0

File tree

2 files changed

+114
-4
lines changed

2 files changed

+114
-4
lines changed

src/lightning/fabric/strategies/launchers/subprocess_script.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import subprocess
1616
import sys
17+
from pathlib import Path
1718
from typing import Any, Callable, Optional, Sequence, Tuple
1819

1920
from lightning_utilities.core.imports import RequirementCache
@@ -143,6 +144,8 @@ def _basic_subprocess_cmd() -> Sequence[str]:
143144

144145
def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
145146
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
147+
from hydra.core.hydra_config import HydraConfig
148+
from hydra.types import RunMode
146149
from hydra.utils import get_original_cwd, to_absolute_path
147150

148151
# when user is using hydra find the absolute path
@@ -151,9 +154,22 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
151154
else:
152155
command = [sys.executable, "-m", __main__.__spec__.name]
153156

154-
command += sys.argv[1:]
157+
# extract the hydra configuration
158+
hydra_cfg = HydraConfig.get()
159+
hydra_output = Path(hydra_cfg.runtime.output_dir)
160+
161+
if hydra_cfg.output_subdir is None: # config isn't saved, so re-run original command
162+
if hydra_cfg.mode == RunMode.MULTIRUN:
163+
raise RuntimeError(f"DDP with multirun requires either re-run callback or saved config file")
164+
command += sys.argv[1:] + [f"hydra.run.dir={hydra_output}"] # Keep output directory same
165+
else:
166+
hydra_subdir = hydra_output / hydra_cfg.output_subdir
167+
pickled_config_path = hydra_subdir / "config.pickle"
168+
if pickled_config_path.exists():
169+
command += ["--experimental-rerun", str(pickled_config_path)]
170+
else:
171+
command += ["-cp", str(hydra_subdir), "-cn", "config.yaml"] # Used saved config for new run
172+
command += [f"hydra.output_subdir=.pl_ddp_hydra_{local_rank}", f"hydra.run.dir={hydra_output}"]
155173

156174
cwd = get_original_cwd()
157-
os_cwd = f'"{os.getcwd()}"'
158-
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
159175
return command, cwd

tests/tests_pytorch/strategies/launchers/test_subprocess_script.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import subprocess
22
import sys
3+
from pathlib import Path
34
from unittest.mock import Mock
45

56
import pytest
@@ -13,6 +14,7 @@
1314

1415
if _HYDRA_WITH_RUN_PROCESS:
1516
from hydra.test_utils.test_utils import run_process
17+
from omegaconf import OmegaConf
1618

1719

1820
# Script to run from command line
@@ -48,7 +50,7 @@ def task_fn(cfg):
4850

4951
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
5052
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
51-
@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"])
53+
@pytest.mark.parametrize("subdir", [None, "null", "dksa", ".hello"])
5254
def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
5355
monkeypatch.chdir(tmpdir)
5456

@@ -63,6 +65,98 @@ def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
6365
cmd += [f"hydra.output_subdir={subdir}"]
6466
run_process(cmd)
6567

68+
if subdir == "null": # There should be no subdirectory created
69+
# Make sure there's no config.yaml
70+
logs = list(Path.cwd().glob("**/config.yaml"))
71+
assert len(logs) == 0
72+
else:
73+
# Make sure config.yaml was created for additional processes.
74+
logs = list(Path.cwd().glob("**/config.yaml"))
75+
assert len(logs) == devices
76+
77+
# Make sure the parameter was set and used
78+
cfg = OmegaConf.load(logs[0])
79+
assert cfg.devices == devices
80+
81+
# Make sure PL spawned a job that is logged by Hydra
82+
logs = list(Path.cwd().glob("**/*.log"))
83+
assert len(logs) == 1
84+
85+
86+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
87+
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
88+
@pytest.mark.parametrize("num_jobs", [1, 2])
89+
def test_ddp_with_hydra_multirunjob(tmpdir, num_jobs, monkeypatch):
90+
monkeypatch.chdir(tmpdir)
91+
92+
# Save script locally
93+
with open("temp.py", "w") as fn:
94+
fn.write(script)
95+
96+
# create fake multirun params based on `num_jobs`
97+
fake_param = "+foo=" + ",".join(str(i) for i in range(num_jobs))
98+
99+
# Run CLI
100+
run_process([sys.executable, "temp.py", "+devices=2", '+strategy="ddp"', fake_param, "--multirun"])
101+
102+
# Make sure config.yaml was created for each job
103+
configs = sorted(Path.cwd().glob("**/.pl_ddp_hydra_*/config.yaml"))
104+
assert len(configs) == num_jobs
105+
106+
# Make sure the parameter was set and used for each job
107+
for i, config in enumerate(configs):
108+
cfg = OmegaConf.load(config)
109+
local_rank = int(config.parent.parent.parts[-1])
110+
assert cfg.devices == 2
111+
assert cfg.foo == local_rank
112+
113+
logs = list(Path.cwd().glob("**/*.log"))
114+
assert len(logs) == num_jobs
115+
116+
117+
yaml_file = """
118+
hydra:
119+
callbacks:
120+
save_job_info:
121+
_target_: hydra.experimental.callbacks.PickleJobInfoCallback
122+
"""
123+
124+
125+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
126+
@pytest.mark.skipif(not _HYDRA_WITH_RERUN, reason=str(_HYDRA_WITH_RERUN))
127+
@pytest.mark.parametrize("num_jobs", [1, 2])
128+
def test_ddp_with_hydra_multirunjob_rerun(tmpdir, num_jobs, monkeypatch):
129+
monkeypatch.chdir(tmpdir)
130+
131+
# Save script locally
132+
with open("temp.py", "w") as fn:
133+
fn.write(script)
134+
135+
with open("config.yaml", "w") as fn:
136+
fn.write(yaml_file)
137+
138+
# create fake multirun params based on `num_jobs`
139+
fake_param = "+foo=" + ",".join(str(i) for i in range(num_jobs))
140+
141+
# Run CLI
142+
run_process(
143+
[
144+
sys.executable,
145+
"temp.py",
146+
"-cp",
147+
".",
148+
"-cn",
149+
"config.yaml",
150+
"+devices=2",
151+
'+strategy="ddp"',
152+
fake_param,
153+
"--multirun",
154+
]
155+
)
156+
157+
pickles = sorted(Path.cwd().glob("**/.hydra/config.pickle"))
158+
assert len(pickles) == num_jobs
159+
66160

67161
def test_kill():
68162
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)

0 commit comments

Comments
 (0)