Skip to content

Commit e00b23d

Browse files
committed
Fix strategy type validation in connectors (#16693)
1 parent b6122ec commit e00b23d

File tree

6 files changed

+28
-19
lines changed

6 files changed

+28
-19
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
### Fixed
2121

22-
-
22+
- Fixed an attribute error and improved input validation for invalid strategy types being passed to Fabric ([#16693](https://github.com/Lightning-AI/lightning/pull/16693))
2323

2424

2525
## [1.9.1] - 2023-02-10

src/lightning_fabric/connector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def _check_config_and_set_final_flags(
187187
if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
188188
raise ValueError(
189189
f"You selected an invalid strategy name: `strategy={strategy!r}`."
190+
" It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."
190191
" Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
191192
" Find a complete list of options in our documentation at https://lightning.ai"
192193
)
@@ -436,12 +437,12 @@ def _check_strategy_and_fallback(self) -> None:
436437

437438
def _init_strategy(self) -> None:
438439
"""Instantiate the Strategy given depending on the setting of ``_strategy_flag``."""
440+
# The validation of `_strategy_flag` already happened earlier on in the connector
441+
assert isinstance(self._strategy_flag, (str, Strategy))
439442
if isinstance(self._strategy_flag, str):
440443
self.strategy = STRATEGY_REGISTRY.get(self._strategy_flag)
441-
elif isinstance(self._strategy_flag, Strategy):
442-
self.strategy = self._strategy_flag
443444
else:
444-
raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
445+
self.strategy = self._strategy_flag
445446

446447
def _check_and_init_precision(self) -> Precision:
447448
self._validate_precision_choice()

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
### Fixed
2121

22-
-
22+
- Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693))
2323

2424

2525
## [1.9.1] - 2023-02-10

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ def _check_config_and_set_final_flags(
273273
" you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead."
274274
)
275275

276+
if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
277+
raise ValueError(
278+
f"You selected an invalid strategy name: `strategy={strategy!r}`."
279+
" It must be either a string or an instance of `lightning.pytorch.strategies.Strategy`."
280+
" Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
281+
" Find a complete list of options in our documentation at https://lightning.ai"
282+
)
283+
276284
if (
277285
accelerator is not None
278286
and accelerator not in self._accelerator_types
@@ -694,13 +702,13 @@ def _init_strategy(self) -> None:
694702
# handle horovod has to happen before initialize strategy because HorovodStrategy needs hvd.init() first.
695703
# TODO lazy initialized and setup horovod strategy `global_rank`
696704
self._handle_horovod()
705+
# The validation of `_strategy_flag` already happened earlier on in the connector
706+
assert isinstance(self._strategy_flag, (str, Strategy))
697707
if isinstance(self._strategy_flag, str):
698708
self.strategy = StrategyRegistry.get(self._strategy_flag)
699-
elif isinstance(self._strategy_flag, Strategy):
709+
else:
700710
# TODO(fabric): remove ignore after merging Fabric and PL strategies
701711
self.strategy = self._strategy_flag # type: ignore[assignment]
702-
else:
703-
raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
704712

705713
def _check_and_init_precision(self) -> PrecisionPlugin:
706714
self._validate_precision_choice()

tests/tests_fabric/test_connector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,10 @@ def test_invalid_accelerator_choice():
386386
_Connector(accelerator="cocofruit")
387387

388388

389-
def test_invalid_strategy_choice():
390-
with pytest.raises(ValueError, match="You selected an invalid strategy name: `strategy='cocofruit'`"):
391-
_Connector(strategy="cocofruit")
389+
@pytest.mark.parametrize("invalid_strategy", ["cocofruit", object()])
390+
def test_invalid_strategy_choice(invalid_strategy):
391+
with pytest.raises(ValueError, match="You selected an invalid strategy name:"):
392+
_Connector(strategy=invalid_strategy)
392393

393394

394395
@pytest.mark.parametrize(

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def test_accelerator_invalid_choice():
6464
Trainer(accelerator="invalid")
6565

6666

67+
@pytest.mark.parametrize("invalid_strategy", ["cocofruit", object()])
68+
def test_invalid_strategy_choice(invalid_strategy):
69+
with pytest.raises(ValueError, match="You selected an invalid strategy name:"):
70+
AcceleratorConnector(strategy=invalid_strategy)
71+
72+
6773
@RunIf(skip_windows=True, standalone=True)
6874
def test_strategy_choice_ddp_on_cpu(tmpdir):
6975
"""Test that selecting DDPStrategy on CPU works."""
@@ -373,13 +379,6 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
373379
assert isinstance(trainer.strategy, DDPStrategy)
374380

375381

376-
def test_exception_invalid_strategy():
377-
with pytest.raises(MisconfigurationException, match=r"strategy='ddp_cpu'\)` is not a valid"):
378-
Trainer(strategy="ddp_cpu")
379-
with pytest.raises(MisconfigurationException, match=r"strategy='tpu_spawn'\)` is not a valid"):
380-
Trainer(strategy="tpu_spawn")
381-
382-
383382
@pytest.mark.parametrize(
384383
["strategy", "strategy_class"],
385384
(
@@ -442,7 +441,7 @@ def test_strategy_choice_cpu_instance(strategy_class):
442441
pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)),
443442
],
444443
)
445-
def test_strategy_choice_gpu_str(strategy, strategy_class):
444+
def test_strategy_choice_gpu_str(strategy, strategy_class, cuda_count_2, mps_count_0):
446445
if "sharded" in strategy:
447446
with pytest.deprecated_call(match="FairScale has been deprecated in v1.9.0"):
448447
trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2)

0 commit comments

Comments
 (0)