Skip to content

Commit a8abb3a

Browse files
committed
test
1 parent d53b183 commit a8abb3a

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

tests/tests_fabric/test_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
LightningEnvironment,
3737
LSFEnvironment,
3838
SLURMEnvironment,
39-
TorchElasticEnvironment,
39+
TorchElasticEnvironment, XLAEnvironment,
4040
)
4141
from lightning.fabric.plugins.io import TorchCheckpointIO
4242
from lightning.fabric.strategies import (
@@ -70,6 +70,7 @@ def test_accelerator_choice_tpu(accelerator, devices):
7070
# accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy
7171
# This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606
7272
assert isinstance(connector.strategy, XLAStrategy)
73+
assert isinstance(connector.cluster_environment, XLAEnvironment)
7374
else:
7475
assert isinstance(connector.strategy, SingleTPUStrategy)
7576

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import pytest
2121
import torch
2222
import torch.distributed
23+
24+
from lightning.pytorch.accelerators import TPUAccelerator
2325
from lightning_utilities.core.imports import package_available
2426

2527
import lightning.pytorch
@@ -28,7 +30,7 @@
2830
LightningEnvironment,
2931
LSFEnvironment,
3032
SLURMEnvironment,
31-
TorchElasticEnvironment,
33+
TorchElasticEnvironment, XLAEnvironment,
3234
)
3335
from lightning.pytorch import Trainer
3436
from lightning.pytorch.accelerators.accelerator import Accelerator
@@ -42,7 +44,7 @@
4244
DDPStrategy,
4345
DeepSpeedStrategy,
4446
FSDPStrategy,
45-
SingleDeviceStrategy,
47+
SingleDeviceStrategy, XLAStrategy, SingleTPUStrategy,
4648
)
4749
from lightning.pytorch.strategies.ddp_spawn import _DDP_FORK_ALIASES
4850
from lightning.pytorch.strategies.hpu_parallel import HPUParallelStrategy
@@ -57,6 +59,23 @@ def test_accelerator_choice_cpu(tmpdir):
5759
assert isinstance(trainer.strategy, SingleDeviceStrategy)
5860

5961

62+
@RunIf(tpu=True, standalone=True)
63+
@pytest.mark.parametrize(
64+
["accelerator", "devices"], [("tpu", None), ("tpu", 1), ("tpu", [1]), ("tpu", 8), ("auto", 1), ("auto", 8)]
65+
)
66+
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
67+
def test_accelerator_choice_tpu(accelerator, devices):
68+
connector = AcceleratorConnector(accelerator=accelerator, devices=devices)
69+
assert isinstance(connector.accelerator, TPUAccelerator)
70+
if devices is None or (isinstance(devices, int) and devices > 1):
71+
# accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy
72+
# This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606
73+
assert isinstance(connector.strategy, XLAStrategy)
74+
assert isinstance(connector.cluster_environment, XLAEnvironment)
75+
else:
76+
assert isinstance(connector.strategy, SingleTPUStrategy)
77+
78+
6079
def test_accelerator_invalid_choice():
6180
with pytest.raises(ValueError, match="You selected an invalid accelerator name: `accelerator='invalid'`"):
6281
Trainer(accelerator="invalid")

0 commit comments

Comments
 (0)