Skip to content

Commit 15aa9c6

Browse files
authored
An instance of SaveConfigCallback should only save the config once (#14927)
1 parent abea29b commit 15aa9c6

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
295295

296296

297297

298+
- `SaveConfigCallback` instances should only save the config once to allow having the `overwrite=False` safeguard when using `LightningCLI(..., run=False)` ([#14927](https://github.com/Lightning-AI/lightning/pull/14927))
299+
300+
298301
## [1.7.7] - 2022-09-22
299302

300303
### Fixed

src/pytorch_lightning/cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,12 @@ def __init__(
209209
self.config_filename = config_filename
210210
self.overwrite = overwrite
211211
self.multifile = multifile
212+
self.already_saved = False
212213

213214
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
215+
if self.already_saved:
216+
return
217+
214218
log_dir = trainer.log_dir # this broadcasts the directory
215219
assert log_dir is not None
216220
config_path = os.path.join(log_dir, self.config_filename)
@@ -238,6 +242,10 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
238242
self.parser.save(
239243
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
240244
)
245+
self.already_saved = True
246+
247+
# broadcast so that all ranks are in sync on future calls to .setup()
248+
self.already_saved = trainer.strategy.broadcast(self.already_saved)
241249

242250

243251
class LightningCLI:

tests/tests_pytorch/test_cli.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,22 @@ def test_lightning_cli_save_config_cases(tmpdir):
249249
LightningCLI(BoringModel)
250250

251251

252+
def test_lightning_cli_save_config_only_once(tmpdir):
253+
config_path = tmpdir / "config.yaml"
254+
cli_args = [f"--trainer.default_root_dir={tmpdir}", "--trainer.logger=False", "--trainer.max_epochs=1"]
255+
256+
with mock.patch("sys.argv", ["any.py"] + cli_args):
257+
cli = LightningCLI(BoringModel, run=False)
258+
259+
save_config_callback = next(c for c in cli.trainer.callbacks if isinstance(c, SaveConfigCallback))
260+
assert not save_config_callback.overwrite
261+
assert not save_config_callback.already_saved
262+
cli.trainer.fit(cli.model)
263+
assert os.path.isfile(config_path)
264+
assert save_config_callback.already_saved
265+
cli.trainer.test(cli.model) # Should not fail because config already saved
266+
267+
252268
def test_lightning_cli_config_and_subclass_mode(tmpdir):
253269
input_config = {
254270
"fit": {

0 commit comments

Comments
 (0)