20
20
import pytest
21
21
import torch
22
22
import torch .distributed
23
+
24
+ from lightning .pytorch .accelerators import TPUAccelerator
23
25
from lightning_utilities .core .imports import package_available
24
26
25
27
import lightning .pytorch
28
30
LightningEnvironment ,
29
31
LSFEnvironment ,
30
32
SLURMEnvironment ,
31
- TorchElasticEnvironment ,
33
+ TorchElasticEnvironment , XLAEnvironment ,
32
34
)
33
35
from lightning .pytorch import Trainer
34
36
from lightning .pytorch .accelerators .accelerator import Accelerator
42
44
DDPStrategy ,
43
45
DeepSpeedStrategy ,
44
46
FSDPStrategy ,
45
- SingleDeviceStrategy ,
47
+ SingleDeviceStrategy , XLAStrategy , SingleTPUStrategy ,
46
48
)
47
49
from lightning .pytorch .strategies .ddp_spawn import _DDP_FORK_ALIASES
48
50
from lightning .pytorch .strategies .hpu_parallel import HPUParallelStrategy
@@ -57,6 +59,23 @@ def test_accelerator_choice_cpu(tmpdir):
57
59
assert isinstance (trainer .strategy , SingleDeviceStrategy )
58
60
59
61
62
+ @RunIf (tpu = True , standalone = True )
63
+ @pytest .mark .parametrize (
64
+ ["accelerator" , "devices" ], [("tpu" , None ), ("tpu" , 1 ), ("tpu" , [1 ]), ("tpu" , 8 ), ("auto" , 1 ), ("auto" , 8 )]
65
+ )
66
+ @mock .patch .dict (os .environ , os .environ .copy (), clear = True )
67
+ def test_accelerator_choice_tpu (accelerator , devices ):
68
+ connector = AcceleratorConnector (accelerator = accelerator , devices = devices )
69
+ assert isinstance (connector .accelerator , TPUAccelerator )
70
+ if devices is None or (isinstance (devices , int ) and devices > 1 ):
71
+ # accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy
72
+ # This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606
73
+ assert isinstance (connector .strategy , XLAStrategy )
74
+ assert isinstance (connector .cluster_environment , XLAEnvironment )
75
+ else :
76
+ assert isinstance (connector .strategy , SingleTPUStrategy )
77
+
78
+
60
79
def test_accelerator_invalid_choice ():
61
80
with pytest .raises (ValueError , match = "You selected an invalid accelerator name: `accelerator='invalid'`" ):
62
81
Trainer (accelerator = "invalid" )
0 commit comments