12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import logging
15
- from typing import Optional
15
+ from typing import Optional , Union
16
16
17
17
import lightning .pytorch as pl
18
18
from lightning .fabric .utilities .data import _auto_add_worker_init_fn
@@ -79,10 +79,12 @@ def __init__(
79
79
self .min_epochs = min_epochs
80
80
self .epoch_loop = _TrainingEpochLoop (trainer )
81
81
self .epoch_progress = Progress ()
82
+ self .max_batches : Union [int , float ] = float ("inf" )
82
83
83
84
self ._data_source = _DataLoaderSource (None , "train_dataloader" )
84
85
self ._combined_loader : Optional [CombinedLoader ] = None
85
86
self ._data_fetcher : Optional [_DataFetcher ] = None
87
+ self ._last_train_dl_reload_epoch = float ("-inf" )
86
88
87
89
@property
88
90
def total_batch_idx (self ) -> int :
@@ -136,10 +138,16 @@ def _can_stop_early(self) -> bool:
136
138
met_min_steps = self .epoch_loop .global_step >= self .min_steps if self .min_steps else True
137
139
return met_min_epochs and met_min_steps
138
140
141
+ @property
142
+ def _should_reload_train_dl (self ) -> bool :
143
+ """Check if train dataloader should be reloaded."""
144
+ n_epochs = self .trainer .reload_dataloaders_every_n_epochs
145
+ return n_epochs and self .trainer .current_epoch - self ._last_train_dl_reload_epoch >= n_epochs
146
+
139
147
@property
140
148
def done (self ) -> bool :
141
149
"""Evaluates when to leave the loop."""
142
- if self .trainer . num_training_batches == 0 :
150
+ if self .max_batches == 0 :
143
151
rank_zero_info ("`Trainer.fit` stopped: No training batches." )
144
152
return True
145
153
@@ -168,8 +176,8 @@ def done(self) -> bool:
168
176
@property
169
177
def skip (self ) -> bool :
170
178
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
171
- # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
172
- # until `on_run_start`, we use `limit_train_batches` instead
179
+ # if `limit_train_batches == 0` then `setup_data` won't set the `self.max_batches` attribute (checked in `done`)
180
+ # so we cannot use it solely
173
181
return self .done or self .trainer .limit_train_batches == 0
174
182
175
183
def run (self ) -> None :
@@ -190,11 +198,10 @@ def run(self) -> None:
190
198
self .on_run_end ()
191
199
192
200
def setup_data (self , shuffle : bool = True ) -> None :
193
- trainer = self .trainer
194
-
195
- if self ._combined_loader is not None and not trainer ._data_connector ._should_reload_train_dl :
201
+ if self ._combined_loader is not None and not self ._should_reload_train_dl :
196
202
return
197
203
204
+ trainer = self .trainer
198
205
source = self ._data_source
199
206
pl_module = trainer .lightning_module
200
207
if not source .is_defined () or trainer .limit_train_batches == 0 or not is_overridden ("training_step" , pl_module ):
@@ -227,7 +234,7 @@ def setup_data(self, shuffle: bool = True) -> None:
227
234
self ._combined_loader = combined_loader
228
235
229
236
module = pl_module or trainer .datamodule
230
- orig_train_batches = trainer . num_training_batches = (
237
+ orig_train_batches = self . max_batches = (
231
238
len (self ._combined_loader )
232
239
if has_len_all_ranks (self ._combined_loader , trainer .strategy , module )
233
240
else float ("inf" )
@@ -236,12 +243,12 @@ def setup_data(self, shuffle: bool = True) -> None:
236
243
return
237
244
238
245
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
239
- trainer ._last_train_dl_reload_epoch = trainer .current_epoch
246
+ self ._last_train_dl_reload_epoch = trainer .current_epoch
240
247
241
248
if isinstance (trainer .limit_train_batches , int ):
242
- trainer . num_training_batches = min (orig_train_batches , trainer .limit_train_batches )
243
- elif trainer . num_training_batches != float ("inf" ):
244
- trainer . num_training_batches = int (orig_train_batches * trainer .limit_train_batches )
249
+ self . max_batches = min (orig_train_batches , trainer .limit_train_batches )
250
+ elif self . max_batches != float ("inf" ):
251
+ self . max_batches = int (orig_train_batches * trainer .limit_train_batches )
245
252
elif trainer .limit_train_batches != 1.0 :
246
253
raise MisconfigurationException (
247
254
"When using an `IterableDataset`, `Trainer(limit_train_batches)` must be `1.0` or an int."
@@ -250,10 +257,10 @@ def setup_data(self, shuffle: bool = True) -> None:
250
257
251
258
if isinstance (trainer .val_check_interval , int ):
252
259
trainer .val_check_batch = trainer .val_check_interval
253
- if trainer .val_check_batch > trainer . num_training_batches and trainer .check_val_every_n_epoch is not None :
260
+ if trainer .val_check_batch > self . max_batches and trainer .check_val_every_n_epoch is not None :
254
261
raise ValueError (
255
262
f" `val_check_interval` ({ trainer .val_check_interval } ) must be less than or equal"
256
- f" to the number of the training batches ({ trainer . num_training_batches } )."
263
+ f" to the number of the training batches ({ self . max_batches } )."
257
264
" If you want to disable validation set `limit_val_batches` to 0.0 instead."
258
265
" If you want to validate based on the total training batches, set `check_val_every_n_epoch=None`."
259
266
)
@@ -268,19 +275,19 @@ def setup_data(self, shuffle: bool = True) -> None:
268
275
" checking validation every k training batches."
269
276
)
270
277
else :
271
- trainer .val_check_batch = int (trainer . num_training_batches * trainer .val_check_interval )
278
+ trainer .val_check_batch = int (self . max_batches * trainer .val_check_interval )
272
279
trainer .val_check_batch = max (1 , trainer .val_check_batch )
273
280
274
- if trainer .loggers and trainer . num_training_batches < trainer .log_every_n_steps :
281
+ if trainer .loggers and self . max_batches < trainer .log_every_n_steps :
275
282
rank_zero_warn (
276
- f"The number of training batches ({ trainer . num_training_batches } ) is smaller than the logging interval"
283
+ f"The number of training batches ({ self . max_batches } ) is smaller than the logging interval"
277
284
f" Trainer(log_every_n_steps={ trainer .log_every_n_steps } ). Set a lower value for log_every_n_steps if"
278
285
" you want to see logs for the training epoch." ,
279
286
category = PossibleUserWarning ,
280
287
)
281
288
282
289
if (
283
- trainer . num_training_batches == 0
290
+ self . max_batches == 0
284
291
and trainer .limit_train_batches > 0.0
285
292
and isinstance (trainer .limit_train_batches , float )
286
293
and orig_train_batches != float ("inf" )
0 commit comments