Skip to content

Commit 67c09e3

Browse files
awaelchlicarmoccapre-commit-ci[bot]
authored
Separate the Gradient Accumulation Scheduler from Trainer (#16729)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c7ffd41 commit 67c09e3

20 files changed

+149
-235
lines changed

docs/source-pytorch/common/gradient_accumulation.rst

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,17 @@ effective batch size is increased but there is no memory overhead.
1919
# Accumulate gradients for 7 batches
2020
trainer = Trainer(accumulate_grad_batches=7)
2121

22-
You can set different values for it at different epochs by passing a dictionary, where the key represents the epoch at which the value for gradient accumulation
23-
should be updated.
24-
25-
.. testcode::
26-
27-
# till 5th epoch, it will accumulate every 8 batches. From 5th epoch
28-
# till 9th epoch it will accumulate every 4 batches and after that no accumulation
29-
# will happen. Note that you need to use zero-indexed epoch keys here
30-
trainer = Trainer(accumulate_grad_batches={0: 8, 4: 4, 8: 1})
31-
32-
Or, you can create custom :class:`~pytorch_lightning.callbacks.gradient_accumulation_scheduler.GradientAccumulationScheduler`
22+
Optionally, you can make the ``accumulate_grad_batches`` value change over time by using the :class:`~pytorch_lightning.callbacks.gradient_accumulation_scheduler.GradientAccumulationScheduler`.
23+
Pass in a scheduling dictionary, where the key represents the epoch at which the value for gradient accumulation should be updated.
3324

3425
.. testcode::
3526

3627
from pytorch_lightning.callbacks import GradientAccumulationScheduler
3728

38-
3929
# till 5th epoch, it will accumulate every 8 batches. From 5th epoch
4030
# till 9th epoch it will accumulate every 4 batches and after that no accumulation
4131
# will happen. Note that you need to use zero-indexed epoch keys here
4232
accumulator = GradientAccumulationScheduler(scheduling={0: 8, 4: 4, 8: 1})
4333
trainer = Trainer(callbacks=accumulator)
34+
35+
Note: Not all strategies and accelerators support variable gradient accumulation windows.

docs/source-pytorch/common/optimization.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ always switch to :ref:`manual optimization <manual_optimization>`.
5757
Manual optimization is required if you wish to work with multiple optimizers.
5858

5959

60+
.. _gradient_accumulation:
61+
6062
Gradient Accumulation
6163
=====================
6264

docs/source-pytorch/common/trainer.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,7 @@ accumulate_grad_batches
271271

272272
|
273273
274-
Accumulates grads every k batches or as set up in the dict.
275-
Trainer also calls ``optimizer.step()`` for the last indivisible step number.
274+
Accumulates gradients over k batches before stepping the optimizer.
276275

277276
.. testcode::
278277

@@ -284,8 +283,7 @@ Example::
284283
# accumulate every 4 batches (effective batch size is batch*4)
285284
trainer = Trainer(accumulate_grad_batches=4)
286285

287-
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
288-
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
286+
See also: :ref:`gradient_accumulation` to enable more fine-grained accumulation schedules.
289287

290288

291289
benchmark

src/lightning/pytorch/CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
241241

242242
- Removed the `using_lbfgs` argument from `LightningModule.optimizer_step` hook ([#16538](https://github.com/Lightning-AI/lightning/pull/16538))
243243

244+
244245
- Removed the `Trainer.data_parallel` property. Use `isinstance(trainer.strategy, ParallelStrategy)` instead ([#16703](https://github.com/Lightning-AI/lightning/pull/16703))
245246

247+
246248
- Removed support for multiple optimizers in automatic optimization mode ([#16539](https://github.com/Lightning-AI/lightning/pull/16539))
247249
* Removed `opt_idx` argument from `BaseFinetuning.finetune_function` callback method
248250
* Removed `opt_idx` argument from `Callback.on_before_optimizer_step` callback method
@@ -265,10 +267,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
265267

266268
- Removed `PrecisionPlugin.dispatch` ([#16618](https://github.com/Lightning-AI/lightning/pull/16618))
267269

270+
268271
- Removed the unused `lightning.pytorch.utilities.metrics.metrics_to_scalars` function ([#16681](https://github.com/Lightning-AI/lightning/pull/16681))
269272

273+
274+
- Removed support for passing a scheduling dictionary to `Trainer(accumulate_grad_batches=...)` ([#16729](https://github.com/Lightning-AI/lightning/pull/16729))
275+
276+
270277
- Removed the unused `lightning.pytorch.utilities.supporters.{SharedCycleIteratorState,CombinedLoaderIterator}` classes ([#16714](https://github.com/Lightning-AI/lightning/pull/16714))
271278

279+
272280
### Fixed
273281

274282
-

src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import lightning.pytorch as pl
2626
from lightning.pytorch.callbacks.callback import Callback
2727
from lightning.pytorch.utilities.exceptions import MisconfigurationException
28+
from lightning.pytorch.utilities.model_helpers import is_overridden
29+
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
2830

2931

3032
class GradientAccumulationScheduler(Callback):
@@ -58,9 +60,6 @@ class GradientAccumulationScheduler(Callback):
5860
# because epoch (key) should be zero-indexed.
5961
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
6062
>>> trainer = Trainer(callbacks=[accumulator])
61-
62-
# alternatively, pass the scheduling dict directly to the Trainer
63-
>>> trainer = Trainer(accumulate_grad_batches={4: 2})
6463
"""
6564

6665
def __init__(self, scheduling: Dict[int, int]):
@@ -82,7 +81,7 @@ def __init__(self, scheduling: Dict[int, int]):
8281
minimal_epoch = min(scheduling.keys())
8382
if minimal_epoch < 0:
8483
raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")
85-
if minimal_epoch != 0: # if user didnt define first epoch accumulation factor
84+
if minimal_epoch != 0: # if user didn't define first epoch accumulation factor
8685
scheduling.update({0: 1})
8786

8887
self.scheduling = scheduling
@@ -99,5 +98,49 @@ def get_accumulate_grad_batches(self, epoch: int) -> int:
9998
break
10099
return accumulate_grad_batches
101100

101+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
102+
"""Performns a configuration validation before training starts and raises errors for incompatible
103+
settings."""
104+
105+
if not pl_module.automatic_optimization:
106+
raise RuntimeError(
107+
"""Automatic gradient accumulation and the `GradientAccumulationScheduler` is not supported for
108+
manual optimization. Please remove the callback or switch to automatic optimization."""
109+
)
110+
111+
overridden_optimizer_step = is_overridden("optimizer_step", pl_module)
112+
overridden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", pl_module)
113+
going_to_accumulate_grad_batches = self.going_to_accumulate_grad_batches()
114+
has_overridden_optimization_functions = overridden_optimizer_step or overridden_optimizer_zero_grad
115+
if has_overridden_optimization_functions and going_to_accumulate_grad_batches:
116+
rank_zero_warn(
117+
"When using `Trainer(accumulate_grad_batches != 1)` and overriding"
118+
" `LightningModule.optimizer_{step,zero_grad}`, the hooks will not be called on every batch"
119+
" (rather, they are called on every optimization step)."
120+
)
121+
122+
# local import to avoid circular import
123+
from lightning.pytorch.accelerators import IPUAccelerator
124+
from lightning.pytorch.strategies import ColossalAIStrategy, DeepSpeedStrategy
125+
126+
unsupported_strategies = (DeepSpeedStrategy, ColossalAIStrategy)
127+
unsupported_accelerators = (IPUAccelerator,)
128+
129+
if isinstance(trainer.accelerator, unsupported_accelerators):
130+
raise RuntimeError(
131+
f"The `{type(trainer.accelerator).__name__}` does not support `accumulate_grad_batches` changing"
132+
" between epochs."
133+
)
134+
if isinstance(trainer.strategy, unsupported_strategies):
135+
raise RuntimeError(
136+
f"The `{type(trainer.strategy).__name__}` does not support `accumulate_grad_batches` changing"
137+
" between epochs."
138+
)
139+
if trainer.accumulate_grad_batches != 1:
140+
raise ValueError(
141+
"You have set `accumulate_grad_batches` and are using the `GradientAccumulationScheduler`"
142+
" callback. Either remove `accumulate_grad_batches` from the Trainer or remove the callback."
143+
)
144+
102145
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
103146
trainer.accumulate_grad_batches = self.get_accumulate_grad_batches(trainer.current_epoch)

src/lightning/pytorch/callbacks/stochastic_weight_avg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
251251
# There is no need to perform either backward or optimizer.step as we are
252252
# performing only one pass over the train data-loader to compute activation statistics
253253
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
254+
assert isinstance(trainer.num_training_batches, int)
254255
trainer.num_training_batches += 1
255256
trainer.fit_loop._skip_backward = True
256257
self._accumulate_grad_batches = trainer.accumulate_grad_batches

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,6 @@ def on_advance_start(self) -> None:
238238
assert isinstance(self.trainer.train_dataloader, CombinedLoader)
239239
_set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed)
240240

241-
# changing gradient according accumulation_scheduler
242-
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
243-
244241
self.epoch_progress.increment_ready()
245242

246243
self.trainer._logger_connector.on_epoch_start()

src/lightning/pytorch/strategies/colossalai.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,6 @@ def setup(self, trainer: "pl.Trainer") -> None:
348348
"ColossalAI does not support gradient accumulation now. Please set `accumulate_grad_batches` to 1."
349349
)
350350

351-
accumulation_scheduler = trainer.accumulation_scheduler
352-
if accumulation_scheduler.epochs != [0]:
353-
raise ValueError(
354-
"ColossalAI currently does not support different `accumulate_grad_batches` at different epochs."
355-
)
356-
357351
if not isinstance(self.precision_plugin, ColossalAIPrecisionPlugin):
358352
raise ValueError("`ColossalAIStrategy` is only compatible with `ColossalAIPrecisionPlugin`.")
359353

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -441,13 +441,6 @@ def init_deepspeed(self) -> None:
441441
f"DeepSpeed strategy is only supported on GPU but `{self.accelerator.__class__.__name__}` is used."
442442
)
443443

444-
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler
445-
446-
if accumulation_scheduler.epochs != [0]:
447-
raise MisconfigurationException(
448-
"DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
449-
)
450-
451444
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
452445
model = _LightningModuleWrapperBase(forward_module=self.model)
453446

src/lightning/pytorch/strategies/ipu.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,6 @@ def __init__(
105105
self._optimizer_zero_grad_original: Optional[Callable] = None
106106

107107
def setup(self, trainer: "pl.Trainer") -> None:
108-
# set the `accumulate_grad_batches` property as early as possible
109-
self._handle_gradient_accumulation_steps()
110-
111108
# patch the dataloader creation function with the custom `poptorch.DataLoader`.
112109
# this violates the intended control flow for the plugins, but since this is experimental, we have chosen
113110
# to use the simpler solution before adding abstractions to override the `DataLoader` class
@@ -217,23 +214,6 @@ def _convert_to_poptorch_loader(
217214
)
218215
return dataloader
219216

220-
def _handle_gradient_accumulation_steps(self) -> None:
221-
"""Override the trainer.accumulation_scheduler to act as ``accumulate_grad_batches=1`` if gradient
222-
accumulation has been set.
223-
224-
``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation internally.
225-
"""
226-
assert self.lightning_module is not None
227-
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler
228-
229-
if accumulation_scheduler.epochs != [0]:
230-
raise MisconfigurationException(
231-
"IPUs currently does not support different `accumulate_grad_batches` at different epochs."
232-
)
233-
234-
# TODO(@tchaton): Add support for accumulate_grad_batches being a dictionary
235-
accumulation_scheduler.scheduling.update({0: 1})
236-
237217
@property
238218
def _n_replicate(self) -> int:
239219
assert self.lightning_module is not None

0 commit comments

Comments
 (0)