Skip to content

Commit 565d611

Browse files
authored
Move max_batches definition to the Loops (#16820)
1 parent f969411 commit 565d611

File tree

10 files changed

+85
-94
lines changed

10 files changed

+85
-94
lines changed

src/lightning/pytorch/callbacks/stochastic_weight_avg.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,12 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
250250

251251
# There is no need to perform either backward or optimizer.step as we are
252252
# performing only one pass over the train data-loader to compute activation statistics
253-
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
254-
assert isinstance(trainer.num_training_batches, int)
255-
trainer.num_training_batches += 1
253+
# Therefore, we will virtually increase the number of training batches by 1 and skip backward.
254+
trainer.fit_loop.max_batches += 1
256255
trainer.fit_loop._skip_backward = True
257256
self._accumulate_grad_batches = trainer.accumulate_grad_batches
258-
259-
trainer.accumulate_grad_batches = trainer.num_training_batches
257+
assert isinstance(trainer.fit_loop.max_batches, int), "Iterable-style datasets are not supported"
258+
trainer.accumulate_grad_batches = trainer.fit_loop.max_batches
260259

261260
def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
262261
trainer.fit_loop._skip_backward = False
@@ -266,7 +265,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
266265
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
267266
# BatchNorm epoch update. Reset state
268267
trainer.accumulate_grad_batches = self._accumulate_grad_batches
269-
trainer.num_training_batches -= 1
268+
trainer.fit_loop.max_batches -= 1
270269
assert trainer.fit_loop.max_epochs is not None
271270
trainer.fit_loop.max_epochs -= 1
272271
self.reset_momenta()

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode:
4747
self.verbose = verbose
4848
self.inference_mode = inference_mode
4949
self.batch_progress = BatchProgress() # across dataloaders
50+
self._max_batches: List[Union[int, float]] = []
5051

5152
self._results = _ResultCollection(training=False)
5253
self._logged_outputs: List[_OUT_DICT] = []
@@ -55,6 +56,7 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode:
5556
self._combined_loader: Optional[CombinedLoader] = None
5657
self._data_fetcher: Optional[_DataFetcher] = None
5758
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
59+
self._last_val_dl_reload_epoch = float("-inf")
5860

5961
@property
6062
def num_dataloaders(self) -> int:
@@ -66,19 +68,22 @@ def num_dataloaders(self) -> int:
6668
@property
6769
def max_batches(self) -> List[Union[int, float]]:
6870
"""The max number of batches this loop will run for each dataloader."""
69-
if self.trainer.testing:
70-
return self.trainer.num_test_batches
71-
elif self.trainer.sanity_checking:
72-
return self.trainer.num_sanity_val_batches
73-
elif self.trainer.validating:
74-
return self.trainer.num_val_batches
75-
raise RuntimeError(f"Unexpected stage: {self.trainer.state.stage}")
71+
max_batches = self._max_batches
72+
if self.trainer.sanity_checking:
73+
return [min(self.trainer.num_sanity_val_steps, batches) for batches in max_batches]
74+
return max_batches
7675

7776
@property
7877
def skip(self) -> bool:
7978
"""Returns whether the evaluation should be skipped."""
8079
return sum(self.max_batches) == 0
8180

81+
@property
82+
def _should_reload_val_dl(self) -> bool:
83+
"""Check if validation dataloader should be reloaded."""
84+
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
85+
return bool(n_epochs and self.trainer.current_epoch - self._last_val_dl_reload_epoch >= n_epochs)
86+
8287
@_no_grad_context
8388
def run(self) -> List[_OUT_DICT]:
8489
self.setup_data()
@@ -110,11 +115,7 @@ def run(self) -> List[_OUT_DICT]:
110115
def setup_data(self) -> None:
111116
trainer = self.trainer
112117

113-
if (
114-
self._combined_loader is not None
115-
and trainer.state.fn == "fit"
116-
and not trainer._data_connector._should_reload_val_dl
117-
):
118+
if self._combined_loader is not None and trainer.state.fn == "fit" and not self._should_reload_val_dl:
118119
return
119120

120121
source = self._data_source
@@ -130,20 +131,11 @@ def setup_data(self) -> None:
130131
(trainer.sanity_checking and trainer.fit_loop.epoch_loop._should_check_val_epoch())
131132
or not trainer.sanity_checking
132133
):
133-
trainer._last_val_dl_reload_epoch = trainer.current_epoch
134+
self._last_val_dl_reload_epoch = trainer.current_epoch
134135

