Skip to content

Commit 5070508

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent bb7bd91 commit 5070508

File tree

1 file changed

+101
-10
lines changed

1 file changed

+101
-10
lines changed

tests/tests_fabric/test_cli.py

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import pytest
2323
from jsonargparse import Namespace
2424

25-
from lightning.fabric.cli import FabricCLI, _get_supported_strategies, main as _run_main
25+
from lightning.fabric.cli import FabricCLI, _get_supported_strategies
26+
from lightning.fabric.cli import main as _run_main
2627
from lightning.fabric.utilities.consolidate_checkpoint import main as _consolidate_main
2728
from tests_fabric.helpers.runif import RunIf
2829

@@ -38,7 +39,17 @@ def fake_script(tmp_path):
3839
def test_run_env_vars_defaults(monkeypatch, fake_script):
3940
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
4041
with pytest.raises(SystemExit) as e:
41-
args = Namespace(script=fake_script, accelerator=None, strategy=None, devices="1", num_nodes=1, node_rank=0, main_address="127.0.0.1", main_port=29400, precision=None)
42+
args = Namespace(
43+
script=fake_script,
44+
accelerator=None,
45+
strategy=None,
46+
devices="1",
47+
num_nodes=1,
48+
node_rank=0,
49+
main_address="127.0.0.1",
50+
main_port=29400,
51+
precision=None,
52+
)
4253
_run_main(args)
4354
assert e.value.code == 0
4455
assert os.environ["LT_CLI_USED"] == "1"
@@ -55,7 +66,17 @@ def test_run_env_vars_defaults(monkeypatch, fake_script):
5566
def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
5667
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
5768
with pytest.raises(SystemExit) as e:
58-
args = Namespace(script=fake_script, accelerator=accelerator, strategy=None, devices="1", num_nodes=1, node_rank=0, main_address="127.0.0.1", main_port=29400, precision=None)
69+
args = Namespace(
70+
script=fake_script,
71+
accelerator=accelerator,
72+
strategy=None,
73+
devices="1",
74+
num_nodes=1,
75+
node_rank=0,
76+
main_address="127.0.0.1",
77+
main_port=29400,
78+
precision=None,
79+
)
5980
_run_main(args)
6081
assert e.value.code == 0
6182
assert os.environ["LT_ACCELERATOR"] == accelerator
@@ -67,7 +88,17 @@ def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
6788
def test_run_env_vars_strategy(_, strategy, monkeypatch, fake_script):
6889
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
6990
with pytest.raises(SystemExit) as e:
70-
args = Namespace(script=fake_script, accelerator=None, strategy=strategy, devices="1", num_nodes=1, node_rank=0, main_address="127.0.0.1", main_port=29400, precision=None)
91+
args = Namespace(
92+
script=fake_script,
93+
accelerator=None,
94+
strategy=strategy,
95+
devices="1",
96+
num_nodes=1,
97+
node_rank=0,
98+
main_address="127.0.0.1",
99+
main_port=29400,
100+
precision=None,
101+
)
71102
_run_main(args)
72103
assert e.value.code == 0
73104
assert os.environ["LT_STRATEGY"] == strategy
@@ -96,7 +127,17 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script):
96127
def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
97128
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
98129
with pytest.raises(SystemExit) as e:
99-
args = Namespace(script=fake_script, accelerator="cuda", strategy=None, devices=devices, num_nodes=1, node_rank=0, main_address="127.0.0.1", main_port=29400, precision=None)
130+
args = Namespace(
131+
script=fake_script,
132+
accelerator="cuda",
133+
strategy=None,
134+
devices=devices,
135+
num_nodes=1,
136+
node_rank=0,
137+
main_address="127.0.0.1",
138+
main_port=29400,
139+
precision=None,
140+
)
100141
_run_main(args)
101142
assert e.value.code == 0
102143
assert os.environ["LT_DEVICES"] == devices
@@ -108,7 +149,17 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
108149
def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
109150
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
110151
with pytest.raises(SystemExit) as e:
111-
args = Namespace(script=fake_script, accelerator=accelerator, strategy=None, devices="1", num_nodes=1, node_rank=0, main_address="127.0.0.1", main_port=29400, precision=None)
152+
args = Namespace(
153+
script=fake_script,
154+
accelerator=accelerator,
155+
strategy=None,
156+
devices="1",
157+
num_nodes=1,
158+
node_rank=0,
159+
main_address="127.0.0.1",
160+
main_port=29400,
161+
precision=None,
162+
)
112163
_run_main(args)
113164
assert e.value.code == 0
114165
assert os.environ["LT_DEVICES"] == "1"
@@ -119,7 +170,17 @@ def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
119170
def test_run_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
120171
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
121172
with pytest.raises(SystemExit) as e:
122-
args = Namespace(script=fake_script, accelerator=None, strategy=None, devices="1", num_nodes=int(num_nodes), node_rank=0, main_address="127.0.0.1", main_port=29400, precision=None)
173+
args = Namespace(
174+
script=fake_script,
175+
accelerator=None,
176+
strategy=None,
177+
devices="1",
178+
num_nodes=int(num_nodes),
179+
node_rank=0,
180+
main_address="127.0.0.1",
181+
main_port=29400,
182+
precision=None,
183+
)
123184
_run_main(args)
124185
assert e.value.code == 0
125186
assert os.environ["LT_NUM_NODES"] == num_nodes
@@ -130,7 +191,17 @@ def test_run_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
130191
def test_run_env_vars_precision(precision, monkeypatch, fake_script):
131192
monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock())
132193
with pytest.raises(SystemExit) as e:
133-
args = Namespace(script=fake_script, accelerator=None, strategy=None, devices="1", num_nodes=1, node_rank=0, main_address="127.0.0.1", main_port=29400, precision=precision)
194+
args = Namespace(
195+
script=fake_script,
196+
accelerator=None,
197+
strategy=None,
198+
devices="1",
199+
num_nodes=1,
200+
node_rank=0,
201+
main_address="127.0.0.1",
202+
main_port=29400,
203+
precision=precision,
204+
)
134205
_run_main(args)
135206
assert e.value.code == 0
136207
assert os.environ["LT_PRECISION"] == precision
@@ -141,7 +212,17 @@ def test_run_torchrun_defaults(monkeypatch, fake_script):
141212
torchrun_mock = Mock()
142213
monkeypatch.setitem(sys.modules, "torch.distributed.run", torchrun_mock)
143214
with pytest.raises(SystemExit) as e:
144-
args = Namespace(script=fake_script, accelerator=None, strategy=None, devices="1", num_nodes=1, node_rank=0, main_address="127.0.0.1", main_port=29400, precision=None)
215+
args = Namespace(
216+
script=fake_script,
217+
accelerator=None,
218+
strategy=None,
219+
devices="1",
220+
num_nodes=1,
221+
node_rank=0,
222+
main_address="127.0.0.1",
223+
main_port=29400,
224+
precision=None,
225+
)
145226
_run_main(args)
146227
assert e.value.code == 0
147228
torchrun_mock.main.assert_called_with([
@@ -170,7 +251,17 @@ def test_run_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
170251
torchrun_mock = Mock()
171252
monkeypatch.setitem(sys.modules, "torch.distributed.run", torchrun_mock)
172253
with pytest.raises(SystemExit) as e:
173-
args = Namespace(script=fake_script, accelerator="cuda", strategy=None, devices=devices, num_nodes=1, node_rank=0, main_address="127.0.0.1", main_port=29400, precision=None)
254+
args = Namespace(
255+
script=fake_script,
256+
accelerator="cuda",
257+
strategy=None,
258+
devices=devices,
259+
num_nodes=1,
260+
node_rank=0,
261+
main_address="127.0.0.1",
262+
main_port=29400,
263+
precision=None,
264+
)
174265
_run_main(args)
175266
assert e.value.code == 0
176267
torchrun_mock.main.assert_called_with([

0 commit comments

Comments
 (0)