Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fabric now chooses `accelerator="auto", strategy="auto", devices="auto"` as defaults ([#16842](https://github.com/Lightning-AI/lightning/pull/16842))


- Checkpoint saving and loading redesign ([#16434](https://github.com/Lightning-AI/lightning/pull/16434))
* Changed the method signatrue of `Fabric.save` and `Fabric.load`
* Changed the method signature of `Strategy.save_checkpoint` and `Fabric.load_checkpoint`
Expand Down
48 changes: 19 additions & 29 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,18 @@ class _Connector:

def __init__(
self,
accelerator: Optional[Union[str, Accelerator]] = None,
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
devices: Union[List[int], str, int] = "auto",
num_nodes: int = 1,
precision: _PRECISION_INPUT = "32-true",
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
) -> None:

# These arguments can be set through environment variables set by the CLI
accelerator = self._argument_from_env("accelerator", accelerator, default=None)
strategy = self._argument_from_env("strategy", strategy, default=None)
devices = self._argument_from_env("devices", devices, default=None)
accelerator = self._argument_from_env("accelerator", accelerator, default="auto")
strategy = self._argument_from_env("strategy", strategy, default="auto")
devices = self._argument_from_env("devices", devices, default="auto")
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
precision = self._argument_from_env("precision", precision, default="32-true")

Expand All @@ -123,8 +123,8 @@ def __init__(
# Raise an exception if there are conflicts between flags
# Set each valid flag to `self._x_flag` after validation
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
self._strategy_flag: Optional[Union[Strategy, str]] = None
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
self._strategy_flag: Union[Strategy, str] = "auto"
self._accelerator_flag: Union[Accelerator, str] = "auto"
self._precision_input: _PRECISION_INPUT_STR = "32-true"
self._precision_instance: Optional[Precision] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
Expand All @@ -141,7 +141,7 @@ def __init__(

# 2. Instantiate Accelerator
# handle `auto`, `None` and `gpu`
if self._accelerator_flag == "auto" or self._accelerator_flag is None:
if self._accelerator_flag == "auto":
self._accelerator_flag = self._choose_auto_accelerator()
elif self._accelerator_flag == "gpu":
self._accelerator_flag = self._choose_gpu_accelerator_backend()
Expand All @@ -152,7 +152,7 @@ def __init__(
self.cluster_environment: ClusterEnvironment = self._choose_and_init_cluster_environment()

# 4. Instantiate Strategy - Part 1
if self._strategy_flag is None:
if self._strategy_flag == "auto":
self._strategy_flag = self._choose_strategy()
# In specific cases, ignore user selection and fall back to a different strategy
self._check_strategy_and_fallback()
Expand All @@ -166,8 +166,8 @@ def __init__(

def _check_config_and_set_final_flags(
self,
strategy: Optional[Union[str, Strategy]],
accelerator: Optional[Union[str, Accelerator]],
strategy: Union[str, Strategy],
accelerator: Union[str, Accelerator],
precision: _PRECISION_INPUT,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]],
) -> None:
Expand All @@ -188,10 +188,9 @@ def _check_config_and_set_final_flags(
if isinstance(strategy, str):
strategy = strategy.lower()

if strategy is not None:
self._strategy_flag = strategy
self._strategy_flag = strategy

if strategy is not None and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
if strategy != "auto" and strategy not in self._registered_strategies and not isinstance(strategy, Strategy):
raise ValueError(
f"You selected an invalid strategy name: `strategy={strategy!r}`."
" It must be either a string or an instance of `lightning.fabric.strategies.Strategy`."
Expand All @@ -200,8 +199,7 @@ def _check_config_and_set_final_flags(
)

if (
accelerator is not None
and accelerator not in self._registered_accelerators
accelerator not in self._registered_accelerators
and accelerator not in ("auto", "gpu")
and not isinstance(accelerator, Accelerator)
):
Expand Down Expand Up @@ -256,9 +254,9 @@ def _check_config_and_set_final_flags(
# handle the case when the user passes in a strategy instance which has an accelerator, precision,
# checkpoint io or cluster env set up
# TODO: improve the error messages below
if self._strategy_flag and isinstance(self._strategy_flag, Strategy):
if isinstance(self._strategy_flag, Strategy):
if self._strategy_flag._accelerator:
if self._accelerator_flag:
if self._accelerator_flag and self._accelerator_flag != "auto":
raise ValueError("accelerator set through both strategy class and accelerator flag, choose one")
else:
self._accelerator_flag = self._strategy_flag._accelerator
Expand Down Expand Up @@ -297,9 +295,7 @@ def _check_config_and_set_final_flags(
self._accelerator_flag = "cuda"
self._parallel_devices = self._strategy_flag.parallel_devices

def _check_device_config_and_set_final_flags(
self, devices: Optional[Union[List[int], str, int]], num_nodes: int
) -> None:
def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
self._devices_flag = devices

Expand All @@ -314,12 +310,6 @@ def _check_device_config_and_set_final_flags(
f" using {accelerator_name} accelerator."
)

if self._devices_flag == "auto" and self._accelerator_flag is None:
raise ValueError(
f"You passed `devices={devices}` but haven't specified"
" `accelerator=('auto'|'tpu'|'gpu'|'cpu'|'mps')` for the devices mapping."
)

def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
if self._accelerator_flag == "auto":
Expand Down Expand Up @@ -368,7 +358,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag)

def _set_devices_flag_if_auto_passed(self) -> None:
if self._devices_flag == "auto" or self._devices_flag is None:
if self._devices_flag == "auto":
self._devices_flag = self.accelerator.auto_device_count()

def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ class Fabric:

def __init__(
self,
accelerator: Optional[Union[str, Accelerator]] = None,
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
devices: Union[List[int], str, int] = "auto",
num_nodes: int = 1,
precision: _PRECISION_INPUT = "32-true",
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
Expand Down
21 changes: 15 additions & 6 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,26 @@ def reset_deterministic_algorithm():
torch.use_deterministic_algorithms(False)


def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value)


@pytest.fixture(scope="function")
def xla_available(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", True)
monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", True)
monkeypatch.setattr(lightning.fabric.strategies.xla, "_XLA_AVAILABLE", True)
monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", True)
mock_xla_available(monkeypatch)


def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
mock_xla_available(monkeypatch, value)
monkeypatch.setattr(lightning.fabric.accelerators.tpu.TPUAccelerator, "is_available", lambda: value)


@pytest.fixture(scope="function")
def tpu_available(xla_available, monkeypatch) -> None:
monkeypatch.setattr(lightning.fabric.accelerators.tpu.TPUAccelerator, "is_available", lambda: True)
def tpu_available(monkeypatch: pytest.MonkeyPatch) -> None:
mock_tpu_available(monkeypatch)


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def after_backward(self, model):
],
)
def test_amp(accelerator, precision, expected_dtype):
fabric = MixedPrecisionBoringFabric(accelerator=accelerator, precision=precision)
# TODO: devices>1 fails with:
# DDP expects same model across all ranks, but Rank 0 has 2 params, while rank 1 has inconsistent 1 params
fabric = MixedPrecisionBoringFabric(accelerator=accelerator, precision=precision, devices=1)
fabric.expected_dtype = expected_dtype
fabric.run()

Expand Down
92 changes: 84 additions & 8 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# limitations under the License
import inspect
import os
import sys
from typing import Any, Dict
from unittest import mock
from unittest.mock import Mock

import pytest
import torch
import torch.distributed
from lightning_utilities.test.warning import no_warning_call
from tests_fabric.conftest import mock_tpu_available
from tests_fabric.helpers.runif import RunIf

import lightning.fabric
Expand Down Expand Up @@ -51,6 +54,7 @@
from lightning.fabric.strategies.ddp import _DDP_FORK_ALIASES
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12


def test_accelerator_choice_cpu():
Expand All @@ -61,15 +65,13 @@ def test_accelerator_choice_cpu():

@RunIf(tpu=True, standalone=True)
@pytest.mark.parametrize(
["accelerator", "devices"], [("tpu", None), ("tpu", 1), ("tpu", [1]), ("tpu", 8), ("auto", 1), ("auto", 8)]
["accelerator", "devices"], [("tpu", "auto"), ("tpu", 1), ("tpu", [1]), ("tpu", 8), ("auto", 1), ("auto", 8)]
)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_accelerator_choice_tpu(accelerator, devices):
connector = _Connector(accelerator=accelerator, devices=devices)
assert isinstance(connector.accelerator, TPUAccelerator)
if devices is None or (isinstance(devices, int) and devices > 1):
# accelerator=tpu, devices=None (default) maps to devices=auto (8) and then chooses XLAStrategy
# This behavior may change in the future: https://github.com/Lightning-AI/lightning/issues/10606
if devices == "auto" or (isinstance(devices, int) and devices > 1):
assert isinstance(connector.strategy, XLAStrategy)
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
assert isinstance(connector.cluster_environment, XLAEnvironment)
Expand Down Expand Up @@ -248,7 +250,7 @@ def test_interactive_incompatible_backend_error(_, monkeypatch):

with pytest.raises(RuntimeError, match=r"strategy='ddp'\)`.*is not compatible"):
# Edge case: _Connector maps dp to ddp if accelerator != gpu
_Connector(strategy="dp")
_Connector(strategy="dp", accelerator="cpu")


@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
Expand Down Expand Up @@ -476,7 +478,7 @@ def test_precision_conversion(patch1, patch2, precision, expected_precision, sho

def test_multi_device_default_strategy():
"""The default strategy when multiple devices are selected is "ddp" with the subprocess launcher."""
connector = _Connector(strategy=None, accelerator="cpu", devices=2)
connector = _Connector(strategy="auto", accelerator="cpu", devices=2)
assert isinstance(connector.accelerator, CPUAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert connector.strategy._start_method == "popen"
Expand Down Expand Up @@ -537,7 +539,7 @@ def test_strategy_choice_ddp_spawn(*_):

@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
@pytest.mark.parametrize("job_name,expected_env", [("some_name", SLURMEnvironment), ("bash", LightningEnvironment)])
@pytest.mark.parametrize("strategy", [None, "ddp", DDPStrategy])
@pytest.mark.parametrize("strategy", ["auto", "ddp", DDPStrategy])
def test_strategy_choice_ddp_slurm(_, strategy, job_name, expected_env):
if strategy and not isinstance(strategy, str):
strategy = strategy()
Expand Down Expand Up @@ -636,7 +638,7 @@ def test_strategy_choice_ddp_cpu_kubeflow():
"SLURM_LOCALID": "0",
},
)
@pytest.mark.parametrize("strategy", [None, "ddp", DDPStrategy()])
@pytest.mark.parametrize("strategy", ["auto", "ddp", DDPStrategy()])
def test_strategy_choice_ddp_cpu_slurm(strategy):
connector = _Connector(strategy=strategy, accelerator="cpu", devices=2)
assert isinstance(connector.accelerator, CPUAccelerator)
Expand Down Expand Up @@ -907,3 +909,77 @@ def get_defaults(cls):
# defaults should match on the intersection of argument names
for name, connector_default in connector_defaults.items():
assert connector_default == fabric_defaults[name]


def test_connector_auto_selection(monkeypatch):
no_cuda = mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0)
single_cuda = mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1)
multi_cuda = mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=4)
no_mps = mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
single_mps = mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=True)

# CPU
with no_cuda, no_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, False)
connector = _Connector()
assert isinstance(connector.accelerator, CPUAccelerator)
assert isinstance(connector.strategy, SingleDeviceStrategy)
assert connector._devices_flag == 1

# single CUDA
with single_cuda, no_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, False)
connector = _Connector()
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, SingleDeviceStrategy)
assert connector._devices_flag == [0]

# multi CUDA
with multi_cuda, no_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, False)
connector = _Connector()
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert connector._devices_flag == list(range(4))

