Skip to content

Commit 71bf4fc

Browse files
authored
Flatten fetching abstract interface (#16664)
1 parent 2add5d3 commit 71bf4fc

File tree

8 files changed

+93
-156
lines changed

8 files changed

+93
-156
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
155155
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
156156
* Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
157157
* The loop classes are now marked as protected ([#16445](https://github.com/Lightning-AI/lightning/pull/16445))
158+
* The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664))
158159

159160
- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
160161
* Removed the `LightningModule.truncated_bptt_steps` attribute

src/lightning/pytorch/loops/dataloader/evaluation_loop.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,19 @@
1515
import shutil
1616
import sys
1717
from collections import ChainMap, OrderedDict
18-
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Type, Union
18+
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
1919

2020
from lightning_utilities.core.apply_func import apply_to_collection
2121
from torch import Tensor
2222
from torch.utils.data.dataloader import DataLoader
2323

24-
import lightning.pytorch as pl
2524
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
2625
from lightning.pytorch.loops.dataloader import _DataLoaderLoop
2726
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop
28-
from lightning.pytorch.loops.utilities import _set_sampler_epoch
27+
from lightning.pytorch.loops.utilities import _select_data_fetcher, _set_sampler_epoch
2928
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
3029
from lightning.pytorch.trainer.states import TrainerFn
31-
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
32-
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
33-
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
30+
from lightning.pytorch.utilities.fetching import _DataFetcher
3431

3532
if _RICH_AVAILABLE:
3633
from rich import get_console
@@ -53,7 +50,7 @@ def __init__(self, verbose: bool = True) -> None:
5350
self._logged_outputs: List[_OUT_DICT] = []
5451
self._max_batches: List[Union[int, float]] = []
5552
self._has_run: bool = False
56-
self._data_fetcher: Optional[AbstractDataFetcher] = None
53+
self._data_fetcher: Optional[_DataFetcher] = None
5754

5855
@property
5956
def num_dataloaders(self) -> int:
@@ -125,8 +122,7 @@ def reset(self) -> None:
125122
def on_run_start(self) -> None:
126123
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
127124
hooks."""
128-
data_fetcher_cls = _select_data_fetcher_type(self.trainer)
129-
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
125+
self._data_fetcher = _select_data_fetcher(self.trainer, prefetch_batches=self.prefetch_batches)
130126

131127
# hook
132128
self._on_evaluation_model_eval()
@@ -373,16 +369,3 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None:
373369
lines.append(row_format.format(metric, *row).rstrip())
374370
lines.append(bar)
375371
print(os.linesep.join(lines))
376-
377-
378-
def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
379-
lightning_module = trainer.lightning_module
380-
step_fx_name = "test_step" if trainer.testing else "validation_step"
381-
step_fx = getattr(lightning_module, step_fx_name)
382-
if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True):
383-
rank_zero_warn(
384-
f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for "
385-
"this signature is experimental and the behavior is subject to change."
386-
)
387-
return DataLoaderIterDataFetcher
388-
return DataFetcher

src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from lightning.pytorch.loops.progress import BatchProgress
2020
from lightning.pytorch.trainer.states import TrainerFn
2121
from lightning.pytorch.utilities.exceptions import SIGTERMException
22-
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
22+
from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher
2323
from lightning.pytorch.utilities.types import STEP_OUTPUT
2424

2525

@@ -35,15 +35,15 @@ def __init__(self) -> None:
3535
self.batch_progress = BatchProgress()
3636

3737
self._dl_max_batches: Union[int, float] = 0
38-
self._data_fetcher: Optional[AbstractDataFetcher] = None
38+
self._data_fetcher: Optional[_DataFetcher] = None
3939
self._dl_batch_idx = [0]
4040

4141
@property
4242
def done(self) -> bool:
4343
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
4444
return self.batch_progress.current.completed >= self._dl_max_batches
4545

46-
def run(self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None:
46+
def run(self, data_fetcher: _DataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None:
4747
self.reset()
4848
self.on_run_start(data_fetcher, dl_max_batches, kwargs)
4949
while not self.done:
@@ -69,9 +69,7 @@ def reset(self) -> None:
6969
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
7070
self.batch_progress.reset_on_run()
7171

72-
def on_run_start(
73-
self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict
74-
) -> None:
72+
def on_run_start(self, data_fetcher: _DataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None:
7573
"""Adds the passed arguments to the loop's state if necessary.
7674
7775
Args:
@@ -102,7 +100,7 @@ def _on_after_fetch(self) -> None:
102100

103101
def advance(
104102
self,
105-
data_fetcher: AbstractDataFetcher,
103+
data_fetcher: _DataFetcher,
106104
kwargs: OrderedDict,
107105
) -> None:
108106
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.
@@ -114,7 +112,7 @@ def advance(
114112
Raises:
115113
StopIteration: If the current batch is None
116114
"""
117-
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
115+
if not isinstance(data_fetcher, _DataLoaderIterDataFetcher):
118116
batch_idx = self.batch_progress.current.ready
119117
batch = next(data_fetcher)
120118
else:

src/lightning/pytorch/loops/epoch/training_epoch_loop.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lightning.pytorch.loops.utilities import _is_max_limit_reached
2626
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
2727
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
28-
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
28+
from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher
2929
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
3030
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
3131

@@ -125,7 +125,7 @@ def done(self) -> bool:
125125

126126
return False
127127

128-
def run(self, data_fetcher: AbstractDataFetcher) -> None:
128+
def run(self, data_fetcher: _DataFetcher) -> None:
129129
self.reset()
130130
self.on_run_start(data_fetcher)
131131
while not self.done:
@@ -160,7 +160,7 @@ def reset(self) -> None:
160160
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
161161
self.val_loop.epoch_loop.batch_progress.total.reset()
162162

163-
def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None:
163+
def on_run_start(self, data_fetcher: _DataFetcher) -> None:
164164
_ = iter(data_fetcher) # creates the iterator inside the fetcher
165165
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
166166
data_fetcher.fetched += self.batch_progress.current.ready
@@ -174,7 +174,7 @@ def _on_before_fetch(self) -> None:
174174
def _on_after_fetch(self) -> None:
175175
self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")
176176

177-
def advance(self, data_fetcher: AbstractDataFetcher) -> None:
177+
def advance(self, data_fetcher: _DataFetcher) -> None:
178178
"""Runs a single training batch.
179179
180180
Raises:
@@ -186,7 +186,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
186186
# we are going to train first so the val loop does not need to restart
187187
self.val_loop.restarting = False
188188

189-
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
189+
if not isinstance(data_fetcher, _DataLoaderIterDataFetcher):
190190
batch_idx = self.batch_idx + 1
191191
batch = next(data_fetcher)
192192
else:

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import Any, Optional, Type
15+
from typing import Any, Optional
1616

17-
import lightning.pytorch as pl
1817
from lightning.pytorch.loops import _Loop
1918
from lightning.pytorch.loops.epoch import _TrainingEpochLoop
2019
from lightning.pytorch.loops.progress import Progress
21-
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
20+
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher, _set_sampler_epoch
2221
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
2322
from lightning.pytorch.trainer.supporters import CombinedLoader
2423
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
25-
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
26-
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
27-
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
24+
from lightning.pytorch.utilities.fetching import _DataFetcher
25+
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info
2826

2927
log = logging.getLogger(__name__)
3028

@@ -74,7 +72,7 @@ def __init__(
7472
self.epoch_progress = Progress()
7573

7674
self._is_fresh_start_epoch: bool = True
77-
self._data_fetcher: Optional[AbstractDataFetcher] = None
75+
self._data_fetcher: Optional[_DataFetcher] = None
7876

7977
@property
8078
def total_batch_idx(self) -> int:
@@ -217,8 +215,7 @@ def on_run_start(self) -> None:
217215
if self.epoch_loop._should_check_val_epoch():
218216
self.epoch_loop.val_loop._reload_evaluation_dataloaders()
219217

220-
data_fetcher_cls = _select_data_fetcher(self.trainer)
221-
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
218+
self._data_fetcher = _select_data_fetcher(self.trainer, self.prefetch_batches)
222219

223220
self._is_fresh_start_epoch = True
224221
self._results.to(device=self.trainer.lightning_module.device)
@@ -324,14 +321,3 @@ def _should_accumulate(self) -> bool:
324321

325322
def _iteration_based_training(self) -> bool:
326323
return self.trainer.max_steps != -1
327-
328-
329-
def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
330-
training_step_fx = getattr(trainer.lightning_module, "training_step")
331-
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
332-
rank_zero_warn(
333-
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
334-
"this signature is experimental and the behavior is subject to change."
335-
)
336-
return DataLoaderIterDataFetcher
337-
return DataFetcher

src/lightning/pytorch/loops/utilities.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from lightning.pytorch.loops.progress import BaseProgress
2525
from lightning.pytorch.strategies.parallel import ParallelStrategy
2626
from lightning.pytorch.strategies.strategy import Strategy
27+
from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher
2728
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
29+
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
2830

2931

3032
def check_finite_loss(loss: Optional[Tensor]) -> None:
@@ -130,3 +132,23 @@ def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
130132
sampler = getattr(dataloader, sampler_name, None)
131133
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
132134
sampler.set_epoch(epoch)
135+
136+
137+
def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _DataFetcher:
138+
lightning_module = trainer.lightning_module
139+
if trainer.testing:
140+
step_fx_name = "test_step"
141+
elif trainer.training:
142+
step_fx_name = "training_step"
143+
elif trainer.validating or trainer.sanity_checking:
144+
step_fx_name = "validation_step"
145+
else:
146+
raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}")
147+
step_fx = getattr(lightning_module, step_fx_name)
148+
if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True):
149+
rank_zero_warn(
150+
f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for "
151+
"this signature is experimental and the behavior is subject to change."
152+
)
153+
return _DataLoaderIterDataFetcher()
154+
return _PrefetchDataFetcher(prefetch_batches=prefetch_batches)

0 commit comments

Comments
 (0)