135136
stage = trainer.state.stage
136137
assert stage is not None
137-
num_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module)
138-
if trainer.testing:
139-
trainer.num_test_batches = num_batches
140-
elif trainer.sanity_checking:
141-
trainer.num_val_batches = num_batches
142-
trainer.num_sanity_val_batches = [
143-
min(trainer.num_sanity_val_steps, val_batches) for val_batches in num_batches
144-
]
145-
else:
146-
trainer.num_val_batches = num_batches
138+
self._max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(stage, model=pl_module)
147139

148140
if trainer.state.fn != "fit": # if we are fitting, we need to do this in the loop
149141
for dl in combined_loader.flattened:

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import Optional
15+
from typing import Optional, Union
1616

1717
import lightning.pytorch as pl
1818
from lightning.fabric.utilities.data import _auto_add_worker_init_fn
@@ -79,10 +79,12 @@ def __init__(
7979
self.min_epochs = min_epochs
8080
self.epoch_loop = _TrainingEpochLoop(trainer)
8181
self.epoch_progress = Progress()
82+
self.max_batches: Union[int, float] = float("inf")
8283

8384
self._data_source = _DataLoaderSource(None, "train_dataloader")
8485
self._combined_loader: Optional[CombinedLoader] = None
8586
self._data_fetcher: Optional[_DataFetcher] = None
87+
self._last_train_dl_reload_epoch = float("-inf")
8688

8789
@property
8890
def total_batch_idx(self) -> int:
@@ -136,10 +138,16 @@ def _can_stop_early(self) -> bool:
136138
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
137139
return met_min_epochs and met_min_steps
138140

141+
@property
142+
def _should_reload_train_dl(self) -> bool:
143+
"""Check if train dataloader should be reloaded."""
144+
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
145+
return n_epochs and self.trainer.current_epoch - self._last_train_dl_reload_epoch >= n_epochs
146+
139147
@property
140148
def done(self) -> bool:
141149
"""Evaluates when to leave the loop."""
142-
if self.trainer.num_training_batches == 0:
150+
if self.max_batches == 0:
143151
rank_zero_info("`Trainer.fit` stopped: No training batches.")
144152
return True
145153

@@ -168,8 +176,8 @@ def done(self) -> bool:
168176
@property
169177
def skip(self) -> bool:
170178
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
171-
# since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
172-
# until `on_run_start`, we use `limit_train_batches` instead
179+
# if `limit_train_batches == 0` then `setup_data` won't set the `self.max_batches` attribute (checked in `done`)
180+
# so we cannot use it solely
173181
return self.done or self.trainer.limit_train_batches == 0
174182

175183
def run(self) -> None:
@@ -190,11 +198,10 @@ def run(self) -> None:
190198
self.on_run_end()
191199

192200
def setup_data(self, shuffle: bool = True) -> None:
193-
trainer = self.trainer
194-
195-
if self._combined_loader is not None and not trainer._data_connector._should_reload_train_dl:
201+
if self._combined_loader is not None and not self._should_reload_train_dl:
196202
return
197203

204+
trainer = self.trainer
198205
source = self._data_source
199206
pl_module = trainer.lightning_module
200207
if not source.is_defined() or trainer.limit_train_batches == 0 or not is_overridden("training_step", pl_module):
@@ -227,7 +234,7 @@ def setup_data(self, shuffle: bool = True) -> None:
227234
self._combined_loader = combined_loader
228235

229236
module = pl_module or trainer.datamodule
230-
orig_train_batches = trainer.num_training_batches = (
237+
orig_train_batches = self.max_batches = (
231238
len(self._combined_loader)
232239
if has_len_all_ranks(self._combined_loader, trainer.strategy, module)
233240
else float("inf")
@@ -236,12 +243,12 @@ def setup_data(self, shuffle: bool = True) -> None:
236243
return
237244

238245
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
239-
trainer._last_train_dl_reload_epoch = trainer.current_epoch
246+
self._last_train_dl_reload_epoch = trainer.current_epoch
240247

241248
if isinstance(trainer.limit_train_batches, int):
242-
trainer.num_training_batches = min(orig_train_batches, trainer.limit_train_batches)
243-
elif trainer.num_training_batches != float("inf"):
244-
trainer.num_training_batches = int(orig_train_batches * trainer.limit_train_batches)
249+
self.max_batches = min(orig_train_batches, trainer.limit_train_batches)
250+
elif self.max_batches != float("inf"):
251+
self.max_batches = int(orig_train_batches * trainer.limit_train_batches)
245252
elif trainer.limit_train_batches != 1.0:
246253
raise MisconfigurationException(
247254
"When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int."
@@ -250,10 +257,10 @@ def setup_data(self, shuffle: bool = True) -> None:
250257

251258
if isinstance(trainer.val_check_interval, int):
252259
trainer.val_check_batch = trainer.val_check_interval
253-
if trainer.val_check_batch > trainer.num_training_batches and trainer.check_val_every_n_epoch is not None:
260+
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:
254261
raise ValueError(
255262
f" `val_check_interval` ({trainer.val_check_interval}) must be less than or equal"
256-
f" to the number of the training batches ({trainer.num_training_batches})."
263+
f" to the number of the training batches ({self.max_batches})."
257264
" If you want to disable validation set `limit_val_batches` to 0.0 instead."
258265
" If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
259266
)
@@ -268,19 +275,19 @@ def setup_data(self, shuffle: bool = True) -> None:
268275
" checking validation every k training batches."
269276
)
270277
else:
271-
trainer.val_check_batch = int(trainer.num_training_batches * trainer.val_check_interval)
278+
trainer.val_check_batch = int(self.max_batches * trainer.val_check_interval)
272279
trainer.val_check_batch = max(1, trainer.val_check_batch)
273280

274-
if trainer.loggers and trainer.num_training_batches < trainer.log_every_n_steps:
281+
if trainer.loggers and self.max_batches < trainer.log_every_n_steps:
275282
rank_zero_warn(
276-
f"The number of training batches ({trainer.num_training_batches}) is smaller than the logging interval"
283+
f"The number of training batches ({self.max_batches}) is smaller than the logging interval"
277284
f" Trainer(log_every_n_steps={trainer.log_every_n_steps}). Set a lower value for log_every_n_steps if"
278285
" you want to see logs for the training epoch.",
279286
category=PossibleUserWarning,
280287
)
281288

282289
if (
283-
trainer.num_training_batches == 0
290+
self.max_batches == 0
284291
and trainer.limit_train_batches > 0.0
285292
and isinstance(trainer.limit_train_batches, float)
286293
and orig_train_batches != float("inf")

src/lightning/pytorch/loops/prediction_loop.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self, trainer: "pl.Trainer", inference_mode: bool = True) -> None:
3131
self.epoch_batch_indices: List[List[List[int]]] = []
3232
self.current_batch_indices: List[int] = [] # used by PredictionWriter
3333
self.batch_progress = Progress() # across dataloaders
34+
self.max_batches: List[Union[int, float]] = []
3435

3536
self._warning_cache = WarningCache()
3637
self._data_source = _DataLoaderSource(None, "predict_dataloader")
@@ -71,11 +72,6 @@ def num_dataloaders(self) -> int:
7172
assert combined_loader is not None
7273
return len(combined_loader.flattened)
7374

74-
@property
75-
def max_batches(self) -> List[Union[int, float]]:
76-
"""The max number of batches this loop will run for each dataloader."""
77-
return self.trainer.num_predict_batches
78-
7975
@property
8076
def skip(self) -> bool:
8177
return sum(self.max_batches) == 0
@@ -109,7 +105,7 @@ def setup_data(self) -> None:
109105
if not source.is_defined() or trainer.limit_predict_batches == 0:
110106
return
111107

112-
trainer.num_predict_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(
108+
self.max_batches, combined_loader = trainer._data_connector._reset_eval_dataloader(
113109
RunningStage.PREDICTING, model=pl_module
114110
)
115111
for dl in combined_loader.flattened:

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,6 @@ def __init__(self, trainer: "pl.Trainer"):
4444
self.trainer = trainer
4545
self._datahook_selector: Optional[_DataHookSelector] = None
4646

47-
@property
48-
def _should_reload_train_dl(self) -> bool:
49-
"""Check if train dataloader should be reloaded."""
50-
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
51-
return n_epochs and self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs
52-
53-
@property
54-
def _should_reload_val_dl(self) -> bool:
55-
"""Check if validation dataloader should be reloaded."""
56-
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
57-
return bool(n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs)
58-
5947
def on_trainer_init(
6048
self,
6149
val_check_interval: Optional[Union[int, float]],
@@ -83,7 +71,6 @@ def on_trainer_init(
8371
)
8472

8573
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
86-
self.trainer._is_data_prepared = False
8774

8875
def prepare_data(self) -> None:
8976
trainer = self.trainer
@@ -107,7 +94,6 @@ def prepare_data(self) -> None:
10794
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
10895
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
10996
call._call_lightning_module_hook(trainer, "prepare_data")
110-
trainer._is_data_prepared = True
11197

11298
def attach_data(
11399
self,

src/lightning/pytorch/trainer/trainer.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,11 @@ def __init__(
349349
)
350350

351351
self._detect_anomaly: bool = detect_anomaly
352-
self._setup_on_init()
352+
353+
setup._log_device_info(self)
354+
355+
self.should_stop = False
356+
self.state = TrainerState()
353357

354358
# configure profiler
355359
setup._init_profiler(self, profiler)
@@ -378,22 +382,6 @@ def __init__(
378382
num_sanity_val_steps,
379383
)
380384

381-
def _setup_on_init(self) -> None:
382-
setup._log_device_info(self)
383-
384-
self.should_stop = False
385-
self.state = TrainerState()
386-
387-
# TODO(carmocca): move these to the loops
388-
self.num_training_batches = float("inf")
389-
self.num_sanity_val_batches: List[Union[int, float]] = []
390-
self.num_test_batches: List[Union[int, float]] = []
391-
self.num_val_batches: List[Union[int, float]] = []
392-
self.num_predict_batches: List[Union[int, float]] = []
393-
394-
self._last_train_dl_reload_epoch = float("-inf")
395-
self._last_val_dl_reload_epoch = float("-inf")
396-
397385
def fit(
398386
self,
399387
model: "pl.LightningModule",
@@ -1305,6 +1293,31 @@ def predict_dataloaders(self) -> EVAL_DATALOADERS:
13051293
if (combined_loader := self.predict_loop._combined_loader) is not None:
13061294
return combined_loader.iterables
13071295

1296+
@property
1297+
def num_training_batches(self) -> Union[int, float]:
1298+
return self.fit_loop.max_batches
1299+
1300+
@property
1301+
def num_sanity_val_batches(self) -> List[Union[int, float]]:
1302+
max_batches = self.fit_loop.epoch_loop.val_loop.max_batches
1303+
return [min(self.num_sanity_val_steps, batches) for batches in max_batches]
1304+
1305+
@property
1306+
def num_val_batches(self) -> List[Union[int, float]]:
1307+
if self.state.fn == TrainerFn.VALIDATING:
1308+
return self.validate_loop.max_batches
1309+
# if no trainer.fn is set, assume fit's validation
1310+
# use the protected access, because it shouldn't return the sanity_val batches
1311+
return self.fit_loop.epoch_loop.val_loop._max_batches
1312+
1313+
@property
1314+
def num_test_batches(self) -> List[Union[int, float]]:
1315+
return self.test_loop.max_batches
1316+
1317+
@property
1318+
def num_predict_batches(self) -> List[Union[int, float]]:
1319+
return self.predict_loop.max_batches
1320+
13081321
@property
13091322
def _evaluation_loop(self) -> _EvaluationLoop:
13101323
if self.state.fn == TrainerFn.FITTING:

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,9 +1366,9 @@ def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs):
13661366

13671367
def test_train_epoch_end_ckpt_with_no_validation():
13681368
trainer = Trainer(val_check_interval=0.5)
1369-
trainer.num_val_batches = [0]
1369+
trainer.fit_loop.epoch_loop.val_loop._max_batches = [0]
13701370
assert trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
1371-
trainer.num_val_batches = [1]
1371+
trainer.fit_loop.epoch_loop.val_loop._max_batches = [1]
13721372
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
13731373
trainer.val_check_interval = 0.8
13741374
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)

tests/tests_pytorch/loops/test_training_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_should_stop_early_stopping_conditions_not_met(
6666
"""Test that checks that info message is logged when users sets `should_stop` but min conditions are not
6767
met."""
6868
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
69-
trainer.num_training_batches = 10
69+
trainer.fit_loop.max_batches = 10
7070
trainer.should_stop = True
7171
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.total.completed = global_step
7272
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step

0 commit comments

Comments
 (0)