# MPS (there's no distributed)
with no_cuda, single_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, False)
if not _TORCH_GREATER_EQUAL_1_12:
monkeypatch.setattr(torch, "device", Mock())
connector = _Connector()
assert isinstance(connector.accelerator, MPSAccelerator)
assert isinstance(connector.strategy, SingleDeviceStrategy)
assert connector._devices_flag == [0]

# single TPU
with no_cuda, no_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, True)
# TPUAccelerator.auto_device_count always returns 8, but in case this changes in the future...
monkeypatch.setattr(lightning.fabric.accelerators.TPUAccelerator, "auto_device_count", lambda *_: 1)
monkeypatch.setitem(sys.modules, "torch_xla", Mock())
monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock())
monkeypatch.setattr(torch, "device", Mock())
connector = _Connector()
assert isinstance(connector.accelerator, TPUAccelerator)
assert isinstance(connector.strategy, SingleTPUStrategy)
assert connector._devices_flag == 1

monkeypatch.undo() # for some reason `.context()` is not working properly
assert lightning.fabric.accelerators.TPUAccelerator.auto_device_count() == 8

# Multi TPU
with no_cuda, no_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, True)
connector = _Connector()
assert isinstance(connector.accelerator, TPUAccelerator)
assert isinstance(connector.strategy, XLAStrategy)
assert connector._devices_flag == 8

