diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index ba0d843e77c3f..de554d87c7448 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -55,7 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806)) ## [1.9.2] - 2023-02-15 diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index c8276f6ef8459..ad976f14c232a 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -517,7 +517,9 @@ def _lazy_init_strategy(self) -> None: if self.checkpoint_io: self.strategy.checkpoint_io = self.checkpoint_io if hasattr(self.strategy, "cluster_environment"): - self.strategy.cluster_environment = self.cluster_environment + if self.strategy.cluster_environment is None: + self.strategy.cluster_environment = self.cluster_environment + self.cluster_environment = self.strategy.cluster_environment if hasattr(self.strategy, "parallel_devices"): if self.strategy.parallel_devices: self._parallel_devices = self.strategy.parallel_devices diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 18fc06564c254..3b406c100b8bf 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -339,6 +339,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `ColossalAIStrategy` and `ColossalAIPrecisionPlugin` in favor of the new [lightning-colossalai](https://github.com/Lightning-AI/lightning-colossalai) package ([#16757](https://github.com/Lightning-AI/lightning/pull/16757), [#16778](https://github.com/Lightning-AI/lightning/pull/16778)) +### Fixed + +- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806)) + + ## [1.9.2] - 2023-02-15 ### Fixed diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index ced2daefb1508..29147d1869f61 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -583,7 +583,9 @@ def _lazy_init_strategy(self) -> None: if self.checkpoint_io: self.strategy.checkpoint_io = self.checkpoint_io if hasattr(self.strategy, "cluster_environment"): - self.strategy.cluster_environment = self.cluster_environment + if self.strategy.cluster_environment is None: + self.strategy.cluster_environment = self.cluster_environment + self.cluster_environment = self.strategy.cluster_environment if hasattr(self.strategy, "parallel_devices"): if self.strategy.parallel_devices: self._parallel_devices = self.strategy.parallel_devices diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 8296e2426f4db..e8f6c062d7ca1 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -37,6 +37,7 @@ LSFEnvironment, SLURMEnvironment, TorchElasticEnvironment, + XLAEnvironment, ) from lightning.fabric.plugins.io import TorchCheckpointIO from lightning.fabric.strategies import ( @@ -70,6 +71,8 @@ def test_accelerator_choice_tpu(accelerator, devices): # accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy # This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606 assert isinstance(connector.strategy, XLAStrategy) + assert isinstance(connector.strategy.cluster_environment, XLAEnvironment) + assert isinstance(connector.cluster_environment, XLAEnvironment) else: assert isinstance(connector.strategy, SingleTPUStrategy) diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index a2a4389142da8..9762f73a0a44e 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -89,7 +89,8 @@ def test_accelerator_cpu_when_tpu_available(tpu_available): @RunIf(skip_windows=True) @pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)]) -def test_accelerator_tpu(accelerator, devices, tpu_available): +@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks") +def test_accelerator_tpu(_, accelerator, devices, tpu_available): assert TPUAccelerator.is_available() trainer = Trainer(accelerator=accelerator, devices=devices) @@ -177,7 +178,8 @@ def test_strategy_choice_tpu_str_ddp_spawn(tpu_available): @RunIf(skip_windows=True) -def test_strategy_choice_tpu_str_xla_debug(tpu_available): +@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks") +def test_strategy_choice_tpu_str_xla_debug(_, tpu_available): trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8) assert isinstance(trainer.strategy, XLAStrategy) @@ -261,7 +263,8 @@ def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available): @RunIf(skip_windows=True) -def test_xla_checkpoint_plugin_being_default(tpu_available): +@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks") +def test_xla_checkpoint_plugin_being_default(_, tpu_available): trainer = Trainer(accelerator="tpu", devices=8) assert isinstance(trainer.strategy.checkpoint_io, XLACheckpointIO) diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py index 75b7b63957387..cb9efbe4ea95e 100644 --- a/tests/tests_pytorch/strategies/test_registry.py +++ b/tests/tests_pytorch/strategies/test_registry.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest from lightning.pytorch import Trainer @@ -54,7 +56,8 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy): @RunIf(skip_windows=True) -def test_xla_debug_strategy_registry(xla_available): +@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks") +def test_xla_debug_strategy_registry(_, tpu_available, xla_available): strategy = "xla_debug" assert strategy in StrategyRegistry diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index e98d4df2a9c54..7347505ecc77e 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -29,8 +29,10 @@ LSFEnvironment, SLURMEnvironment, TorchElasticEnvironment, + XLAEnvironment, ) from lightning.pytorch import Trainer +from lightning.pytorch.accelerators import TPUAccelerator from lightning.pytorch.accelerators.accelerator import Accelerator from lightning.pytorch.accelerators.cpu import CPUAccelerator from lightning.pytorch.accelerators.cuda import CUDAAccelerator @@ -43,6 +45,8 @@ DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, + SingleTPUStrategy, + XLAStrategy, ) from lightning.pytorch.strategies.ddp_spawn import _DDP_FORK_ALIASES from lightning.pytorch.strategies.hpu_parallel import HPUParallelStrategy @@ -57,6 +61,24 @@ def test_accelerator_choice_cpu(tmpdir): assert isinstance(trainer.strategy, SingleDeviceStrategy) +@RunIf(tpu=True, standalone=True) +@pytest.mark.parametrize( + ["accelerator", "devices"], [("tpu", None), ("tpu", 1), ("tpu", [1]), ("tpu", 8), ("auto", 1), ("auto", 8)] +) +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) +def test_accelerator_choice_tpu(accelerator, devices): + connector = AcceleratorConnector(accelerator=accelerator, devices=devices) + assert isinstance(connector.accelerator, TPUAccelerator) + if devices is None or (isinstance(devices, int) and devices > 1): + # accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy + # This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606 + assert isinstance(connector.strategy, XLAStrategy) + assert isinstance(connector.strategy.cluster_environment, XLAEnvironment) + assert isinstance(connector.cluster_environment, XLAEnvironment) + else: + assert isinstance(connector.strategy, SingleTPUStrategy) + + def test_accelerator_invalid_choice(): with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='invalid'`"): Trainer(accelerator="invalid") @@ -248,7 +270,8 @@ def test_interactive_incompatible_backend_error(cuda_count_2, monkeypatch): @RunIf(skip_windows=True) -def test_interactive_compatible_strategy_tpu(tpu_available, monkeypatch): +@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks") +def test_interactive_compatible_strategy_tpu(_, tpu_available, monkeypatch): monkeypatch.setattr(lightning.pytorch.trainer.connectors.accelerator_connector, "_IS_INTERACTIVE", True) trainer = Trainer(accelerator="tpu") assert trainer.strategy.launcher.is_interactive_compatible diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 8109520bd7610..901187b6beffd 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -15,7 +15,6 @@ import logging import os from unittest import mock -from unittest.mock import PropertyMock import pytest import torch @@ -143,19 +142,19 @@ def test_num_stepping_batches_with_tpu_single(): assert trainer.estimated_stepping_batches == len(model.train_dataloader()) +class MultiprocessModel(BoringModel): + def on_train_start(self): + assert self.trainer.world_size == 8 + assert self.trainer.estimated_stepping_batches == len(self.train_dataloader()) // 8 + + @RunIf(tpu=True) -@mock.patch( - "lightning.pytorch.strategies.xla.XLAStrategy.root_device", - new_callable=PropertyMock, - return_value=torch.device("xla:0"), -) -def test_num_stepping_batches_with_tpu_multi(_): +@mock.patch.dict(os.environ, os.environ.copy(), clear=True) +def test_num_stepping_batches_with_tpu_multi(): """Test stepping batches with the TPU strategy across multiple devices.""" trainer = Trainer(accelerator="tpu", devices=8, max_epochs=1) - model = BoringModel() - trainer._data_connector.attach_data(model) - trainer.strategy.connect(model) - assert trainer.estimated_stepping_batches == len(model.train_dataloader()) // 8 + model = MultiprocessModel() + trainer.fit(model) @mock.patch("lightning.pytorch.accelerators.ipu.IPUAccelerator.is_available", return_value=True)