diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 2c0a9ea43348b..aa0670e5cd8a5 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an import error when `torch.distributed` is not available ([#16658](https://github.com/Lightning-AI/lightning/pull/16658)) +- Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693)) + ## [1.9.0] - 2023-01-17 diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 237e7be4ec6f1..b0e997dcf83a2 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -188,6 +188,7 @@ def _check_config_and_set_final_flags( if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy): raise ValueError( f"You selected an invalid strategy name: `strategy={strategy!r}`." + " It must be either a string or an instance of `lightning.fabric.strategies.Strategy`." " Example choices: ddp, ddp_spawn, deepspeed, dp, ..." " Find a complete list of options in our documentation at https://lightning.ai" ) @@ -439,12 +440,12 @@ def _check_strategy_and_fallback(self) -> None: def _init_strategy(self) -> None: """Instantiate the Strategy given depending on the setting of ``_strategy_flag``.""" + # The validation of `_strategy_flag` already happened earlier on in the connector + assert isinstance(self._strategy_flag, (str, Strategy)) if isinstance(self._strategy_flag, str): self.strategy = STRATEGY_REGISTRY.get(self._strategy_flag) - elif isinstance(self._strategy_flag, Strategy): - self.strategy = self._strategy_flag else: - raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}") + self.strategy = self._strategy_flag def _check_and_init_precision(self) -> Precision: self._validate_precision_choice() diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3c46617316d83..e921d4cef7d18 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -239,6 +239,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an import error when `torch.distributed` is not available ([#16658](https://github.com/Lightning-AI/lightning/pull/16658)) +- Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693)) + ## [1.9.0] - 2023-01-17 diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index e94f7d6fc555f..98417cd8f5965 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -206,16 +206,14 @@ def _check_config_and_set_final_flags( if strategy is not None: self._strategy_flag = strategy - if strategy == "ddp_cpu": - raise MisconfigurationException( - "`Trainer(strategy='ddp_cpu')` is not a valid strategy," - " you can use `Trainer(strategy='ddp'|'ddp_spawn'|'ddp_fork', accelerator='cpu')` instead." - ) - if strategy == "tpu_spawn": - raise MisconfigurationException( - "`Trainer(strategy='tpu_spawn')` is not a valid strategy," - " you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead." - ) + + if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy): + raise ValueError( + f"You selected an invalid strategy name: `strategy={strategy!r}`." + " It must be either a string or an instance of `lightning.pytorch.strategies.Strategy`." + " Example choices: ddp, ddp_spawn, deepspeed, dp, ..." + " Find a complete list of options in our documentation at https://lightning.ai" + ) if ( accelerator is not None @@ -505,13 +503,13 @@ def _check_strategy_and_fallback(self) -> None: def _init_strategy(self) -> None: """Instantiate the Strategy given depending on the setting of ``_strategy_flag``.""" + # The validation of `_strategy_flag` already happened earlier on in the connector + assert isinstance(self._strategy_flag, (str, Strategy)) if isinstance(self._strategy_flag, str): self.strategy = StrategyRegistry.get(self._strategy_flag) - elif isinstance(self._strategy_flag, Strategy): + else: # TODO(fabric): remove ignore after merging Fabric and PL strategies self.strategy = self._strategy_flag # type: ignore[assignment] - else: - raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}") def _check_and_init_precision(self) -> PrecisionPlugin: self._validate_precision_choice() diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index c9411e9e5f3f3..3d9f56b5076b7 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -386,9 +386,10 @@ def test_invalid_accelerator_choice(): _Connector(accelerator="cocofruit") -def test_invalid_strategy_choice(): - with pytest.raises(ValueError, match="You selected an invalid strategy name: `strategy='cocofruit'`"): - _Connector(strategy="cocofruit") +@pytest.mark.parametrize("invalid_strategy", ["cocofruit", object()]) +def test_invalid_strategy_choice(invalid_strategy): + with pytest.raises(ValueError, match="You selected an invalid strategy name:"): + _Connector(strategy=invalid_strategy) @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index c8ad3310ebd48..55397e708d6fe 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -62,6 +62,12 @@ def test_accelerator_invalid_choice(): Trainer(accelerator="invalid") +@pytest.mark.parametrize("invalid_strategy", ["cocofruit", object()]) +def test_invalid_strategy_choice(invalid_strategy): + with pytest.raises(ValueError, match="You selected an invalid strategy name:"): + AcceleratorConnector(strategy=invalid_strategy) + + @RunIf(skip_windows=True, standalone=True) def test_strategy_choice_ddp_on_cpu(tmpdir): """Test that selecting DDPStrategy on CPU works.""" @@ -351,13 +357,6 @@ def test_unsupported_strategy_types_on_cpu_and_fallback(): assert isinstance(trainer.strategy, DDPStrategy) -def test_exception_invalid_strategy(): - with pytest.raises(MisconfigurationException, match=r"strategy='ddp_cpu'\)` is not a valid"): - Trainer(strategy="ddp_cpu") - with pytest.raises(MisconfigurationException, match=r"strategy='tpu_spawn'\)` is not a valid"): - Trainer(strategy="tpu_spawn") - - @pytest.mark.parametrize( ["strategy", "strategy_class"], ( @@ -411,7 +410,7 @@ def test_strategy_choice_cpu_instance(strategy_class): pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)), ], ) -def test_strategy_choice_gpu_str(strategy, strategy_class, cuda_count_2): +def test_strategy_choice_gpu_str(strategy, strategy_class, cuda_count_2, mps_count_0): trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2) assert isinstance(trainer.strategy, strategy_class)