|
31 | 31 | from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
|
32 | 32 | from lightning_lite.accelerators.accelerator import Accelerator
|
33 | 33 | from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
|
34 |
| -from lightning_lite.strategies import DeepSpeedStrategy, SingleDeviceStrategy, Strategy, XLAStrategy |
| 34 | +from lightning_lite.strategies import ( |
| 35 | + DDPShardedStrategy, |
| 36 | + DDPSpawnShardedStrategy, |
| 37 | + DeepSpeedStrategy, |
| 38 | + SingleDeviceStrategy, |
| 39 | + Strategy, |
| 40 | + XLAStrategy, |
| 41 | +) |
35 | 42 | from lightning_lite.strategies.strategy import _Sharded, TBroadcast
|
36 | 43 | from lightning_lite.utilities import move_data_to_device
|
37 | 44 | from lightning_lite.utilities.apply_func import convert_to_tensors
|
@@ -139,42 +146,100 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
|
139 | 146 |
|
140 | 147 | def setup(
|
141 | 148 | self,
|
142 |
| - model: nn.Module, |
| 149 | + module: nn.Module, |
143 | 150 | *optimizers: Optimizer,
|
144 | 151 | move_to_device: bool = True,
|
145 | 152 | ) -> Any: # no specific return because the way we want our API to look does not play well with mypy
|
146 | 153 | """Set up a model and its optimizers for accelerated training.
|
147 | 154 |
|
148 | 155 | Args:
|
149 |
| - model: A model to set up |
| 156 | + module: A :class:`torch.nn.Module` to set up |
150 | 157 | *optimizers: The optimizer(s) to set up (no optimizers is also possible)
|
151 | 158 | move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
|
152 | 159 | and alternatively use :meth:`to_device` manually.
|
153 | 160 |
|
154 | 161 | Returns:
|
155 |
| - The tuple of the wrapped model and list of optimizers, in the same order they were passed in. |
| 162 | + The tuple containing wrapped module and the optimizers, in the same order they were passed in. |
156 | 163 | """
|
157 |
| - self._validate_setup(model, optimizers) |
158 |
| - original_model = model |
| 164 | + self._validate_setup(module, optimizers) |
| 165 | + original_module = module |
159 | 166 |
|
160 |
| - model = self._precision.convert_module(model) |
| 167 | + module = self._precision.convert_module(module) |
161 | 168 |
|
162 | 169 | if move_to_device:
|
163 |
| - model = self._move_model_to_device(model=model, optimizers=list(optimizers)) |
| 170 | + module = self._move_model_to_device(model=module, optimizers=list(optimizers)) |
164 | 171 |
|
165 | 172 | # Let accelerator/plugin wrap and connect the models and optimizers
|
166 |
| - model, optimizers = self._strategy.setup_module_and_optimizers(model, list(optimizers)) |
167 |
| - model = _LiteModule(model, self._precision, original_module=original_model) |
| 173 | + if optimizers: |
| 174 | + module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment] |
| 175 | + module, list(optimizers) |
| 176 | + ) |
| 177 | + else: |
| 178 | + module = self._strategy.setup_module(module) |
| 179 | + |
| 180 | + module = _LiteModule(module, self._precision, original_module=original_module) |
168 | 181 |
|
169 | 182 | # Update the _DeviceDtypeModuleMixin's device parameter
|
170 |
| - model.to(self.device if move_to_device else next(model.parameters()).device) |
| 183 | + module.to(self.device if move_to_device else next(module.parameters()).device) |
171 | 184 |
|
172 | 185 | optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
|
| 186 | + |
173 | 187 | self._models_setup += 1
|
| 188 | + |
174 | 189 | if optimizers:
|
175 |
| - # join both types in a list for API convenience |
176 |
| - return [model] + optimizers |
177 |
| - return model |
| 190 | + # join both types in a tuple for API convenience |
| 191 | + return tuple((module, *optimizers)) |
| 192 | + return module |
| 193 | + |
| 194 | + def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _LiteModule: |
| 195 | + """Set up a model for accelerated training or inference. |
| 196 | +
|
| 197 | + This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain |
| 198 | + strategies like `FSDP` that require setting up the module before the optimizer can be created and set up. |
| 199 | + See also :meth:`setup_optimizers`. |
| 200 | +
|
| 201 | + Args: |
| 202 | + module: A :class:`torch.nn.Module` to set up |
| 203 | + move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` |
| 204 | + and alternatively use :meth:`to_device` manually. |
| 205 | +
|
| 206 | + Returns: |
| 207 | + The wrapped model. |
| 208 | + """ |
| 209 | + self._validate_setup_module(module) |
| 210 | + original_module = module |
| 211 | + |
| 212 | + module = self._precision.convert_module(module) |
| 213 | + |
| 214 | + if move_to_device: |
| 215 | + module = self._move_model_to_device(model=module, optimizers=[]) |
| 216 | + |
| 217 | + # Let strategy wrap and connect the module alone |
| 218 | + module = self._strategy.setup_module(module) |
| 219 | + module = _LiteModule(module, self._precision, original_module=original_module) |
| 220 | + |
| 221 | + # Update the _DeviceDtypeModuleMixin's device parameter |
| 222 | + module.to(self.device if move_to_device else next(module.parameters()).device) |
| 223 | + |
| 224 | + self._models_setup += 1 |
| 225 | + return module |
| 226 | + |
| 227 | + def setup_optimizers(self, *optimizers: Optimizer) -> Union[_LiteOptimizer, Tuple[_LiteOptimizer, ...]]: |
| 228 | + """Set up one or more optimizers for accelerated training. |
| 229 | +
|
| 230 | + Some strategies do not allow setting up model and optimizer independently. For them, you should call |
| 231 | + ``.setup(model, optimizer, ...)`` instead to jointly set them up. |
| 232 | +
|
| 233 | + Args: |
| 234 | + *optimizers: One or more optmizers to set up. |
| 235 | +
|
| 236 | + Returns: |
| 237 | + The wrapped optimizer(s). |
| 238 | + """ |
| 239 | + self._validate_setup_optimizers(optimizers) |
| 240 | + optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] |
| 241 | + optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] |
| 242 | + return optimizers[0] if len(optimizers) == 1 else tuple(optimizers) |
178 | 243 |
|
179 | 244 | def setup_dataloaders(
|
180 | 245 | self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
|
@@ -529,17 +594,44 @@ def _prepare_run_method(self) -> None:
|
529 | 594 | setattr(self, "run", partial(self._run_impl, self.run))
|
530 | 595 |
|
531 | 596 | @staticmethod
|
532 |
| - def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None: |
533 |
| - if isinstance(model, _LiteModule): |
| 597 | + def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None: |
| 598 | + if isinstance(module, _LiteModule): |
534 | 599 | raise ValueError("A model should be passed only once to the `setup` method.")
|
535 | 600 |
|
536 | 601 | if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
|
537 | 602 | raise ValueError("An optimizer should be passed only once to the `setup` method.")
|
538 | 603 |
|
| 604 | + def _validate_setup_module(self, module: nn.Module) -> None: |
| 605 | + if isinstance(module, _LiteModule): |
| 606 | + raise ValueError("A model should be passed only once to the `setup_module` method.") |
| 607 | + |
| 608 | + if isinstance(self._strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): |
| 609 | + raise RuntimeError( |
| 610 | + f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly" |
| 611 | + " through `.setup(model, optimizer, ...)`. For inference, choose a different strategy, for example" |
| 612 | + " `ddp`." |
| 613 | + ) |
| 614 | + |
| 615 | + def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None: |
| 616 | + if isinstance(self._strategy, (DeepSpeedStrategy, DDPShardedStrategy, DDPSpawnShardedStrategy, XLAStrategy)): |
| 617 | + raise RuntimeError( |
| 618 | + f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly" |
| 619 | + " through `.setup(model, optimizer, ...)`." |
| 620 | + ) |
| 621 | + |
| 622 | + if not optimizers: |
| 623 | + raise ValueError("`setup_optimizers` requires at least one optimizer as input.") |
| 624 | + |
| 625 | + if any(isinstance(opt, _LiteOptimizer) for opt in optimizers): |
| 626 | + raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.") |
| 627 | + |
539 | 628 | @staticmethod
|
540 | 629 | def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:
|
| 630 | + if not dataloaders: |
| 631 | + raise ValueError("`setup_dataloaders` requires at least one dataloader as input.") |
| 632 | + |
541 | 633 | if any(isinstance(dl, _LiteDataLoader) for dl in dataloaders):
|
542 |
| - raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method") |
| 634 | + raise ValueError("A dataloader should be passed only once to the `setup_dataloaders` method.") |
543 | 635 |
|
544 | 636 | if any(not isinstance(dl, DataLoader) for dl in dataloaders):
|
545 | 637 | raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
|
|
0 commit comments