Skip to content
Merged
6 changes: 3 additions & 3 deletions docs/source-pytorch/accelerators/tpu_faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ How to resolve the replication issue?
.format(len(local_devices), len(kind_devices)))
RuntimeError: Cannot replicate if number of devices (1) is different from 8

This error is raised when the XLA device is called outside the spawn process. Internally in `TPUSpawn` Strategy for training on multiple tpu cores, we use XLA's `xmp.spawn`.
This error is raised when the XLA device is called outside the spawn process. Internally in the XLA-Strategy for training on multiple tpu cores, we use XLA's `xmp.spawn`.
Don't use ``xm.xla_device()`` while working on Lightning + TPUs!

----
Expand Down Expand Up @@ -91,7 +91,7 @@ How to setup the debug mode for Training on TPUs?
import pytorch_lightning as pl

my_model = MyLightningModule()
trainer = pl.Trainer(accelerator="tpu", devices=8, strategy="tpu_spawn_debug")
trainer = pl.Trainer(accelerator="tpu", devices=8, strategy="xla_debug")
trainer.fit(my_model)

Example Metrics report:
Expand All @@ -108,7 +108,7 @@ Example Metrics report:

A lot of PyTorch operations aren't lowered to XLA, which could lead to significant slowdown of the training process.
These operations are moved to the CPU memory and evaluated, and then the results are transferred back to the XLA device(s).
By using the `tpu_spawn_debug` Strategy, users could create a metrics report to diagnose issues.
By using the `xla_debug` Strategy, users could create a metrics report to diagnose issues.

The report includes things like (`XLA Reference <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#troubleshooting>`_):

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/advanced/strategy_registry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ It also returns the optional description and parameters for initialising the Str
trainer = Trainer(strategy="deepspeed_stage_3_offload", accelerator="gpu", devices=3)

# Training with the TPU Spawn Strategy with `debug` as True
trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8)
trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8)


Additionally, you can pass your custom registered training strategies to the ``strategy`` argument.
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ strategies
SingleHPUStrategy
SingleTPUStrategy
Strategy
TPUSpawnStrategy
XLAStrategy

tuner
-----
Expand Down
4 changes: 2 additions & 2 deletions docs/source-pytorch/extensions/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ The below table lists all relevant strategies available in Lightning with their
* - ipu_strategy
- :class:`~pytorch_lightning.strategies.IPUStrategy`
- Plugin for training on IPU devices. :doc:`Learn more. <../accelerators/ipu>`
* - tpu_spawn
- :class:`~pytorch_lightning.strategies.TPUSpawnStrategy`
* - xla
- :class:`~pytorch_lightning.strategies.XLAStrategy`
- Strategy for training on multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method. :doc:`Learn more. <../accelerators/tpu>`
* - single_tpu
- :class:`~pytorch_lightning.strategies.SingleTPUStrategy`
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- `DataParallelStrategy.get_module_state_dict()` and `DDPStrategy.get_module_state_dict()` now correctly extracts the state dict without keys prefixed with 'module' ([#16487](https://github.com/Lightning-AI/lightning/pull/16487))


- "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490))
* `strategy="fsdp_full_shard_offload"` is now `strategy="fsdp_cpu_offload"`
* `lightning.fabric.plugins.precision.native_amp` is now `lightning.fabric.plugins.precision.amp`


- Enabled all shorthand strategy names that can be supported in the CLI ([#16485](https://github.com/Lightning-AI/lightning/pull/16485))

- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781))


### Deprecated

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
def _choose_strategy(self) -> Union[Strategy, str]:
if self._accelerator_flag == "tpu":
if self._parallel_devices and len(self._parallel_devices) > 1:
return "tpu_spawn"
return "xla"
else:
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore
Expand Down
2 changes: 0 additions & 2 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ def remove_checkpoint(self, filepath: _PATH) -> None:

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
# TODO(fabric): Deprecate the name "tpu_spawn" through the connector
strategy_registry.register("tpu_spawn", cls, description=cls.__class__.__name__)
strategy_registry.register("xla", cls, description=cls.__class__.__name__)

def _set_world_ranks(self) -> None:
Expand Down
6 changes: 6 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The `dataloader_idx` argument is now optional for the `on_{validation,test,predict}_batch_{start,end}` hooks. Remove it or default it to 0 if you don't use multiple dataloaders ([#16753](https://github.com/Lightning-AI/lightning/pull/16753))


- Renamed `TPUSpawnStrategy` to `XLAStrategy` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781))

- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781))


### Deprecated

-
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from lightning.pytorch.strategies.single_hpu import SingleHPUStrategy # noqa: F401
from lightning.pytorch.strategies.single_tpu import SingleTPUStrategy # noqa: F401
from lightning.pytorch.strategies.strategy import Strategy # noqa: F401
from lightning.pytorch.strategies.tpu_spawn import TPUSpawnStrategy # noqa: F401
from lightning.pytorch.strategies.utils import _call_register_strategies
from lightning.pytorch.strategies.xla import XLAStrategy # noqa: F401

_STRATEGIES_BASE_MODULE = "lightning.pytorch.strategies"
StrategyRegistry = _StrategyRegistry()
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class _XLALauncher(_MultiProcessingLauncher):
strategy: A reference to the strategy that is used together with this launcher
"""

def __init__(self, strategy: "pl.strategies.TPUSpawnStrategy") -> None:
def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(strategy=strategy, start_method="fork")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
MpDeviceLoader = None


class TPUSpawnStrategy(DDPSpawnStrategy):
class XLAStrategy(DDPSpawnStrategy):
"""Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn`
method."""

strategy_name = "tpu_spawn"
strategy_name = "xla"

def __init__(
self,
Expand Down Expand Up @@ -143,7 +143,7 @@ def is_distributed(self) -> bool:
return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1

def process_dataloader(self, dataloader: Iterable) -> "MpDeviceLoader":
TPUSpawnStrategy._validate_dataloader(dataloader)
XLAStrategy._validate_dataloader(dataloader)
from torch_xla.distributed.parallel_loader import MpDeviceLoader

if isinstance(dataloader, MpDeviceLoader):
Expand Down Expand Up @@ -192,7 +192,7 @@ def reduce(
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if invalid_reduce_op or invalid_reduce_op_str:
raise ValueError(
"Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
"Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
f" {reduce_op}"
)

Expand Down Expand Up @@ -293,10 +293,7 @@ def teardown(self) -> None:

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"tpu_spawn_debug", cls, description="TPUSpawn Strategy with `debug` as True", debug=True
)

strategy_registry.register("xla_debug", cls, description="XLA strategy with `debug` as True", debug=True)
strategy_registry.register(
cls.strategy_name,
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
SingleTPUStrategy,
Strategy,
StrategyRegistry,
TPUSpawnStrategy,
XLAStrategy,
)
from lightning.pytorch.strategies.ddp_spawn import _DDP_FORK_ALIASES
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -442,7 +442,7 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return SingleHPUStrategy(device=torch.device("hpu"))
if self._accelerator_flag == "tpu":
if self._parallel_devices and len(self._parallel_devices) > 1:
return TPUSpawnStrategy.strategy_name
return XLAStrategy.strategy_name
else:
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore
Expand Down Expand Up @@ -617,10 +617,10 @@ def _lazy_init_strategy(self) -> None:
# TODO: should be moved to _check_strategy_and_fallback().
# Current test check precision first, so keep this check here to meet error order
if isinstance(self.accelerator, TPUAccelerator) and not isinstance(
self.strategy, (SingleTPUStrategy, TPUSpawnStrategy)
self.strategy, (SingleTPUStrategy, XLAStrategy)
):
raise ValueError(
"The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy`,"
"The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `XLAStrategy`,"
f" found {self.strategy.__class__.__name__}."
)

