Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an import error when `torch.distributed` is not available ([#16658](https://github.com/Lightning-AI/lightning/pull/16658))

- Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693))


## [1.9.0] - 2023-01-17

Expand Down
7 changes: 4 additions & 3 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _check_config_and_set_final_flags(
if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
raise ValueError(
f"You selected an invalid strategy name: `strategy={strategy!r}`."
" It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."
" Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
" Find a complete list of options in our documentation at https://lightning.ai"
)
Expand Down Expand Up @@ -439,12 +440,12 @@ def _check_strategy_and_fallback(self) -> None:

def _init_strategy(self) -> None:
"""Instantiate the Strategy given depending on the setting of ``_strategy_flag``."""
# The validation of `_strategy_flag` already happened earlier on in the connector
assert isinstance(self._strategy_flag, (str, Strategy))
if isinstance(self._strategy_flag, str):
self.strategy = STRATEGY_REGISTRY.get(self._strategy_flag)
elif isinstance(self._strategy_flag, Strategy):
self.strategy = self._strategy_flag
else:
raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
self.strategy = self._strategy_flag

def _check_and_init_precision(self) -> Precision:
self._validate_precision_choice()
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an import error when `torch.distributed` is not available ([#16658](https://github.com/Lightning-AI/lightning/pull/16658))

- Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693))


## [1.9.0] - 2023-01-17

Expand Down
24 changes: 11 additions & 13 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,14 @@ def _check_config_and_set_final_flags(

if strategy is not None:
self._strategy_flag = strategy
if strategy == "ddp_cpu":
raise MisconfigurationException(
"`Trainer(strategy='ddp_cpu')` is not a valid strategy,"
" you can use `Trainer(strategy='ddp'|'ddp_spawn'|'ddp_fork', accelerator='cpu')` instead."
)
if strategy == "tpu_spawn":
raise MisconfigurationException(
"`Trainer(strategy='tpu_spawn')` is not a valid strategy,"
" you can use `Trainer(strategy='ddp_spawn', accelerator='tpu')` instead."
)

if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
raise ValueError(
f"You selected an invalid strategy name: `strategy={strategy!r}`."
" It must be either a string or an instance of `lightning.pytorch.strategies.Strategy`."
" Example choices: ddp, ddp_spawn, deepspeed, dp, ..."
" Find a complete list of options in our documentation at https://lightning.ai"
)

if (
accelerator is not None
Expand Down Expand Up @@ -505,13 +503,13 @@ def _check_strategy_and_fallback(self) -> None:

def _init_strategy(self) -> None:
"""Instantiate the Strategy given depending on the setting of ``_strategy_flag``."""
# The validation of `_strategy_flag` already happened earlier on in the connector
assert isinstance(self._strategy_flag, (str, Strategy))
if isinstance(self._strategy_flag, str):
self.strategy = StrategyRegistry.get(self._strategy_flag)
elif isinstance(self._strategy_flag, Strategy):
else:
# TODO(fabric): remove ignore after merging Fabric and PL strategies
self.strategy = self._strategy_flag # type: ignore[assignment]
else:
raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")

def _check_and_init_precision(self) -> PrecisionPlugin:
self._validate_precision_choice()
Expand Down
7 changes: 4 additions & 3 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,10 @@ def test_invalid_accelerator_choice():
_Connector(accelerator="cocofruit")


def test_invalid_strategy_choice():
with pytest.raises(ValueError, match="You selected an invalid strategy name: `strategy='cocofruit'`"):
_Connector(strategy="cocofruit")
@pytest.mark.parametrize("invalid_strategy", ["cocofruit", object()])
def test_invalid_strategy_choice(invalid_strategy):
with pytest.raises(ValueError, match="You selected an invalid strategy name:"):
_Connector(strategy=invalid_strategy)


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def test_accelerator_invalid_choice():
Trainer(accelerator="invalid")


@pytest.mark.parametrize("invalid_strategy", ["cocofruit", object()])
def test_invalid_strategy_choice(invalid_strategy):
with pytest.raises(ValueError, match="You selected an invalid strategy name:"):
AcceleratorConnector(strategy=invalid_strategy)


@RunIf(skip_windows=True, standalone=True)
def test_strategy_choice_ddp_on_cpu(tmpdir):
"""Test that selecting DDPStrategy on CPU works."""
Expand Down Expand Up @@ -351,13 +357,6 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
assert isinstance(trainer.strategy, DDPStrategy)


def test_exception_invalid_strategy():
with pytest.raises(MisconfigurationException, match=r"strategy='ddp_cpu'\)` is not a valid"):
Trainer(strategy="ddp_cpu")
with pytest.raises(MisconfigurationException, match=r"strategy='tpu_spawn'\)` is not a valid"):
Trainer(strategy="tpu_spawn")


@pytest.mark.parametrize(
["strategy", "strategy_class"],
(
Expand Down Expand Up @@ -411,7 +410,7 @@ def test_strategy_choice_cpu_instance(strategy_class):
pytest.param("deepspeed", DeepSpeedStrategy, marks=RunIf(deepspeed=True)),
],
)
def test_strategy_choice_gpu_str(strategy, strategy_class, cuda_count_2):
def test_strategy_choice_gpu_str(strategy, strategy_class, cuda_count_2, mps_count_0):
trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2)
assert isinstance(trainer.strategy, strategy_class)

Expand Down