Skip to content

Commit 5317545

Browse files
committed
test restore
1 parent d1cc990 commit 5317545

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,9 @@ def _restore_train_dataloaders(self) -> None:
398398
if not state_dicts:
399399
return
400400

401-
for train_dataloader, state_dict in zip(self.trainer.train_dataloader, state_dicts):
401+
combined_loader = self.trainer.fit_loop._combined_loader
402+
iterables = combined_loader.flattened if combined_loader is not None else []
403+
for train_dataloader, state_dict in zip(iterables, state_dicts):
402404
if isinstance(train_dataloader, _Stateful):
403405
train_dataloader.load_state_dict(state_dict)
404406

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def __init__(self, *args, **kwargs):
254254
(CombinedLoader([NotStatefulDataLoader(3), StatefulDataLoader(1), NotStatefulDataLoader(2)]), [{"label": 1}]),
255255
])
256256
def test_train_dataloaders_restore(train_dataloaders, expected_states, tmp_path):
257-
"""Test that the CheckpointConnector saves the state of stateful dataloaders and can reloead them."""
257+
"""Test that the CheckpointConnector saves the state of stateful dataloaders and can reload them."""
258258
class DataLoaderModel(BoringModel):
259259
def training_step(self, batch, batch_idx):
260260
if isinstance(batch, list):
@@ -264,8 +264,7 @@ def training_step(self, batch, batch_idx):
264264
def train_dataloader(self):
265265
return train_dataloaders
266266

267-
model = DataLoaderModel()
268-
trainer = Trainer(
267+
trainer_kwargs = dict(
269268
default_root_dir=tmp_path,
270269
accelerator="cpu",
271270
max_steps=1,
@@ -275,6 +274,10 @@ def train_dataloader(self):
275274
logger=False,
276275
num_sanity_val_steps=0,
277276
)
277+
278+
model = DataLoaderModel()
279+
trainer = Trainer(**trainer_kwargs)
280+
278281
# Fit to init the state of CheckpointConnector
279282
trainer.fit(model)
280283
checkpoint = trainer._checkpoint_connector.dump_checkpoint()
@@ -283,3 +286,10 @@ def train_dataloader(self):
283286
assert "train_dataloaders" not in checkpoint
284287
else:
285288
assert checkpoint["train_dataloaders"] == expected_states
289+
290+
torch.save(checkpoint, tmp_path / "checkpoint.ckpt")
291+
292+
model = DataLoaderModel()
293+
trainer = Trainer(**trainer_kwargs)
294+
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
295+
# TODO: Test here

0 commit comments

Comments
 (0)