Expand All @@ -644,7 +644,7 @@ def is_distributed(self) -> bool:
FSDPStrategy,
DDPSpawnStrategy,
DeepSpeedStrategy,
TPUSpawnStrategy,
XLAStrategy,
HPUParallelStrategy,
)
is_distributed = isinstance(self.strategy, distributed_strategy)
Expand Down
1 change: 0 additions & 1 deletion tests/tests_fabric/strategies/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test_available_strategies_in_registry():
"ddp_fork",
"ddp_notebook",
"single_tpu",
"tpu_spawn",
"xla",
"dp",
}
Expand Down
24 changes: 12 additions & 12 deletions tests/tests_pytorch/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from lightning.pytorch.accelerators.tpu import TPUAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.plugins import PrecisionPlugin, TPUPrecisionPlugin, XLACheckpointIO
from lightning.pytorch.strategies import DDPStrategy, TPUSpawnStrategy
from lightning.pytorch.strategies import DDPStrategy, XLAStrategy
from lightning.pytorch.utilities import find_shared_parameters
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.trainer.optimization.test_manual_optimization import assert_emtpy_grad
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_accelerator_tpu(accelerator, devices, tpu_available):

trainer = Trainer(accelerator=accelerator, devices=devices)
assert isinstance(trainer.accelerator, TPUAccelerator)
assert isinstance(trainer.strategy, TPUSpawnStrategy)
assert isinstance(trainer.strategy, XLAStrategy)
assert trainer.num_devices == 8


Expand Down Expand Up @@ -177,15 +177,15 @@ def test_strategy_choice_tpu_str_ddp_spawn(tpu_available):


@RunIf(skip_windows=True)
def test_strategy_choice_tpu_str_tpu_spawn_debug(tpu_available):
trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8)
assert isinstance(trainer.strategy, TPUSpawnStrategy)
def test_strategy_choice_tpu_str_xla_debug(tpu_available):
trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8)
assert isinstance(trainer.strategy, XLAStrategy)


@RunIf(tpu=True)
def test_strategy_choice_tpu_strategy():
trainer = Trainer(strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8)
assert isinstance(trainer.strategy, TPUSpawnStrategy)
trainer = Trainer(strategy=XLAStrategy(), accelerator="tpu", devices=8)
assert isinstance(trainer.strategy, XLAStrategy)


@RunIf(tpu=True)
Expand Down Expand Up @@ -237,7 +237,7 @@ def forward(self, x):


def test_tpu_invalid_raises(tpu_available):
strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin())
strategy = XLAStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin())
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
Trainer(strategy=strategy, devices=8)

Expand All @@ -248,14 +248,14 @@ def test_tpu_invalid_raises(tpu_available):

def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available):
accelerator = TPUAccelerator()
strategy = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin())
strategy = XLAStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin())
with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"):
Trainer(strategy=strategy, devices=8)

