1
1
import subprocess
2
2
import sys
3
+ from pathlib import Path
3
4
from unittest .mock import Mock
4
5
5
6
import pytest
13
14
14
15
if _HYDRA_WITH_RUN_PROCESS :
15
16
from hydra .test_utils .test_utils import run_process
17
+ from omegaconf import OmegaConf
16
18
17
19
18
20
# Script to run from command line
@@ -48,7 +50,7 @@ def task_fn(cfg):
48
50
49
51
@RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
50
52
@pytest .mark .skipif (not _HYDRA_WITH_RUN_PROCESS , reason = str (_HYDRA_WITH_RUN_PROCESS ))
51
- @pytest .mark .parametrize ("subdir" , [None , "dksa" , ".hello" ])
53
+ @pytest .mark .parametrize ("subdir" , [None , "null" , " dksa" , ".hello" ])
52
54
def test_ddp_with_hydra_runjob (subdir , tmpdir , monkeypatch ):
53
55
monkeypatch .chdir (tmpdir )
54
56
@@ -63,6 +65,98 @@ def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
63
65
cmd += [f"hydra.output_subdir={ subdir } " ]
64
66
run_process (cmd )
65
67
68
+ if subdir == "null" : # There should be no subdirectory created
69
+ # Make sure there's no config.yaml
70
+ logs = list (Path .cwd ().glob ("**/config.yaml" ))
71
+ assert len (logs ) == 0
72
+ else :
73
+ # Make sure config.yaml was created for additional processes.
74
+ logs = list (Path .cwd ().glob ("**/config.yaml" ))
75
+ assert len (logs ) == devices
76
+
77
+ # Make sure the parameter was set and used
78
+ cfg = OmegaConf .load (logs [0 ])
79
+ assert cfg .devices == devices
80
+
81
+ # Make sure PL spawned a job that is logged by Hydra
82
+ logs = list (Path .cwd ().glob ("**/*.log" ))
83
+ assert len (logs ) == 1
84
+
85
+
86
+ @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
87
+ @pytest .mark .skipif (not _HYDRA_WITH_RUN_PROCESS , reason = str (_HYDRA_WITH_RUN_PROCESS ))
88
+ @pytest .mark .parametrize ("num_jobs" , [1 , 2 ])
89
+ def test_ddp_with_hydra_multirunjob (tmpdir , num_jobs , monkeypatch ):
90
+ monkeypatch .chdir (tmpdir )
91
+
92
+ # Save script locally
93
+ with open ("temp.py" , "w" ) as fn :
94
+ fn .write (script )
95
+
96
+ # create fake multirun params based on `num_jobs`
97
+ fake_param = "+foo=" + "," .join (str (i ) for i in range (num_jobs ))
98
+
99
+ # Run CLI
100
+ run_process ([sys .executable , "temp.py" , "+devices=2" , '+strategy="ddp"' , fake_param , "--multirun" ])
101
+
102
+ # Make sure config.yaml was created for each job
103
+ configs = sorted (Path .cwd ().glob ("**/.pl_ddp_hydra_*/config.yaml" ))
104
+ assert len (configs ) == num_jobs
105
+
106
+ # Make sure the parameter was set and used for each job
107
+ for i , config in enumerate (configs ):
108
+ cfg = OmegaConf .load (config )
109
+ local_rank = int (config .parent .parent .parts [- 1 ])
110
+ assert cfg .devices == 2
111
+ assert cfg .foo == local_rank
112
+
113
+ logs = list (Path .cwd ().glob ("**/*.log" ))
114
+ assert len (logs ) == num_jobs
115
+
116
+
117
+ yaml_file = """
118
+ hydra:
119
+ callbacks:
120
+ save_job_info:
121
+ _target_: hydra.experimental.callbacks.PickleJobInfoCallback
122
+ """
123
+
124
+
125
+ @RunIf (min_cuda_gpus = 2 , skip_windows = True , standalone = True )
126
+ @pytest .mark .skipif (not _HYDRA_WITH_RERUN , reason = str (_HYDRA_WITH_RERUN ))
127
+ @pytest .mark .parametrize ("num_jobs" , [1 , 2 ])
128
+ def test_ddp_with_hydra_multirunjob_rerun (tmpdir , num_jobs , monkeypatch ):
129
+ monkeypatch .chdir (tmpdir )
130
+
131
+ # Save script locally
132
+ with open ("temp.py" , "w" ) as fn :
133
+ fn .write (script )
134
+
135
+ with open ("config.yaml" , "w" ) as fn :
136
+ fn .write (yaml_file )
137
+
138
+ # create fake multirun params based on `num_jobs`
139
+ fake_param = "+foo=" + "," .join (str (i ) for i in range (num_jobs ))
140
+
141
+ # Run CLI
142
+ run_process (
143
+ [
144
+ sys .executable ,
145
+ "temp.py" ,
146
+ "-cp" ,
147
+ "." ,
148
+ "-cn" ,
149
+ "config.yaml" ,
150
+ "+devices=2" ,
151
+ '+strategy="ddp"' ,
152
+ fake_param ,
153
+ "--multirun" ,
154
+ ]
155
+ )
156
+
157
+ pickles = sorted (Path .cwd ().glob ("**/.hydra/config.pickle" ))
158
+ assert len (pickles ) == num_jobs
159
+
66
160
67
161
def test_kill ():
68
162
launcher = _SubprocessScriptLauncher (Mock (), 1 , 1 )
0 commit comments