Skip to content

Commit 051d316

Browse files
authored
Loop flattening: remove .connect() (#16384)
1 parent f506c54 commit 051d316

File tree

12 files changed

+22
-276
lines changed

12 files changed

+22
-276
lines changed

docs/source-pytorch/extensions/loops.rst

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -259,28 +259,6 @@ run (optional)
259259

260260
----------
261261

262-
Subloops
263-
--------
264-
265-
When you want to customize nested loops within loops use the :meth:`~pytorch_lightning.loops.loop.Loop.connect` method:
266-
267-
.. code-block:: python
268-
269-
# Optional: stitch back the trainer arguments
270-
epoch_loop = MyEpochLoop(trainer.fit_loop.epoch_loop.min_steps, trainer.fit_loop.epoch_loop.max_steps)
271-
# Optional: connect children loops as they might have existing state
272-
epoch_loop.connect(trainer.fit_loop.epoch_loop.batch_loop, trainer.fit_loop.epoch_loop.val_loop)
273-
# Instantiate and connect the loop.
274-
trainer.fit_loop.connect(epoch_loop=epoch_loop)
275-
trainer.fit(model)
276-
277-
More about the built-in loops and how they are composed is explained in the next section.
278-
279-
.. image:: https://pl-public-data.s3.amazonaws.com/docs/static/images/loops/connect-epoch-loop.gif
280-
:alt: Animation showing how to connect a custom subloop
281-
282-
----------
283-
284262
Built-in Loops
285263
--------------
286264

@@ -342,71 +320,6 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt
342320
It simply iterates over each prediction dataloader from one to the next by calling :code:`PredictionEpochLoop.run()` in its :code:`advance()` method.
343321

344322

345-
----------
346-
347-
Available Loops in Lightning Flash
348-
----------------------------------
349-
350-
`Active Learning <https://en.wikipedia.org/wiki/Active_learning_(machine_learning)>`__ is a machine learning practice in which the user interacts with the learner in order to provide new labels when required.
351-
352-
You can find a real use case in `Lightning Flash <https://github.com/Lightning-AI/lightning-flash>`_.
353-
354-
Flash implements the :code:`ActiveLearningLoop` that you can use together with the :code:`ActiveLearningDataModule` to label new data on the fly.
355-
To run the following demo, install Flash and `BaaL <https://github.com/ElementAI/baal>`__ first:
356-
357-
.. code-block:: bash
358-
359-
pip install lightning-flash[image] baal
360-
361-
.. code-block:: python
362-
363-
import torch
364-
365-
import flash
366-
from flash.core.classification import Probabilities
367-
from flash.core.data.utils import download_data
368-
from flash.image import ImageClassificationData, ImageClassifier
369-
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
370-
371-
# 1. Create the DataModule
372-
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
373-
374-
# Implement the research use-case where we mask labels from labelled dataset.
375-
datamodule = ActiveLearningDataModule(
376-
ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
377-
initial_num_labels=5,
378-
val_split=0.1,
379-
)
380-
381-
# 2. Build the task
382-
head = torch.nn.Sequential(
383-
torch.nn.Dropout(p=0.1),
384-
torch.nn.Linear(512, datamodule.num_classes),
385-
)
386-
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=Probabilities())
387-
388-
389-
# 3.1 Create the trainer
390-
trainer = flash.Trainer(max_epochs=3)
391-
392-
# 3.2 Create the active learning loop and connect it to the trainer
393-
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
394-
active_learning_loop.connect(trainer.fit_loop)
395-
trainer.fit_loop = active_learning_loop
396-
397-
# 3.3 Finetune
398-
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
399-
400-
# 4. Predict what's on a few images! ants or bees?
401-
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
402-
print(predictions)
403-
404-
# 5. Save the model!
405-
trainer.save_checkpoint("image_classification_model.pt")
406-
407-
Here is the `Active Learning Loop example <https://github.com/Lightning-AI/lightning-flash/blob/master/flash_examples/integrations/baal/image_classification_active_learning.py>`_ and the `code for the active learning loop <https://github.com/Lightning-AI/lightning-flash/blob/master/flash/image/classification/integrations/baal/loop.py>`_.
408-
409-
410323
----------
411324

