Skip to content

Commit 365bf10

Browse files
authored
Resolve FitLoop setter TODOs (#16803)
1 parent 781768d commit 365bf10

File tree

7 files changed

+11
-25
lines changed

7 files changed

+11
-25
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
336336
- Removed `ProgressBarBase.{train_batch_idx,val_batch_idx,test_batch_idx,predict_batch_idx}` properties ([#16760](https://github.com/Lightning-AI/lightning/pull/16760))
337337

338338

339+
- Removed the `fit_loop.{min,max}_steps` setters ([#16803](https://github.com/Lightning-AI/lightning/pull/16803))
340+
339341

340342
- Removed the `Trainer(track_grad_norm=...)` argument ([#16745](https://github.com/Lightning-AI/lightning/pull/16745))
341343

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,27 +99,11 @@ def min_steps(self) -> Optional[int]:
9999
"""Returns the minimum number of steps to run."""
100100
return self.epoch_loop.min_steps
101101

102-
@min_steps.setter
103-
def min_steps(self, value: Optional[int]) -> None:
104-
"""Sets the minimum number of steps (forwards to epoch_loop)"""
105-
# TODO: This setter is required by debugging connector (fast dev run), should be avoided
106-
self.epoch_loop.min_steps = value
107-
108102
@property
109103
def max_steps(self) -> int:
110104
"""Returns the maximum number of steps to run."""
111105
return self.epoch_loop.max_steps
112106

113-
@max_steps.setter
114-
def max_steps(self, value: int) -> None:
115-
"""Sets the maximum number of steps (forwards to epoch_loop)"""
116-
# TODO: This setter is required by debugging connector (fast dev run), should be avoided
117-
if value < -1:
118-
raise MisconfigurationException(
119-
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
120-
)
121-
self.epoch_loop.max_steps = value
122-
123107
@_Loop.restarting.setter
124108
def restarting(self, restarting: bool) -> None:
125109
# if the last epoch completely finished, we are not actually restarting

src/lightning/pytorch/trainer/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _init_debugging_flags(
7272

7373
trainer.limit_test_batches = num_batches
7474
trainer.limit_predict_batches = num_batches
75-
trainer.fit_loop.max_steps = num_batches
75+
trainer.fit_loop.epoch_loop.max_steps = num_batches
7676
trainer.num_sanity_val_steps = 0
7777
trainer.fit_loop.max_epochs = 1
7878
trainer.val_check_interval = 1.0

src/lightning/pytorch/tuner/batch_size_scaling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N
129129
if isinstance(loop, pl.loops._FitLoop):
130130
trainer.limit_train_batches = 1.0
131131
trainer.limit_val_batches = steps_per_trial
132-
trainer.fit_loop.max_steps = steps_per_trial
132+
trainer.fit_loop.epoch_loop.max_steps = steps_per_trial
133133
elif isinstance(loop, pl.loops._EvaluationLoop):
134134
stage = trainer.state.stage
135135
assert stage is not None
@@ -145,7 +145,7 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any])
145145
loop = trainer._active_loop
146146
assert loop is not None
147147
if isinstance(loop, pl.loops._FitLoop):
148-
loop.max_steps = params["max_steps"]
148+
loop.epoch_loop.max_steps = params["max_steps"]
149149
trainer.limit_train_batches = params["limit_train_batches"]
150150
trainer.limit_val_batches = params["limit_val_batches"]
151151
elif isinstance(loop, pl.loops._EvaluationLoop):

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto
320320
# No logging
321321
trainer.logger = DummyLogger() if trainer.logger is not None else None
322322
# Max step set to number of iterations starting at current number of iterations
323-
trainer.fit_loop.max_steps = num_training + trainer.global_step
323+
trainer.fit_loop.epoch_loop.max_steps = num_training + trainer.global_step
324324
trainer.limit_val_batches = num_training
325325

326326

@@ -329,10 +329,10 @@ def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) ->
329329
trainer.strategy.lr_scheduler_configs = params["lr_scheduler_configs"]
330330
trainer.callbacks = params["callbacks"]
331331
trainer.loggers = params["loggers"]
332-
trainer.fit_loop.max_steps = params["max_steps"]
332+
loop = trainer.fit_loop
333+
loop.epoch_loop.max_steps = params["max_steps"]
333334
trainer.limit_val_batches = params["limit_val_batches"]
334335

335-
loop = trainer.fit_loop
336336
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
337337
loop.restarting = False
338338
trainer.should_stop = False

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ def test_fit_loop_done_log_messages(caplog):
152152
epoch_loop = Mock()
153153
epoch_loop.global_step = 10
154154
fit_loop.epoch_loop = epoch_loop
155-
fit_loop.max_steps = 10
155+
epoch_loop.max_steps = 10
156156
assert fit_loop.done
157157
assert "max_steps=10` reached" in caplog.text
158158
caplog.clear()
159-
fit_loop.max_steps = 20
159+
epoch_loop.max_steps = 20
160160

161161
fit_loop.epoch_progress.current.processed = 3
162162
fit_loop.max_epochs = 3

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,7 @@ def test_dataloaders_reset_and_attach(tmpdir):
11961196
assert trainer.train_dataloader.dataset is dataloader_0.dataset
11971197
assert trainer.val_dataloaders.dataset is dataloader_1.dataset
11981198
# 2nd fit
1199-
trainer.fit_loop.max_steps += 1
1199+
trainer.fit_loop.epoch_loop.max_steps += 1
12001200
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
12011201
assert trainer.train_dataloader.dataset is dataloader_2.dataset
12021202
assert trainer.val_dataloaders.dataset is dataloader_3.dataset

0 commit comments

Comments
 (0)