Skip to content

Commit ebbd538

Browse files
Use hydra.run.dir (not os.getcwd) for DDP subprocesses' run dir (#18145)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 27b10ca commit ebbd538

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
147147

148148
### Fixed
149149

150+
- Fixed issue where DDP subprocesses that used Hydra would set hydra's working directory to current directory ([#18145](https://github.com/Lightning-AI/lightning/pull/18145))
151+
152+
150153
- Fixed issue where running on TPUs would select the wrong device index ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))
151154

152155

src/lightning/fabric/strategies/launchers/subprocess_script.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _basic_subprocess_cmd() -> Sequence[str]:
143143

144144
def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
145145
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
146+
from hydra.core.hydra_config import HydraConfig
146147
from hydra.utils import get_original_cwd, to_absolute_path
147148

148149
# when user is using hydra find the absolute path
@@ -154,6 +155,7 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
154155
command += sys.argv[1:]
155156

156157
cwd = get_original_cwd()
157-
os_cwd = f'"{os.getcwd()}"'
158-
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
158+
rundir = f'"{HydraConfig.get().run.dir}"'
159+
# Set output_subdir null since we don't want different subprocesses trying to write to config.yaml
160+
command += [f"hydra.run.dir={rundir}", f"hydra.job.name=train_ddp_process_{local_rank}", "hydra.output_subdir=null"]
159161
return command, cwd

tests/tests_pytorch/strategies/launchers/test_subprocess_script.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
99
from tests_pytorch.helpers.runif import RunIf
1010

11-
_HYDRA_WITH_RERUN = RequirementCache("hydra-core>=1.2")
1211
_HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7")
1312

1413
if _HYDRA_WITH_RUN_PROCESS:
1514
from hydra.test_utils.test_utils import run_process
15+
from omegaconf import OmegaConf
1616

1717

1818
# Script to run from command line
@@ -48,21 +48,34 @@ def task_fn(cfg):
4848

4949
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
5050
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
51-
@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"])
52-
def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
53-
monkeypatch.chdir(tmpdir)
51+
@pytest.mark.parametrize("subdir", [None, "null", "dksa", ".hello"])
52+
def test_ddp_with_hydra_runjob(subdir, tmp_path, monkeypatch):
53+
monkeypatch.chdir(tmp_path)
5454

5555
# Save script locally
5656
with open("temp.py", "w") as fn:
5757
fn.write(script)
5858

5959
# Run CLI
6060
devices = 2
61-
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"']
61+
run_dir = tmp_path / "hydra_output"
62+
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"', f"hydra.run.dir={run_dir}"]
6263
if subdir is not None:
6364
cmd += [f"hydra.output_subdir={subdir}"]
6465
run_process(cmd)
6566

67+
# Make sure no config.yaml was created for additional processes
68+
saved_confs = list(run_dir.glob("**/config.yaml"))
69+
assert len(saved_confs) == (0 if subdir == "null" else 1) # Main process has config.yaml iff subdir!="null"
70+
71+
if saved_confs: # Make sure the parameter was set and used
72+
cfg = OmegaConf.load(saved_confs[0])
73+
assert cfg.devices == devices
74+
75+
# Make sure PL spawned jobs that are logged by Hydra
76+
logs = list(run_dir.glob("**/*.log"))
77+
assert len(logs) == devices
78+
6679

6780
def test_kill():
6881
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)

0 commit comments

Comments
 (0)