@@ -283,6 +283,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
283
283
state_dict ["loop_child.state_dict" ]["a" ] = 3
284
284
# check restarting after `load_state_dict`
285
285
loop_parent .load_state_dict (state_dict )
286
+ loop_parent .restarting = True
286
287
assert loop_parent .restarting
287
288
288
289
loop_parent .run ()
@@ -306,6 +307,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
306
307
loop_child = Simple (2 )
307
308
loop_parent .loop_child = loop_child
308
309
loop_parent .load_state_dict (state_dict )
310
+ loop_parent .restarting = True
309
311
assert loop_parent .progress .increment == 1
310
312
assert loop_parent .loop_child .progress .increment == 1
311
313
@@ -359,6 +361,7 @@ def val_dataloader(self):
359
361
assert checkpoint ["epoch_loop.val_loop.dataloader_progress" ] == expected
360
362
361
363
trainer .fit_loop .load_state_dict (checkpoint )
364
+ trainer .fit_loop .restarting = True
362
365
363
366
# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
364
367
# the fit-validation total batch progress is reset per epoch so it's not counted for the total value.
@@ -548,6 +551,7 @@ def configure_optimizers_multiple(self):
548
551
assert checkpoint ["loops" ]["fit_loop" ] == expected
549
552
550
553
trainer .fit_loop .load_state_dict (checkpoint ["loops" ]["fit_loop" ])
554
+ trainer .fit_loop .restarting = True
551
555
state_dict = trainer .fit_loop .state_dict ()
552
556
553
557
# need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
@@ -557,6 +561,7 @@ def configure_optimizers_multiple(self):
557
561
assert state_dict == checkpoint ["loops" ]["fit_loop" ]
558
562
559
563
trainer .fit_loop .load_state_dict (checkpoint ["loops" ]["fit_loop" ])
564
+ trainer .fit_loop .restarting = True
560
565
# test resetting manually, we expect all `ready` counters to be reset to `completed`
561
566
trainer .fit_loop .reset ()
562
567
trainer .fit_loop .epoch_loop .reset ()
@@ -753,6 +758,7 @@ def test_fit_loop_reset(tmpdir):
753
758
754
759
# we load exactly what was saved - no reset yet
755
760
fit_loop .load_state_dict (mid_epoch_ckpt ["loops" ]["fit_loop" ])
761
+ fit_loop .restarting = True
756
762
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
757
763
fit_loop .reset ()
758
764
epoch_loop .reset ()
@@ -785,6 +791,7 @@ def test_fit_loop_reset(tmpdir):
785
791
786
792
# we load exactly what was saved - no reset yet
787
793
fit_loop .load_state_dict (end_of_epoch_ckpt ["loops" ]["fit_loop" ])
794
+ fit_loop .restarting = True
788
795
# resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0
789
796
fit_loop .reset ()
790
797
epoch_loop .reset ()
0 commit comments