412325
Advanced Examples

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5555

5656
- Removed support for loop customization
5757
* Removed `Loop.replace()` ([#16361](https://github.com/Lightning-AI/lightning/pull/16361))
58+
* Removed `Loop.connect()` ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
59+
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
5860

5961
- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
6062
* Removed the `LightningModule.truncated_bptt_steps` attribute

src/pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,6 @@ def prefetch_batches(self) -> int:
7878
is_unsized = batches[self.current_dataloader_idx] == float("inf")
7979
return int(is_unsized)
8080

81-
def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
82-
"""Connect the evaluation epoch loop with this loop."""
83-
self.epoch_loop = epoch_loop
84-
8581
@property
8682
def done(self) -> bool:
8783
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether."""

src/pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,6 @@ def dataloaders(self) -> Sequence[DataLoader]:
6666
def skip(self) -> bool:
6767
return sum(self.max_batches) == 0
6868

69-
def connect(self, epoch_loop: PredictionEpochLoop) -> None: # type: ignore[override]
70-
"""Connect the prediction epoch loop with this loop."""
71-
self.epoch_loop = epoch_loop
72-
7369
def reset(self) -> None:
7470
"""Resets the internal state of the loop for a new run."""
7571
self.predictions = []

src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ def should_store_predictions(self) -> bool:
3838
any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks)
3939
return self.return_predictions or any_pred
4040

41-
def connect(self, **kwargs: "Loop") -> None:
42-
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
43-
4441
def reset(self) -> None:
4542
"""Resets the loops internal state."""
4643
self._seen_batch_indices = []

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,6 @@ def done(self) -> bool:
121121

122122
return False
123123

124-
def connect( # type: ignore[override]
125-
self,
126-
optimizer_loop: Optional[OptimizerLoop] = None,
127-
manual_loop: Optional[ManualOptimization] = None,
128-
val_loop: Optional["loops.EvaluationLoop"] = None,
129-
) -> None:
130-
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
131-
if optimizer_loop is not None:
132-
self.optimizer_loop = optimizer_loop
133-
if manual_loop is not None:
134-
self.manual_loop = manual_loop
135-
if val_loop is not None:
136-
self.val_loop = val_loop
137-
138124
def reset(self) -> None:
139125
"""Resets the internal state of the loop for a new run."""
140126
if self.restarting:

src/pytorch_lightning/loops/fit_loop.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,6 @@ def skip(self) -> bool:
169169
# until `on_run_start`, we use `limit_train_batches` instead
170170
return self.done or self.trainer.limit_train_batches == 0
171171

172-
def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override]
173-
"""Connects a training epoch loop to this fit loop."""
174-
self.epoch_loop = epoch_loop
175-
176172
def reset(self) -> None:
177173
"""Resets the internal state of this loop."""
178174
if self.restarting:

src/pytorch_lightning/loops/loop.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,6 @@ def skip(self):
100100
"""
101101
return False
102102

103-
def connect(self, **kwargs: "Loop") -> None:
104-
"""Optionally connect one or multiple loops to this one.
105-
106-
Linked loops should form a tree.
107-
"""
108-
109103
def on_skip(self) -> T:
110104
"""The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
111105

src/pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,6 @@ def done(self) -> bool:
172172
"""Returns ``True`` when the last optimizer in the sequence has run."""
173173
return self.optim_progress.optimizer_position >= len(self._indices)
174174

175-
def connect(self, **kwargs: "Loop") -> None:
176-
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
177-
178175
def reset(self) -> None:
179176
if not self.restarting:
180177
# when reset() is called from outside (manually), we reset the loop progress

src/pytorch_lightning/trainer/trainer.py

Lines changed: 7 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -371,21 +371,16 @@ def __init__(
371371
self._signal_connector = SignalConnector(self)
372372
self.tuner = Tuner(self)
373373

374-
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
375-
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
376-
fit_loop.connect(epoch_loop=training_epoch_loop)
377-
378-
# default .fit() loop
379-
self.fit_loop = fit_loop
380-
381-
# default .validate() loop
374+
# init loops
375+
self.fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
376+
self.fit_loop.epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
382377
self.validate_loop = EvaluationLoop()
383-
384-
# default .test() loop
385378
self.test_loop = EvaluationLoop()
386-
387-
# default .predict() loop
388379
self.predict_loop = PredictionLoop()
380+
self.fit_loop.trainer = self
381+
self.validate_loop.trainer = self
382+
self.test_loop.trainer = self
383+
self.predict_loop.trainer = self
389384

390385
# init callbacks
391386
# Declare attributes to be set in _callback_connector on_trainer_init
@@ -1103,8 +1098,6 @@ def _run_train(self) -> None:
11031098
self.model.train()
11041099
torch.set_grad_enabled(True)
11051100

1106-
self.fit_loop.trainer = self
1107-
11081101
with torch.autograd.set_detect_anomaly(self._detect_anomaly):
11091102
self.fit_loop.run()
11101103

@@ -1114,9 +1107,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
11141107
# reload dataloaders
11151108
self._evaluation_loop._reload_evaluation_dataloaders()
11161109

1117-
# reset trainer on this loop and all child loops in case user connected a custom loop
1118-
self._evaluation_loop.trainer = self
1119-
11201110
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context(
11211111
self.accelerator, self._inference_mode
11221112
):
@@ -1133,8 +1123,6 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
11331123

11341124
def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
11351125
self.reset_predict_dataloader(self.lightning_module)
1136-
# reset trainer on this loop and all child loops in case user connected a custom loop
1137-
self.predict_loop.trainer = self
11381126
with _evaluation_context(self.accelerator, self._inference_mode):
11391127
return self.predict_loop.run()
11401128

@@ -1955,63 +1943,6 @@ def is_last_batch(self) -> bool:
19551943
"""Whether trainer is executing the last batch."""
19561944
return self.fit_loop.epoch_loop.batch_progress.is_last_batch
19571945

1958-
@property
1959-
def fit_loop(self) -> FitLoop:
1960-
return self._fit_loop
1961-
1962-
@fit_loop.setter
1963-
def fit_loop(self, loop: FitLoop) -> None:
1964-
"""Attach a custom fit loop to this Trainer.
1965-
1966-
It will run with
1967-
:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit`.
1968-
"""
1969-
loop.trainer = self
1970-
self._fit_loop = loop
1971-
1972-
@property
1973-
def validate_loop(self) -> EvaluationLoop:
1974-
return self._validate_loop
1975-
1976-
@validate_loop.setter
1977-
def validate_loop(self, loop: EvaluationLoop) -> None:
1978-
"""Attach a custom validation loop to this Trainer.
1979-
1980-
It will run with
1981-
:meth:`~pytorch_lightning.trainer.trainer.Trainer.validate`. Note that this loop is different from the one
1982-
running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call.
1983-
"""
1984-
loop.trainer = self
1985-
self._validate_loop = loop
1986-
1987-
@property
1988-
def test_loop(self) -> EvaluationLoop:
1989-
return self._test_loop
1990-
1991-
@test_loop.setter
1992-
def test_loop(self, loop: EvaluationLoop) -> None:
1993-
"""Attach a custom test loop to this Trainer.
1994-
1995-
It will run with
1996-
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.
1997-
"""
1998-
loop.trainer = self
1999-
self._test_loop = loop
2000-
2001-
@property
2002-
def predict_loop(self) -> PredictionLoop:
2003-
return self._predict_loop
2004-
2005-
@predict_loop.setter
2006-
def predict_loop(self, loop: PredictionLoop) -> None:
2007-
"""Attach a custom prediction loop to this Trainer.
2008-
2009-
It will run with
2010-
:meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`.
2011-
"""
2012-
loop.trainer = self
2013-
self._predict_loop = loop
2014-
20151946
@property
20161947
def _evaluation_loop(self) -> EvaluationLoop:
20171948
if self.state.fn == TrainerFn.FITTING:

0 commit comments

Comments
 (0)