|
8 | 8 | from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
|
9 | 9 | from tests_pytorch.helpers.runif import RunIf
|
10 | 10 |
|
11 |
| -_HYDRA_WITH_RERUN = RequirementCache("hydra-core>=1.2") |
12 | 11 | _HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7")
|
13 | 12 |
|
14 | 13 | if _HYDRA_WITH_RUN_PROCESS:
|
15 | 14 | from hydra.test_utils.test_utils import run_process
|
| 15 | + from omegaconf import OmegaConf |
16 | 16 |
|
17 | 17 |
|
18 | 18 | # Script to run from command line
|
@@ -48,21 +48,34 @@ def task_fn(cfg):
|
48 | 48 |
|
49 | 49 | @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
50 | 50 | @pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
|
51 |
| -@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"]) |
52 |
| -def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch): |
53 |
| - monkeypatch.chdir(tmpdir) |
| 51 | +@pytest.mark.parametrize("subdir", [None, "null", "dksa", ".hello"]) |
| 52 | +def test_ddp_with_hydra_runjob(subdir, tmp_path, monkeypatch): |
| 53 | + monkeypatch.chdir(tmp_path) |
54 | 54 |
|
55 | 55 | # Save script locally
|
56 | 56 | with open("temp.py", "w") as fn:
|
57 | 57 | fn.write(script)
|
58 | 58 |
|
59 | 59 | # Run CLI
|
60 | 60 | devices = 2
|
61 |
| - cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"'] |
| 61 | + run_dir = tmp_path / "hydra_output" |
| 62 | + cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"', f"hydra.run.dir={run_dir}"] |
62 | 63 | if subdir is not None:
|
63 | 64 | cmd += [f"hydra.output_subdir={subdir}"]
|
64 | 65 | run_process(cmd)
|
65 | 66 |
|
| 67 | + # Make sure no config.yaml was created for additional processes |
| 68 | + saved_confs = list(run_dir.glob("**/config.yaml")) |
| 69 | + assert len(saved_confs) == (0 if subdir == "null" else 1) # Main process has config.yaml iff subdir!="null" |
| 70 | + |
| 71 | + if saved_confs: # Make sure the parameter was set and used |
| 72 | + cfg = OmegaConf.load(saved_confs[0]) |
| 73 | + assert cfg.devices == devices |
| 74 | + |
| 75 | + # Make sure PL spawned jobs that are logged by Hydra |
| 76 | + logs = list(run_dir.glob("**/*.log")) |
| 77 | + assert len(logs) == devices |
| 78 | + |
66 | 79 |
|
67 | 80 | def test_kill():
|
68 | 81 | launcher = _SubprocessScriptLauncher(Mock(), 1, 1)
|
|
0 commit comments