# TPU and CUDA: prefers TPU
with multi_cuda, no_mps, monkeypatch.context():
mock_tpu_available(monkeypatch, True)
connector = _Connector()
assert isinstance(connector.accelerator, TPUAccelerator)
assert isinstance(connector.strategy, XLAStrategy)
assert connector._devices_flag == 8
16 changes: 8 additions & 8 deletions tests/tests_fabric/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,23 @@ def precision_context(precision, accelerator) -> Generator[None, None, None]:


@pytest.mark.parametrize(
"precision, strategy, devices, accelerator",
"precision, accelerator",
[
pytest.param(32, None, 1, "cpu"),
pytest.param(32, None, 1, "gpu", marks=RunIf(min_cuda_gpus=1)),
pytest.param(16, None, 1, "gpu", marks=RunIf(min_cuda_gpus=1)),
pytest.param("bf16", None, 1, "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)),
pytest.param(32, None, 1, "mps", marks=RunIf(mps=True)),
(32, "cpu"),
pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)),
pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)),
pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)),
pytest.param(32, "mps", marks=RunIf(mps=True)),
],
)
def test_boring_fabric_model_single_device(precision, strategy, devices, accelerator, tmpdir):
def test_boring_fabric_model_single_device(precision, accelerator):
Fabric.seed_everything(42)
train_dataloader = DataLoader(RandomDataset(32, 8))
model = BoringModel()
num_epochs = 1
state_dict = deepcopy(model.state_dict())

fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator)
fabric = FabricRunner(precision=precision, accelerator=accelerator)
fabric.run(model, train_dataloader, num_epochs=num_epochs)
fabric_state_dict = model.state_dict()

Expand Down
Loading