@@ -254,7 +254,7 @@ def __init__(self, *args, **kwargs):
254
254
(CombinedLoader ([NotStatefulDataLoader (3 ), StatefulDataLoader (1 ), NotStatefulDataLoader (2 )]), [{"label" : 1 }]),
255
255
])
256
256
def test_train_dataloaders_restore (train_dataloaders , expected_states , tmp_path ):
257
- """Test that the CheckpointConnector saves the state of stateful dataloaders and can reloead them."""
257
+ """Test that the CheckpointConnector saves the state of stateful dataloaders and can reload them."""
258
258
class DataLoaderModel (BoringModel ):
259
259
def training_step (self , batch , batch_idx ):
260
260
if isinstance (batch , list ):
@@ -264,8 +264,7 @@ def training_step(self, batch, batch_idx):
264
264
def train_dataloader (self ):
265
265
return train_dataloaders
266
266
267
- model = DataLoaderModel ()
268
- trainer = Trainer (
267
+ trainer_kwargs = dict (
269
268
default_root_dir = tmp_path ,
270
269
accelerator = "cpu" ,
271
270
max_steps = 1 ,
@@ -275,6 +274,10 @@ def train_dataloader(self):
275
274
logger = False ,
276
275
num_sanity_val_steps = 0 ,
277
276
)
277
+
278
+ model = DataLoaderModel ()
279
+ trainer = Trainer (** trainer_kwargs )
280
+
278
281
# Fit to init the state of CheckpointConnector
279
282
trainer .fit (model )
280
283
checkpoint = trainer ._checkpoint_connector .dump_checkpoint ()
@@ -283,3 +286,10 @@ def train_dataloader(self):
283
286
assert "train_dataloaders" not in checkpoint
284
287
else :
285
288
assert checkpoint ["train_dataloaders" ] == expected_states
289
+
290
+ torch .save (checkpoint , tmp_path / "checkpoint.ckpt" )
291
+
292
+ model = DataLoaderModel ()
293
+ trainer = Trainer (** trainer_kwargs )
294
+ trainer .fit (model , ckpt_path = (tmp_path / "checkpoint.ckpt" ))
295
+ # TODO: Test here
0 commit comments