Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect values after transferring data to a MPS device ([#13285](https://github.com/Lightning-AI/lightning/issues/13285))


- Reset the dataloaders on OOM failure in batch size finder to use the last successful batch size ([#14372](https://github.com/Lightning-AI/lightning/pull/14372))


- Fixed an issue to keep downscaling the batch size in case there hasn't been even a single successful optimal batch size with `mode="power"` ([#14372](https://github.com/Lightning-AI/lightning/pull/14372))


- Fixed an issue to avoid the impact of sanity check on `reload_dataloaders_every_n_epochs` for validation ([#13964](https://github.com/Lightning-AI/lightning/pull/13964))


Expand Down
38 changes: 26 additions & 12 deletions src/pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def _run_power_scaling(
trainer: "pl.Trainer", model: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int
) -> int:
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
# if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
any_success = False
for _ in range(max_trials):
garbage_collection_cuda()

Expand All @@ -137,22 +140,28 @@ def _run_power_scaling(
trainer.tuner._run(model)
# Double in size
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")

if not changed:
break

# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
any_success = True
except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc="failed")
break
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
if any_success:
break
else:
raise # some other error not memory related

if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
else:
break
return new_size


Expand Down Expand Up @@ -189,13 +198,13 @@ def _run_binsearch_scaling(
else:
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")

if changed:
# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)
else:
if not changed:
break

# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)

except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
Expand All @@ -204,6 +213,11 @@ def _run_binsearch_scaling(
high = new_size
midval = (high + low) // 2
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed")

# Force the train dataloader to reset as the batch size has changed
trainer.reset_train_dataloader(model)
trainer.reset_val_dataloader(model)

if high - low <= 1:
break
else:
Expand Down
23 changes: 23 additions & 0 deletions tests/tests_pytorch/tuner/test_scale_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,26 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):

assert trainer.train_dataloader.loaders.batch_size == new_batch_size
assert trainer.val_dataloaders[0].batch_size == new_batch_size


@pytest.mark.parametrize("scale_method, expected_batch_size", [("power", 62), ("binsearch", 100)])
@patch("pytorch_lightning.tuner.batch_size_scaling.is_oom_error", return_value=True)
def test_dataloader_batch_size_updated_on_failure(_, tmpdir, scale_method, expected_batch_size):
class CustomBatchSizeModel(BatchSizeModel):
def training_step(self, *_, **__):
if self.batch_size > 100:
raise RuntimeError

def train_dataloader(self):
return DataLoader(RandomDataset(32, 1000), batch_size=self.batch_size)

model = CustomBatchSizeModel(batch_size=16)
model.validation_step = None
model.training_epoch_end = None
scale_batch_size_kwargs = {"max_trials": 10, "steps_per_trial": 1, "init_val": 500, "mode": scale_method}

trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, auto_scale_batch_size=True)
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
assert new_batch_size == model.batch_size
assert new_batch_size == expected_batch_size
assert trainer.train_dataloader.loaders.batch_size == expected_batch_size