Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384))
* The loop classes are now marked as protected ([#16445](https://github.com/Lightning-AI/lightning/pull/16445))
* The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664))

- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
* Removed the `LightningModule.truncated_bptt_steps` attribute
Expand Down
27 changes: 5 additions & 22 deletions src/lightning/pytorch/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,19 @@
import shutil
import sys
from collections import ChainMap, OrderedDict
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union

from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.utils.data.dataloader import DataLoader

import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.loops.dataloader import _DataLoaderLoop
from lightning.pytorch.loops.epoch import _EvaluationEpochLoop
from lightning.pytorch.loops.utilities import _set_sampler_epoch
from lightning.pytorch.loops.utilities import _select_data_fetcher, _set_sampler_epoch
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.fetching import _DataFetcher

if _RICH_AVAILABLE:
from rich import get_console
Expand All @@ -53,7 +50,7 @@ def __init__(self, verbose: bool = True) -> None:
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[Union[int, float]] = []
self._has_run: bool = False
self._data_fetcher: Optional[AbstractDataFetcher] = None
self._data_fetcher: Optional[_DataFetcher] = None

@property
def num_dataloaders(self) -> int:
Expand Down Expand Up @@ -125,8 +122,7 @@ def reset(self) -> None:
def on_run_start(self) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
data_fetcher_cls = _select_data_fetcher_type(self.trainer)
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
self._data_fetcher = _select_data_fetcher(self.trainer, prefetch_batches=self.prefetch_batches)

# hook
self._on_evaluation_model_eval()
Expand Down Expand Up @@ -373,16 +369,3 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None:
lines.append(row_format.format(metric, *row).rstrip())
lines.append(bar)
print(os.linesep.join(lines))


def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
lightning_module = trainer.lightning_module
step_fx_name = "test_step" if trainer.testing else "validation_step"
step_fx = getattr(lightning_module, step_fx_name)
if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher
return DataFetcher
14 changes: 6 additions & 8 deletions src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lightning.pytorch.loops.progress import BatchProgress
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import SIGTERMException
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.utilities.types import STEP_OUTPUT


Expand All @@ -35,15 +35,15 @@ def __init__(self) -> None:
self.batch_progress = BatchProgress()

self._dl_max_batches: Union[int, float] = 0
self._data_fetcher: Optional[AbstractDataFetcher] = None
self._data_fetcher: Optional[_DataFetcher] = None
self._dl_batch_idx = [0]

@property
def done(self) -> bool:
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
return self.batch_progress.current.completed >= self._dl_max_batches

def run(self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None:
def run(self, data_fetcher: _DataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None:
self.reset()
self.on_run_start(data_fetcher, dl_max_batches, kwargs)
while not self.done:
Expand All @@ -69,9 +69,7 @@ def reset(self) -> None:
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
self.batch_progress.reset_on_run()

def on_run_start(
self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict
) -> None:
def on_run_start(self, data_fetcher: _DataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None:
"""Adds the passed arguments to the loop's state if necessary.

Args:
Expand Down Expand Up @@ -102,7 +100,7 @@ def _on_after_fetch(self) -> None:

def advance(
self,
data_fetcher: AbstractDataFetcher,
data_fetcher: _DataFetcher,
kwargs: OrderedDict,
) -> None:
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.
Expand All @@ -114,7 +112,7 @@ def advance(
Raises:
StopIteration: If the current batch is None
"""
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
if not isinstance(data_fetcher, _DataLoaderIterDataFetcher):
batch_idx = self.batch_progress.current.ready
batch = next(data_fetcher)
else:
Expand Down
10 changes: 5 additions & 5 deletions src/lightning/pytorch/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lightning.pytorch.loops.utilities import _is_max_limit_reached
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature

Expand Down Expand Up @@ -125,7 +125,7 @@ def done(self) -> bool:

return False

def run(self, data_fetcher: AbstractDataFetcher) -> None:
def run(self, data_fetcher: _DataFetcher) -> None:
self.reset()
self.on_run_start(data_fetcher)
while not self.done:
Expand Down Expand Up @@ -160,7 +160,7 @@ def reset(self) -> None:
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
self.val_loop.epoch_loop.batch_progress.total.reset()

def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None:
def on_run_start(self, data_fetcher: _DataFetcher) -> None:
_ = iter(data_fetcher) # creates the iterator inside the fetcher
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching
data_fetcher.fetched += self.batch_progress.current.ready
Expand All @@ -174,7 +174,7 @@ def _on_before_fetch(self) -> None:
def _on_after_fetch(self) -> None:
self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")

def advance(self, data_fetcher: AbstractDataFetcher) -> None:
def advance(self, data_fetcher: _DataFetcher) -> None:
"""Runs a single training batch.

Raises:
Expand All @@ -186,7 +186,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None:
# we are going to train first so the val loop does not need to restart
self.val_loop.restarting = False

if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
if not isinstance(data_fetcher, _DataLoaderIterDataFetcher):
batch_idx = self.batch_idx + 1
batch = next(data_fetcher)
else:
Expand Down
26 changes: 6 additions & 20 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Optional, Type
from typing import Any, Optional

import lightning.pytorch as pl
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.epoch import _TrainingEpochLoop
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher, _set_sampler_epoch
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.fetching import _DataFetcher
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -74,7 +72,7 @@ def __init__(
self.epoch_progress = Progress()

self._is_fresh_start_epoch: bool = True
self._data_fetcher: Optional[AbstractDataFetcher] = None
self._data_fetcher: Optional[_DataFetcher] = None

@property
def total_batch_idx(self) -> int:
Expand Down Expand Up @@ -217,8 +215,7 @@ def on_run_start(self) -> None:
if self.epoch_loop._should_check_val_epoch():
self.epoch_loop.val_loop._reload_evaluation_dataloaders()

data_fetcher_cls = _select_data_fetcher(self.trainer)
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
self._data_fetcher = _select_data_fetcher(self.trainer, self.prefetch_batches)

self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)
Expand Down Expand Up @@ -324,14 +321,3 @@ def _should_accumulate(self) -> bool:

def _iteration_based_training(self) -> bool:
return self.trainer.max_steps != -1


def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
training_step_fx = getattr(trainer.lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher
return DataFetcher
22 changes: 22 additions & 0 deletions src/lightning/pytorch/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from lightning.pytorch.loops.progress import BaseProgress
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature


def check_finite_loss(loss: Optional[Tensor]) -> None:
Expand Down Expand Up @@ -130,3 +132,23 @@ def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
sampler = getattr(dataloader, sampler_name, None)
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
sampler.set_epoch(epoch)


def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _DataFetcher:
lightning_module = trainer.lightning_module
if trainer.testing:
step_fx_name = "test_step"
elif trainer.training:
step_fx_name = "training_step"
elif trainer.validating or trainer.sanity_checking:
step_fx_name = "validation_step"
else:
raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}")
step_fx = getattr(lightning_module, step_fx_name)
if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return _DataLoaderIterDataFetcher()
return _PrefetchDataFetcher(prefetch_batches=prefetch_batches)
Loading