Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))


## [1.9.2] - 2023-02-15
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,9 @@ def _lazy_init_strategy(self) -> None:
if self.checkpoint_io:
self.strategy.checkpoint_io = self.checkpoint_io
if hasattr(self.strategy, "cluster_environment"):
self.strategy.cluster_environment = self.cluster_environment
if self.strategy.cluster_environment is None:
self.strategy.cluster_environment = self.cluster_environment
self.cluster_environment = self.strategy.cluster_environment
if hasattr(self.strategy, "parallel_devices"):
if self.strategy.parallel_devices:
self._parallel_devices = self.strategy.parallel_devices
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the `ColossalAIStrategy` and `ColossalAIPrecisionPlugin` in favor of the new [lightning-colossalai](https://github.com/Lightning-AI/lightning-colossalai) package ([#16757](https://github.com/Lightning-AI/lightning/pull/16757), [#16778](https://github.com/Lightning-AI/lightning/pull/16778))


### Fixed

- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))


## [1.9.2] - 2023-02-15

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,9 @@ def _lazy_init_strategy(self) -> None:
if self.checkpoint_io:
self.strategy.checkpoint_io = self.checkpoint_io
if hasattr(self.strategy, "cluster_environment"):
self.strategy.cluster_environment = self.cluster_environment
if self.strategy.cluster_environment is None:
self.strategy.cluster_environment = self.cluster_environment
self.cluster_environment = self.strategy.cluster_environment
if hasattr(self.strategy, "parallel_devices"):
if self.strategy.parallel_devices:
self._parallel_devices = self.strategy.parallel_devices
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
LSFEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
XLAEnvironment,
)
from lightning.fabric.plugins.io import TorchCheckpointIO
from lightning.fabric.strategies import (
Expand Down Expand Up @@ -70,6 +71,8 @@ def test_accelerator_choice_tpu(accelerator, devices):
# accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy
# This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606
assert isinstance(connector.strategy, XLAStrategy)
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
assert isinstance(connector.cluster_environment, XLAEnvironment)
else:
assert isinstance(connector.strategy, SingleTPUStrategy)

Expand Down
9 changes: 6 additions & 3 deletions tests/tests_pytorch/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def test_accelerator_cpu_when_tpu_available(tpu_available):

@RunIf(skip_windows=True)
@pytest.mark.parametrize(["accelerator", "devices"], [("auto", 8), ("auto", "auto"), ("tpu", None)])
def test_accelerator_tpu(accelerator, devices, tpu_available):
@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks")
def test_accelerator_tpu(_, accelerator, devices, tpu_available):
assert TPUAccelerator.is_available()

trainer = Trainer(accelerator=accelerator, devices=devices)
Expand Down Expand Up @@ -177,7 +178,8 @@ def test_strategy_choice_tpu_str_ddp_spawn(tpu_available):


@RunIf(skip_windows=True)
def test_strategy_choice_tpu_str_xla_debug(tpu_available):
@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks")
def test_strategy_choice_tpu_str_xla_debug(_, tpu_available):
trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8)
assert isinstance(trainer.strategy, XLAStrategy)

Expand Down Expand Up @@ -261,7 +263,8 @@ def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available):


@RunIf(skip_windows=True)
def test_xla_checkpoint_plugin_being_default(tpu_available):
@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks")
def test_xla_checkpoint_plugin_being_default(_, tpu_available):
trainer = Trainer(accelerator="tpu", devices=8)
assert isinstance(trainer.strategy.checkpoint_io, XLACheckpointIO)

Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/strategies/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest

from lightning.pytorch import Trainer
Expand Down Expand Up @@ -54,7 +56,8 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy):


@RunIf(skip_windows=True)
def test_xla_debug_strategy_registry(xla_available):
@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks")
def test_xla_debug_strategy_registry(_, tpu_available, xla_available):
strategy = "xla_debug"

assert strategy in StrategyRegistry
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
LSFEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
XLAEnvironment,
)
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import TPUAccelerator
from lightning.pytorch.accelerators.accelerator import Accelerator
from lightning.pytorch.accelerators.cpu import CPUAccelerator
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
Expand All @@ -43,6 +45,8 @@
DeepSpeedStrategy,
FSDPStrategy,
SingleDeviceStrategy,
SingleTPUStrategy,
XLAStrategy,
)
from lightning.pytorch.strategies.ddp_spawn import _DDP_FORK_ALIASES
from lightning.pytorch.strategies.hpu_parallel import HPUParallelStrategy
Expand All @@ -57,6 +61,24 @@ def test_accelerator_choice_cpu(tmpdir):
assert isinstance(trainer.strategy, SingleDeviceStrategy)


@RunIf(tpu=True, standalone=True)
@pytest.mark.parametrize(
["accelerator", "devices"], [("tpu", None), ("tpu", 1), ("tpu", [1]), ("tpu", 8), ("auto", 1), ("auto", 8)]
)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_accelerator_choice_tpu(accelerator, devices):
connector = AcceleratorConnector(accelerator=accelerator, devices=devices)
assert isinstance(connector.accelerator, TPUAccelerator)
if devices is None or (isinstance(devices, int) and devices > 1):
# accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy
# This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606
assert isinstance(connector.strategy, XLAStrategy)
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
assert isinstance(connector.cluster_environment, XLAEnvironment)
else:
assert isinstance(connector.strategy, SingleTPUStrategy)


def test_accelerator_invalid_choice():
with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='invalid'`"):
Trainer(accelerator="invalid")
Expand Down Expand Up @@ -248,7 +270,8 @@ def test_interactive_incompatible_backend_error(cuda_count_2, monkeypatch):


@RunIf(skip_windows=True)
def test_interactive_compatible_strategy_tpu(tpu_available, monkeypatch):
@mock.patch("lightning.pytorch.strategies.xla.XLAStrategy.set_world_ranks")
def test_interactive_compatible_strategy_tpu(_, tpu_available, monkeypatch):
monkeypatch.setattr(lightning.pytorch.trainer.connectors.accelerator_connector, "_IS_INTERACTIVE", True)
trainer = Trainer(accelerator="tpu")
assert trainer.strategy.launcher.is_interactive_compatible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import os
from unittest import mock
from unittest.mock import PropertyMock

import pytest
import torch
Expand Down Expand Up @@ -143,19 +142,19 @@ def test_num_stepping_batches_with_tpu_single():
assert trainer.estimated_stepping_batches == len(model.train_dataloader())


class MultiprocessModel(BoringModel):
def on_train_start(self):
assert self.trainer.world_size == 8
assert self.trainer.estimated_stepping_batches == len(self.train_dataloader()) // 8


@RunIf(tpu=True)
@mock.patch(
"lightning.pytorch.strategies.xla.XLAStrategy.root_device",
new_callable=PropertyMock,
return_value=torch.device("xla:0"),
)
def test_num_stepping_batches_with_tpu_multi(_):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_num_stepping_batches_with_tpu_multi():
"""Test stepping batches with the TPU strategy across multiple devices."""
trainer = Trainer(accelerator="tpu", devices=8, max_epochs=1)
model = BoringModel()
trainer._data_connector.attach_data(model)
trainer.strategy.connect(model)
assert trainer.estimated_stepping_batches == len(model.train_dataloader()) // 8
model = MultiprocessModel()
trainer.fit(model)


@mock.patch("lightning.pytorch.accelerators.ipu.IPUAccelerator.is_available", return_value=True)
Expand Down