Skip to content

Commit e359268

Browse files
robertomestrohitgr7Bordaawaelchli
authored andcommitted
Estimate stepping batches with max_steps if max_epochs is not set (#14317)
Co-authored-by: Roberto Estevão <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent cadff2d commit e359268

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929
- Fixed an issue to avoid the impact of sanity check on `reload_dataloaders_every_n_epochs` for validation ([#13964](https://github.com/Lightning-AI/lightning/pull/13964))
3030

3131

32+
- Fixed `Trainer.estimated_stepping_batches` when maximum number of epochs is not set ([#14317](https://github.com/Lightning-AI/lightning/pull/14317))
33+
34+
3235
## [1.7.2] - 2022-08-17
3336

3437
### Added

src/pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2769,8 +2769,8 @@ def configure_optimizers(self):
27692769
)
27702770

27712771
# infinite training
2772-
if self.max_epochs == -1 and self.max_steps == -1:
2773-
return float("inf")
2772+
if self.max_epochs == -1:
2773+
return float("inf") if self.max_steps == -1 else self.max_steps
27742774

27752775
if self.train_dataloader is None:
27762776
rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.")

tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def test_num_stepping_batches_infinite_training():
9595
assert trainer.estimated_stepping_batches == float("inf")
9696

9797

98-
def test_num_stepping_batches_with_max_steps():
98+
@pytest.mark.parametrize("max_steps", [2, 100])
99+
def test_num_stepping_batches_with_max_steps(max_steps):
99100
"""Test stepping batches with `max_steps`."""
100-
max_steps = 2
101101
trainer = Trainer(max_steps=max_steps)
102102
model = BoringModel()
103103
trainer.fit(model)

0 commit comments

Comments
 (0)