accelerator = TPUAccelerator()
strategy = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin())
with pytest.raises(
ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy"
ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `XLAStrategy"
):
Trainer(strategy=strategy, devices=8)

Expand All @@ -267,11 +267,11 @@ def test_xla_checkpoint_plugin_being_default(tpu_available):


@RunIf(tpu=True)
@patch("lightning.pytorch.strategies.tpu_spawn.TPUSpawnStrategy.root_device")
@patch("lightning.pytorch.strategies.xla.XLAStrategy.root_device")
def test_xla_mp_device_dataloader_attribute(_, monkeypatch):
dataset = RandomDataset(32, 64)
dataloader = DataLoader(dataset)
strategy = TPUSpawnStrategy()
strategy = XLAStrategy()
isinstance_return = True

import torch_xla.distributed.parallel_loader as parallel_loader
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def mps_count_4(monkeypatch):
@pytest.fixture(scope="function")
def xla_available(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(lightning.pytorch.accelerators.tpu, "_XLA_AVAILABLE", True)
monkeypatch.setattr(lightning.pytorch.strategies.tpu_spawn, "_XLA_AVAILABLE", True)
monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", True)
monkeypatch.setattr(lightning.pytorch.strategies.single_tpu, "_XLA_AVAILABLE", True)
monkeypatch.setattr(lightning.pytorch.plugins.precision.tpu, "_XLA_AVAILABLE", True)
monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", True)
Expand Down
10 changes: 5 additions & 5 deletions tests/tests_pytorch/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightning.pytorch.accelerators import TPUAccelerator
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.strategies import TPUSpawnStrategy
from lightning.pytorch.strategies import XLAStrategy
from lightning.pytorch.strategies.launchers.xla import _XLALauncher
from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -285,7 +285,7 @@ def wrap_launch_function(fn, strategy, *args, **kwargs):
def xla_launch(fn):
# TODO: the accelerator should be optional to just launch processes, but this requires lazy initialization
accelerator = TPUAccelerator()
strategy = TPUSpawnStrategy(accelerator=accelerator, parallel_devices=list(range(8)))
strategy = XLAStrategy(accelerator=accelerator, parallel_devices=list(range(8)))
launcher = _XLALauncher(strategy=strategy)
wrapped = partial(wrap_launch_function, fn, strategy)
return launcher.launch(wrapped, strategy)
Expand Down Expand Up @@ -325,7 +325,7 @@ def teardown(self, stage):
devices=8,
limit_train_batches=0.4,
limit_val_batches=0.4,
strategy=TPUSpawnStrategy(debug=True),
strategy=XLAStrategy(debug=True),
)

model = DebugModel()
Expand Down Expand Up @@ -359,6 +359,6 @@ def on_train_start(self):

@RunIf(tpu=True)
def test_device_type_when_tpu_strategy_passed(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8)
assert isinstance(trainer.strategy, TPUSpawnStrategy)
trainer = Trainer(default_root_dir=tmpdir, strategy=XLAStrategy(), accelerator="tpu", devices=8)
assert isinstance(trainer.strategy, XLAStrategy)
assert isinstance(trainer.accelerator, TPUAccelerator)
10 changes: 5 additions & 5 deletions tests/tests_pytorch/strategies/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DeepSpeedStrategy,
FSDPStrategy,
StrategyRegistry,
TPUSpawnStrategy,
XLAStrategy,
)
from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -54,15 +54,15 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy):


@RunIf(skip_windows=True)
def test_tpu_spawn_debug_strategy_registry(xla_available):
strategy = "tpu_spawn_debug"
def test_xla_debug_strategy_registry(xla_available):
strategy = "xla_debug"

assert strategy in StrategyRegistry
assert StrategyRegistry[strategy]["init_params"] == {"debug": True}
assert StrategyRegistry[strategy]["strategy"] == TPUSpawnStrategy
assert StrategyRegistry[strategy]["strategy"] == XLAStrategy

trainer = Trainer(strategy=strategy)
assert isinstance(trainer.strategy, TPUSpawnStrategy)
assert isinstance(trainer.strategy, XLAStrategy)


@RunIf(min_torch="1.12")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.strategies import TPUSpawnStrategy
from lightning.pytorch.strategies import XLAStrategy
from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader
from tests_pytorch.helpers.runif import RunIf

Expand All @@ -45,7 +45,7 @@ def predict_dataloader(self):


def test_error_process_iterable_dataloader(xla_available):
strategy = TPUSpawnStrategy(MagicMock())
strategy = XLAStrategy(MagicMock())
with pytest.raises(TypeError, match="TPUs do not currently support"):
strategy.process_dataloader(_loader_no_len)

Expand All @@ -60,9 +60,9 @@ def on_train_start(self) -> None:
@RunIf(tpu=True, standalone=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_model_tpu_one_core():
"""Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy."""
"""Tests if device/debug flag is set correctly when training and after teardown for XLAStrategy."""
model = BoringModelTPU()
trainer = Trainer(accelerator="tpu", devices=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True))
assert isinstance(trainer.strategy, TPUSpawnStrategy)
trainer = Trainer(accelerator="tpu", devices=1, fast_dev_run=True, strategy=XLAStrategy(debug=True))
assert isinstance(trainer.strategy, XLAStrategy)
trainer.fit(model)
assert "PT_XLA_DEBUG" not in os.environ
Loading