Skip to content

Commit 70deac2

Browse files
cschellChristian Schellrohitgr7
authored
Reset epoch progress with batch size scaler (#13846)
Co-authored-by: Christian Schell <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent c418828 commit 70deac2

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8787
- Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296))
8888

8989

90+
- Reset epoch progress with batch size scaler ([#13846](https://github.com/Lightning-AI/lightning/pull/13846)
91+
92+
9093
- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))
9194

9295

src/pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ def _run_power_scaling(
128128
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
129129
for _ in range(max_trials):
130130
garbage_collection_cuda()
131-
trainer.fit_loop.global_step = 0 # reset after each try
131+
132+
# reset after each try
133+
_reset_progress(trainer)
134+
132135
try:
133136
# Try fit
134137
trainer.tuner._run(model)
@@ -166,7 +169,10 @@ def _run_binsearch_scaling(
166169
count = 0
167170
while True:
168171
garbage_collection_cuda()
169-
trainer.fit_loop.global_step = 0 # reset after each try
172+
173+
# reset after each try
174+
_reset_progress(trainer)
175+
170176
try:
171177
# Try fit
172178
trainer.tuner._run(model)
@@ -249,3 +255,12 @@ def _adjust_batch_size(
249255
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
250256
module = trainer.lightning_module or trainer.datamodule
251257
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)
258+
259+
260+
def _reset_progress(trainer: "pl.Trainer") -> None:
261+
if trainer.lightning_module.automatic_optimization:
262+
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.reset()
263+
else:
264+
trainer.fit_loop.epoch_loop.batch_loop.manual_loop.optim_step_progress.reset()
265+
266+
trainer.fit_loop.epoch_progress.reset()

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
from copy import deepcopy
16+
from unittest.mock import patch
1617

1718
import pytest
1819
import torch
@@ -308,10 +309,13 @@ def __init__(self):
308309
def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
309310
"""Test that train and val dataloaders are reset at every update in scale batch size."""
310311
model = BatchSizeModel(batch_size=16)
311-
scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4, "mode": scale_method}
312+
max_trials = 5
313+
scale_batch_size_kwargs = {"max_trials": max_trials, "steps_per_trial": 2, "init_val": 4, "mode": scale_method}
312314

313-
trainer = Trainer(max_epochs=2, auto_scale_batch_size=True)
314-
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
315+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
316+
with patch.object(model, "on_train_epoch_end") as advance_mocked:
317+
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
318+
assert advance_mocked.call_count == max_trials
315319

316320
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
317321
assert trainer.val_dataloaders[0].batch_size == new_batch_size

0 commit comments

Comments
 (0)