22
22
import pytest
23
23
from jsonargparse import Namespace
24
24
25
- from lightning .fabric .cli import FabricCLI , _get_supported_strategies , main as _run_main
25
+ from lightning .fabric .cli import FabricCLI , _get_supported_strategies
26
+ from lightning .fabric .cli import main as _run_main
26
27
from lightning .fabric .utilities .consolidate_checkpoint import main as _consolidate_main
27
28
from tests_fabric .helpers .runif import RunIf
28
29
@@ -38,7 +39,17 @@ def fake_script(tmp_path):
38
39
def test_run_env_vars_defaults (monkeypatch , fake_script ):
39
40
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
40
41
with pytest .raises (SystemExit ) as e :
41
- args = Namespace (script = fake_script , accelerator = None , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
42
+ args = Namespace (
43
+ script = fake_script ,
44
+ accelerator = None ,
45
+ strategy = None ,
46
+ devices = "1" ,
47
+ num_nodes = 1 ,
48
+ node_rank = 0 ,
49
+ main_address = "127.0.0.1" ,
50
+ main_port = 29400 ,
51
+ precision = None ,
52
+ )
42
53
_run_main (args )
43
54
assert e .value .code == 0
44
55
assert os .environ ["LT_CLI_USED" ] == "1"
@@ -55,7 +66,17 @@ def test_run_env_vars_defaults(monkeypatch, fake_script):
55
66
def test_run_env_vars_accelerator (_ , accelerator , monkeypatch , fake_script ):
56
67
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
57
68
with pytest .raises (SystemExit ) as e :
58
- args = Namespace (script = fake_script , accelerator = accelerator , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
69
+ args = Namespace (
70
+ script = fake_script ,
71
+ accelerator = accelerator ,
72
+ strategy = None ,
73
+ devices = "1" ,
74
+ num_nodes = 1 ,
75
+ node_rank = 0 ,
76
+ main_address = "127.0.0.1" ,
77
+ main_port = 29400 ,
78
+ precision = None ,
79
+ )
59
80
_run_main (args )
60
81
assert e .value .code == 0
61
82
assert os .environ ["LT_ACCELERATOR" ] == accelerator
@@ -67,7 +88,17 @@ def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
67
88
def test_run_env_vars_strategy (_ , strategy , monkeypatch , fake_script ):
68
89
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
69
90
with pytest .raises (SystemExit ) as e :
70
- args = Namespace (script = fake_script , accelerator = None , strategy = strategy , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
91
+ args = Namespace (
92
+ script = fake_script ,
93
+ accelerator = None ,
94
+ strategy = strategy ,
95
+ devices = "1" ,
96
+ num_nodes = 1 ,
97
+ node_rank = 0 ,
98
+ main_address = "127.0.0.1" ,
99
+ main_port = 29400 ,
100
+ precision = None ,
101
+ )
71
102
_run_main (args )
72
103
assert e .value .code == 0
73
104
assert os .environ ["LT_STRATEGY" ] == strategy
@@ -96,7 +127,17 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script):
96
127
def test_run_env_vars_devices_cuda (_ , devices , monkeypatch , fake_script ):
97
128
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
98
129
with pytest .raises (SystemExit ) as e :
99
- args = Namespace (script = fake_script , accelerator = "cuda" , strategy = None , devices = devices , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
130
+ args = Namespace (
131
+ script = fake_script ,
132
+ accelerator = "cuda" ,
133
+ strategy = None ,
134
+ devices = devices ,
135
+ num_nodes = 1 ,
136
+ node_rank = 0 ,
137
+ main_address = "127.0.0.1" ,
138
+ main_port = 29400 ,
139
+ precision = None ,
140
+ )
100
141
_run_main (args )
101
142
assert e .value .code == 0
102
143
assert os .environ ["LT_DEVICES" ] == devices
@@ -108,7 +149,17 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
108
149
def test_run_env_vars_devices_mps (accelerator , monkeypatch , fake_script ):
109
150
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
110
151
with pytest .raises (SystemExit ) as e :
111
- args = Namespace (script = fake_script , accelerator = accelerator , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
152
+ args = Namespace (
153
+ script = fake_script ,
154
+ accelerator = accelerator ,
155
+ strategy = None ,
156
+ devices = "1" ,
157
+ num_nodes = 1 ,
158
+ node_rank = 0 ,
159
+ main_address = "127.0.0.1" ,
160
+ main_port = 29400 ,
161
+ precision = None ,
162
+ )
112
163
_run_main (args )
113
164
assert e .value .code == 0
114
165
assert os .environ ["LT_DEVICES" ] == "1"
@@ -119,7 +170,17 @@ def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
119
170
def test_run_env_vars_num_nodes (num_nodes , monkeypatch , fake_script ):
120
171
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
121
172
with pytest .raises (SystemExit ) as e :
122
- args = Namespace (script = fake_script , accelerator = None , strategy = None , devices = "1" , num_nodes = int (num_nodes ), node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
173
+ args = Namespace (
174
+ script = fake_script ,
175
+ accelerator = None ,
176
+ strategy = None ,
177
+ devices = "1" ,
178
+ num_nodes = int (num_nodes ),
179
+ node_rank = 0 ,
180
+ main_address = "127.0.0.1" ,
181
+ main_port = 29400 ,
182
+ precision = None ,
183
+ )
123
184
_run_main (args )
124
185
assert e .value .code == 0
125
186
assert os .environ ["LT_NUM_NODES" ] == num_nodes
@@ -130,7 +191,17 @@ def test_run_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
130
191
def test_run_env_vars_precision (precision , monkeypatch , fake_script ):
131
192
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
132
193
with pytest .raises (SystemExit ) as e :
133
- args = Namespace (script = fake_script , accelerator = None , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = precision )
194
+ args = Namespace (
195
+ script = fake_script ,
196
+ accelerator = None ,
197
+ strategy = None ,
198
+ devices = "1" ,
199
+ num_nodes = 1 ,
200
+ node_rank = 0 ,
201
+ main_address = "127.0.0.1" ,
202
+ main_port = 29400 ,
203
+ precision = precision ,
204
+ )
134
205
_run_main (args )
135
206
assert e .value .code == 0
136
207
assert os .environ ["LT_PRECISION" ] == precision
@@ -141,7 +212,17 @@ def test_run_torchrun_defaults(monkeypatch, fake_script):
141
212
torchrun_mock = Mock ()
142
213
monkeypatch .setitem (sys .modules , "torch.distributed.run" , torchrun_mock )
143
214
with pytest .raises (SystemExit ) as e :
144
- args = Namespace (script = fake_script , accelerator = None , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
215
+ args = Namespace (
216
+ script = fake_script ,
217
+ accelerator = None ,
218
+ strategy = None ,
219
+ devices = "1" ,
220
+ num_nodes = 1 ,
221
+ node_rank = 0 ,
222
+ main_address = "127.0.0.1" ,
223
+ main_port = 29400 ,
224
+ precision = None ,
225
+ )
145
226
_run_main (args )
146
227
assert e .value .code == 0
147
228
torchrun_mock .main .assert_called_with ([
@@ -170,7 +251,17 @@ def test_run_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
170
251
torchrun_mock = Mock ()
171
252
monkeypatch .setitem (sys .modules , "torch.distributed.run" , torchrun_mock )
172
253
with pytest .raises (SystemExit ) as e :
173
- args = Namespace (script = fake_script , accelerator = "cuda" , strategy = None , devices = devices , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
254
+ args = Namespace (
255
+ script = fake_script ,
256
+ accelerator = "cuda" ,
257
+ strategy = None ,
258
+ devices = devices ,
259
+ num_nodes = 1 ,
260
+ node_rank = 0 ,
261
+ main_address = "127.0.0.1" ,
262
+ main_port = 29400 ,
263
+ precision = None ,
264
+ )
174
265
_run_main (args )
175
266
assert e .value .code == 0
176
267
torchrun_mock .main .assert_called_with ([
0 commit comments