Skip to content

Commit 0dfb3d2

Browse files
authored
Support individual setup of model and optimizer in Lite (#15185)
1 parent 32e319a commit 0dfb3d2

File tree

8 files changed

+315
-40
lines changed

8 files changed

+315
-40
lines changed

src/lightning_lite/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
-
1818

19-
-
19+
- Added `LightningLite.setup_module()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185))
2020

2121

2222
### Changed

src/lightning_lite/lite.py

Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,14 @@
3131
from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
3232
from lightning_lite.accelerators.accelerator import Accelerator
3333
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
34-
from lightning_lite.strategies import DeepSpeedStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
34+
from lightning_lite.strategies import (
35+
DDPShardedStrategy,
36+
DDPSpawnShardedStrategy,
37+
DeepSpeedStrategy,
38+
SingleDeviceStrategy,
39+
Strategy,
40+
XLAStrategy,
41+
)
3542
from lightning_lite.strategies.strategy import _Sharded, TBroadcast
3643
from lightning_lite.utilities import move_data_to_device
3744
from lightning_lite.utilities.apply_func import convert_to_tensors
@@ -139,42 +146,100 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
139146

140147
def setup(
141148
self,
142-
model: nn.Module,
149+
module: nn.Module,
143150
*optimizers: Optimizer,
144151
move_to_device: bool = True,
145152
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
146153
"""Set up a model and its optimizers for accelerated training.
147154
148155
Args:
149-
model: A model to set up
156+
module: A :class:`torch.nn.Module` to set up
150157
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
151158
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
152159
and alternatively use :meth:`to_device` manually.
153160
154161
Returns:
155-
The tuple of the wrapped model and list of optimizers, in the same order they were passed in.
162+
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
156163
"""
157-
self._validate_setup(model, optimizers)
158-
original_model = model
164+
self._validate_setup(module, optimizers)
165+
original_module = module
159166

160-
model = self._precision.convert_module(model)
167+
module = self._precision.convert_module(module)
161168

162169
if move_to_device:
163-
model = self._move_model_to_device(model=model, optimizers=list(optimizers))
170+
module = self._move_model_to_device(model=module, optimizers=list(optimizers))
164171

165172
# Let accelerator/plugin wrap and connect the models and optimizers
166-
model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers))
167-
model = _LiteModule(model, self._precision, original_module=original_model)
173+
if optimizers:
174+
module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
175+
module, list(optimizers)
176+
)
177+
else:
178+
module = self._strategy.setup_module(module)
179+
180+
module = _LiteModule(module, self._precision, original_module=original_module)
168181

169182
# Update the _DeviceDtypeModuleMixin's device parameter
170-
model.to(self.device if move_to_device else next(model.parameters()).device)
183+
module.to(self.device if move_to_device else next(module.parameters()).device)
171184

172185
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
186+
173187
self._models_setup += 1
188+
174189
if optimizers:
175-
# join both types in a list for API convenience
176-
return [model] + optimizers
177-
return model
190+
# join both types in a tuple for API convenience
191+
return tuple((module, *optimizers))
192+
return module
193+
194+
def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _LiteModule:
195+
"""Set up a model for accelerated training or inference.
196+
197+
This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain
198+
strategies like `FSDP` that require setting up the module before the optimizer can be created and set up.
199+
See also :meth:`setup_optimizers`.
200+
201+
Args:
202+
module: A :class:`torch.nn.Module` to set up
203+
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
204+
and alternatively use :meth:`to_device` manually.
205+
206+
Returns:
207+
The wrapped model.
208+
"""
209+
self._validate_setup_module(module)
210+
original_module = module
211+
212+
module = self._precision.convert_module(module)
213+
214+
if move_to_device:
215+
module = self._move_model_to_device(model=module, optimizers=[])
216+
217+
# Let strategy wrap and connect the module alone
218+
module = self._strategy.setup_module(module)
219+
module = _LiteModule(module, self._precision, original_module=original_module)
220+
221+
# Update the _DeviceDtypeModuleMixin's device parameter
222+
module.to(self.device if move_to_device else next(module.parameters()).device)
223+
224+
self._models_setup += 1
225+
return module
226+
227+
def setup_optimizers(self, *optimizers: Optimizer) -> Union[_LiteOptimizer, Tuple[_LiteOptimizer, ...]]:
228+
"""Set up one or more optimizers for accelerated training.
229+
230+
Some strategies do not allow setting up model and optimizer independently. For them, you should call
231+
``.setup(model, optimizer, ...)`` instead to jointly set them up.
232+
233+
Args:
234+
*optimizers: One or more optmizers to set up.
235+
236+
Returns:
237+
The wrapped optimizer(s).
238+
"""
239+
self._validate_setup_optimizers(optimizers)
240+
optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
241+
optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
242+
return optimizers[0] if len(optimizers) == 1 else tuple(optimizers)
178243

179244
def setup_dataloaders(
180245
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
@@ -529,17 +594,44 @@ def _prepare_run_method(self) -> None:
529594
setattr(self, "run", partial(self._run_impl, self.run))
530595

531596
@staticmethod
532-
def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None:
533-
if isinstance(model, _LiteModule):
597+
def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None:
598+
if isinstance(module, _LiteModule):
534599
raise ValueError("A model should be passed only once to the `setup` method.")
535600

536601
if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
537602
raise ValueError("An optimizer should be passed only once to the `setup` method.")
538603

604+
def _validate_setup_module(self, module: nn.Module) -> None:
605+
if isinstance(module, _LiteModule):
606+
raise ValueError("A model should be passed only once to the `setup_module` method.")
607+
608+
if isinstance(self._strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
609+
raise RuntimeError(
610+
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
611+
" through `.setup(model, optimizer, ...)`. For inference, choose a different strategy, for example"
612+
" `ddp`."
613+
)
614+
615+
def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None:
616+
if isinstance(self._strategy, (DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy)):
617+
raise RuntimeError(
618+
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
619+
" through `.setup(model, optimizer, ...)`."
620+
)
621+
622+
if not optimizers:
623+
raise ValueError("`setup_optimizers` requires at least one optimizer as input.")
624+
625+
if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
626+
raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.")
627+
539628
@staticmethod
540629
def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
630+
if not dataloaders:
631+
raise ValueError("`setup_dataloaders` requires at least one dataloader as input.")
632+
541633
if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders):
542-
raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method")
634+
raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method.")
543635

544636
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
545637
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")

src/lightning_lite/strategies/deepspeed.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from lightning_lite.utilities.enums import AMPType, PrecisionType
3636
from lightning_lite.utilities.rank_zero import rank_zero_info
3737
from lightning_lite.utilities.seed import reset_seed
38-
from lightning_lite.utilities.types import _LRScheduler, _PATH, ReduceLROnPlateau
38+
from lightning_lite.utilities.types import _PATH
3939

4040
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
4141
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
@@ -305,11 +305,11 @@ def model(self) -> "deepspeed.DeepSpeedEngine":
305305
return self._deepspeed_engine
306306

307307
def setup_module_and_optimizers(
308-
self, model: Module, optimizers: List[Optimizer]
308+
self, module: Module, optimizers: List[Optimizer]
309309
) -> Tuple["deepspeed.DeepSpeedEngine", List[Optimizer]]:
310-
"""Setup a model and multiple optimizers together.
310+
"""Set up a model and multiple optimizers together.
311311
312-
Currently only a single optimizer is supported.
312+
Currently, only a single optimizer is supported.
313313
314314
Return:
315315
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
@@ -321,10 +321,25 @@ def setup_module_and_optimizers(
321321
f" Got {len(optimizers)} optimizers instead."
322322
)
323323

324-
self._deepspeed_engine, optimizer = self._setup_module_and_optimizer(model, optimizers[0])
324+
self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0])
325325
self._set_deepspeed_activation_checkpointing()
326326
return self._deepspeed_engine, [optimizer]
327327

328+
def setup_module(self, module: Module) -> "deepspeed.DeepSpeedEngine":
329+
"""Set up a module for inference (no optimizers).
330+
331+
For training, see :meth:`setup_module_and_optimizers`.
332+
"""
333+
self._deepspeed_engine, _ = self._initialize_engine(module)
334+
return self._deepspeed_engine
335+
336+
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
337+
"""Optimizers can only be set up jointly with the model in this strategy.
338+
339+
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together.
340+
"""
341+
raise NotImplementedError(self._err_msg_joint_setup_required())
342+
328343
@contextmanager
329344
def module_sharded_context(self) -> Generator[None, None, None]:
330345
# Current limitation in Lite: The config needs to be fully determined at the time of calling the
@@ -401,11 +416,10 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
401416
offload_optimizer_device="nvme",
402417
)
403418

404-
def _setup_module_and_optimizer(
419+
def _initialize_engine(
405420
self,
406421
model: Module,
407-
optimizer: Optional[Optimizer],
408-
lr_scheduler: Optional[Union[_LRScheduler, ReduceLROnPlateau]] = None,
422+
optimizer: Optional[Optimizer] = None,
409423
) -> Tuple["deepspeed.DeepSpeedEngine", Optimizer]:
410424
"""Initialize one model and one optimizer with an optional learning rate scheduler.
411425
@@ -420,7 +434,6 @@ def _setup_module_and_optimizer(
420434
model=model,
421435
model_parameters=model_parameters,
422436
optimizer=optimizer,
423-
lr_scheduler=lr_scheduler,
424437
dist_init_required=False,
425438
)
426439
return deepspeed_engine, deepspeed_optimizer

src/lightning_lite/strategies/fairscale.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from lightning_utilities.core.imports import module_available
2020
from torch.nn import Module
21+
from torch.nn.parallel import DistributedDataParallel
2122
from torch.optim import Optimizer
2223

2324
from lightning_lite.accelerators import Accelerator
@@ -89,6 +90,20 @@ def setup_module_and_optimizers(
8990
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
9091
return model, optimizers
9192

93+
def setup_module(self, module: Module) -> DistributedDataParallel:
94+
"""Setting up the module without optimizers in this strategy is not supported.
95+
96+
Please use :meth:`setup_module_and_optimizers` instead.
97+
"""
98+
raise NotImplementedError(self._err_msg_joint_setup_required())
99+
100+
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
101+
"""Optimizers can only be set up jointly with the model in this strategy.
102+
103+
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together.
104+
"""
105+
raise NotImplementedError(self._err_msg_joint_setup_required())
106+
92107
@classmethod
93108
def register_strategies(cls, strategy_registry: Dict) -> None:
94109
strategy_registry.register(
@@ -153,6 +168,20 @@ def setup_module_and_optimizers(
153168
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
154169
return model, optimizers
155170

171+
def setup_module(self, module: Module) -> DistributedDataParallel:
172+
"""Setting up the module without optimizers in this strategy is not supported.
173+
174+
Please use :meth:`setup_module_and_optimizers` instead.
175+
"""
176+
raise NotImplementedError(self._err_msg_joint_setup_required())
177+
178+
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
179+
"""Optimizers can only be set up jointly with the model in this strategy.
180+
181+
Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer(s) together.
182+
"""
183+
raise NotImplementedError(self._err_msg_joint_setup_required())
184+
156185
@classmethod
157186
def register_strategies(cls, strategy_registry: Dict) -> None:
158187
strategy_registry.register(

src/lightning_lite/strategies/strategy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def setup_module_and_optimizers(
118118
"""Set up a model and multiple optimizers together.
119119
120120
The returned objects are expected to be in the same order they were passed in. The default implementation will
121-
call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs.
121+
call :meth:`setup_module` and :meth:`setup_optimizer` on the inputs.
122122
"""
123123
module = self.setup_module(module)
124124
optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers]
@@ -288,6 +288,12 @@ def teardown(self) -> None:
288288
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
289289
pass
290290

291+
def _err_msg_joint_setup_required(self) -> str:
292+
return (
293+
f"The `{type(self).__name__}` does not support setting up the module and optimizer(s) independently."
294+
" Please call `setup_module_and_optimizers(model, [optimizer, ...])` to jointly set them up."
295+
)
296+
291297

292298
class _BackwardSyncControl(ABC):
293299
"""Interface for any :class:`Strategy` that wants to offer a functionality to enable or disable gradient

tests/tests_lite/strategies/test_deepspeed.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
# limitations under the License.
1414
import json
1515
import os
16+
from re import escape
17+
from unittest import mock
18+
from unittest.mock import ANY, Mock
1619

1720
import pytest
21+
import torch
1822
from tests_lite.helpers.runif import RunIf
1923

2024
from lightning_lite.accelerators import CPUAccelerator
@@ -116,3 +120,34 @@ def test_deepspeed_config_zero_offload(deepspeed_zero_config):
116120
deepspeed_zero_config["zero_optimization"]["offload_optimizer"] = False
117121
strategy = DeepSpeedStrategy(config=deepspeed_zero_config)
118122
assert strategy.config["zero_optimization"]["offload_optimizer"] is False
123+
124+
125+
@RunIf(deepspeed=True)
126+
@mock.patch("lightning_lite.strategies.deepspeed.deepspeed.initialize")
127+
def test_deepspeed_setup_module(init_mock):
128+
"""Test that the DeepSpeed strategy can set up the model for inference (no optimizer required)."""
129+
model = Mock()
130+
model.parameters.return_value = []
131+
strategy = DeepSpeedStrategy()
132+
strategy.parallel_devices = [torch.device("cuda", 1)]
133+
init_mock.return_value = [Mock()] * 4 # mock to make tuple unpacking work
134+
135+
strategy.setup_module(model)
136+
init_mock.assert_called_with(
137+
args=ANY,
138+
config=strategy.config,
139+
model=model,
140+
model_parameters=ANY,
141+
optimizer=None,
142+
dist_init_required=False,
143+
)
144+
145+
146+
@RunIf(deepspeed=True)
147+
def test_deepspeed_requires_joint_setup():
148+
"""Test that the DeepSpeed strategy does not support setting up model and optimizer independently."""
149+
strategy = DeepSpeedStrategy()
150+
with pytest.raises(
151+
NotImplementedError, match=escape("does not support setting up the module and optimizer(s) independently")
152+
):
153+
strategy.setup_optimizer(Mock())

0 commit comments

Comments
 (0)