diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 6402d001c1e20..ada70e66fcada 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -153,6 +153,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Removed the `trainer.{fit,validate,test,predict}_loop` properties ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) * Removed the default `Loop.run()` implementation ([#16384](https://github.com/Lightning-AI/lightning/pull/16384)) * The loop classes are now marked as protected ([#16445](https://github.com/Lightning-AI/lightning/pull/16445)) + * The fetching classes are now marked as protected ([#16664](https://github.com/Lightning-AI/lightning/pull/16664)) - Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172)) * Removed the `LightningModule.truncated_bptt_steps` attribute diff --git a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py index 65b2065f47cbf..5b8fe157f599a 100644 --- a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py +++ b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py @@ -15,22 +15,19 @@ import shutil import sys from collections import ChainMap, OrderedDict -from typing import Any, Iterable, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.utils.data.dataloader import DataLoader -import lightning.pytorch as pl from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.loops.dataloader import _DataLoaderLoop from lightning.pytorch.loops.epoch import _EvaluationEpochLoop -from lightning.pytorch.loops.utilities import _set_sampler_epoch +from lightning.pytorch.loops.utilities import _select_data_fetcher, _set_sampler_epoch from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher -from lightning.pytorch.utilities.rank_zero import rank_zero_warn -from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature +from lightning.pytorch.utilities.fetching import _DataFetcher if _RICH_AVAILABLE: from rich import get_console @@ -53,7 +50,7 @@ def __init__(self, verbose: bool = True) -> None: self._logged_outputs: List[_OUT_DICT] = [] self._max_batches: List[Union[int, float]] = [] self._has_run: bool = False - self._data_fetcher: Optional[AbstractDataFetcher] = None + self._data_fetcher: Optional[_DataFetcher] = None @property def num_dataloaders(self) -> int: @@ -125,8 +122,7 @@ def reset(self) -> None: def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" - data_fetcher_cls = _select_data_fetcher_type(self.trainer) - self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches) + self._data_fetcher = _select_data_fetcher(self.trainer, prefetch_batches=self.prefetch_batches) # hook self._on_evaluation_model_eval() @@ -373,16 +369,3 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None: lines.append(row_format.format(metric, *row).rstrip()) lines.append(bar) print(os.linesep.join(lines)) - - -def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: - lightning_module = trainer.lightning_module - step_fx_name = "test_step" if trainer.testing else "validation_step" - step_fx = getattr(lightning_module, step_fx_name) - if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True): - rank_zero_warn( - f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for " - "this signature is experimental and the behavior is subject to change." - ) - return DataLoaderIterDataFetcher - return DataFetcher diff --git a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py index 4b8eb66255bfa..0ce0b51391914 100644 --- a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py @@ -19,7 +19,7 @@ from lightning.pytorch.loops.progress import BatchProgress from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import SIGTERMException -from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher +from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -35,7 +35,7 @@ def __init__(self) -> None: self.batch_progress = BatchProgress() self._dl_max_batches: Union[int, float] = 0 - self._data_fetcher: Optional[AbstractDataFetcher] = None + self._data_fetcher: Optional[_DataFetcher] = None self._dl_batch_idx = [0] @property @@ -43,7 +43,7 @@ def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches - def run(self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None: + def run(self, data_fetcher: _DataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None: self.reset() self.on_run_start(data_fetcher, dl_max_batches, kwargs) while not self.done: @@ -69,9 +69,7 @@ def reset(self) -> None: if self.done and self.trainer.state.fn != TrainerFn.FITTING: self.batch_progress.reset_on_run() - def on_run_start( - self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict - ) -> None: + def on_run_start(self, data_fetcher: _DataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None: """Adds the passed arguments to the loop's state if necessary. Args: @@ -102,7 +100,7 @@ def _on_after_fetch(self) -> None: def advance( self, - data_fetcher: AbstractDataFetcher, + data_fetcher: _DataFetcher, kwargs: OrderedDict, ) -> None: """Calls the evaluation step with the corresponding hooks and updates the logger connector. @@ -114,7 +112,7 @@ def advance( Raises: StopIteration: If the current batch is None """ - if not isinstance(data_fetcher, DataLoaderIterDataFetcher): + if not isinstance(data_fetcher, _DataLoaderIterDataFetcher): batch_idx = self.batch_progress.current.ready batch = next(data_fetcher) else: diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index 8573ebf3f379f..3836f9738c571 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -25,7 +25,7 @@ from lightning.pytorch.loops.utilities import _is_max_limit_reached from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException -from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher +from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature @@ -125,7 +125,7 @@ def done(self) -> bool: return False - def run(self, data_fetcher: AbstractDataFetcher) -> None: + def run(self, data_fetcher: _DataFetcher) -> None: self.reset() self.on_run_start(data_fetcher) while not self.done: @@ -160,7 +160,7 @@ def reset(self) -> None: # seen per epoch, this is useful for tracking when validation is run multiple times per epoch self.val_loop.epoch_loop.batch_progress.total.reset() - def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: + def on_run_start(self, data_fetcher: _DataFetcher) -> None: _ = iter(data_fetcher) # creates the iterator inside the fetcher # add the previous `fetched` value to properly track `is_last_batch` with no prefetching data_fetcher.fetched += self.batch_progress.current.ready @@ -174,7 +174,7 @@ def _on_before_fetch(self) -> None: def _on_after_fetch(self) -> None: self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next") - def advance(self, data_fetcher: AbstractDataFetcher) -> None: + def advance(self, data_fetcher: _DataFetcher) -> None: """Runs a single training batch. Raises: @@ -186,7 +186,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # we are going to train first so the val loop does not need to restart self.val_loop.restarting = False - if not isinstance(data_fetcher, DataLoaderIterDataFetcher): + if not isinstance(data_fetcher, _DataLoaderIterDataFetcher): batch_idx = self.batch_idx + 1 batch = next(data_fetcher) else: diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 0826b0f81a003..7c5dd7639ee6a 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Optional, Type +from typing import Any, Optional -import lightning.pytorch as pl from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.epoch import _TrainingEpochLoop from lightning.pytorch.loops.progress import Progress -from lightning.pytorch.loops.utilities import _is_max_limit_reached, _set_sampler_epoch +from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher, _set_sampler_epoch from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection from lightning.pytorch.trainer.supporters import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException -from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher -from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn -from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature +from lightning.pytorch.utilities.fetching import _DataFetcher +from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info log = logging.getLogger(__name__) @@ -74,7 +72,7 @@ def __init__( self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True - self._data_fetcher: Optional[AbstractDataFetcher] = None + self._data_fetcher: Optional[_DataFetcher] = None @property def total_batch_idx(self) -> int: @@ -217,8 +215,7 @@ def on_run_start(self) -> None: if self.epoch_loop._should_check_val_epoch(): self.epoch_loop.val_loop._reload_evaluation_dataloaders() - data_fetcher_cls = _select_data_fetcher(self.trainer) - self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches) + self._data_fetcher = _select_data_fetcher(self.trainer, self.prefetch_batches) self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) @@ -324,14 +321,3 @@ def _should_accumulate(self) -> bool: def _iteration_based_training(self) -> bool: return self.trainer.max_steps != -1 - - -def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: - training_step_fx = getattr(trainer.lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): - rank_zero_warn( - "Found `dataloader_iter` argument in the `training_step`. Note that the support for " - "this signature is experimental and the behavior is subject to change." - ) - return DataLoaderIterDataFetcher - return DataFetcher diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 366aee6947d34..6e94665bc9ec7 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -24,7 +24,9 @@ from lightning.pytorch.loops.progress import BaseProgress from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.strategy import Strategy +from lightning.pytorch.utilities.fetching import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher from lightning.pytorch.utilities.rank_zero import rank_zero_warn +from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature def check_finite_loss(loss: Optional[Tensor]) -> None: @@ -130,3 +132,23 @@ def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None: sampler = getattr(dataloader, sampler_name, None) if sampler is not None and callable(getattr(sampler, "set_epoch", None)): sampler.set_epoch(epoch) + + +def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _DataFetcher: + lightning_module = trainer.lightning_module + if trainer.testing: + step_fx_name = "test_step" + elif trainer.training: + step_fx_name = "training_step" + elif trainer.validating or trainer.sanity_checking: + step_fx_name = "validation_step" + else: + raise RuntimeError(f"DataFetcher is unsupported for {trainer.state.stage}") + step_fx = getattr(lightning_module, step_fx_name) + if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True): + rank_zero_warn( + f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for " + "this signature is experimental and the behavior is subject to change." + ) + return _DataLoaderIterDataFetcher() + return _PrefetchDataFetcher(prefetch_batches=prefetch_batches) diff --git a/src/lightning/pytorch/utilities/fetching.py b/src/lightning/pytorch/utilities/fetching.py index 90636c88a0797..12b3c9c8bc463 100644 --- a/src/lightning/pytorch/utilities/fetching.py +++ b/src/lightning/pytorch/utilities/fetching.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod from typing import Any, Callable, Iterable, Iterator, List, Optional, Sized, Tuple from torch.utils.data.dataloader import DataLoader @@ -26,43 +25,8 @@ def _profile_nothing() -> None: pass -class AbstractDataFetcher(ABC): - - """This base class should be used to implement a ``DataFetcher``. It is required to override the - ``fetching_function`` with fetching logic. - - Example:: - - class SimpleDataFetcher(AbstractDataFetcher): - def fetching_function(self): - while True: - try: - return next(self.dataloader_iter), False - except StopIteration: - return None, True - """ - - @abstractmethod - def fetching_function(self) -> Any: - """Override with your own fetching logic.""" - - @abstractmethod - def prefetching(self) -> None: - """Override with your own pre-fetching logic.""" - - def on_fetch_start(self) -> Any: - """Hook to override to handle the logic before fetching a batch.""" - - def on_fetch_end(self, batch: Any, start_output: Any) -> None: - """Hook to extend which handles the logic after fetching a batch.""" - - def wait(self) -> None: - """Hook to override to indicate the `DataFetcher` to wait for an event.""" - - def __init__(self, prefetch_batches: int = 0) -> None: - if prefetch_batches < 0: - raise MisconfigurationException("`prefetch_batches` should at least be 0.") - self.prefetch_batches = prefetch_batches +class _DataFetcher(Iterator): + def __init__(self) -> None: self._dataloader: Optional[Iterable] = None self.dataloader_iter: Optional[Iterator] = None self.fetched: int = 0 @@ -89,14 +53,23 @@ def loader_iters(self) -> Any: return self.dataloader_iter.loader_iters return self.dataloader_iter - def __iter__(self) -> "AbstractDataFetcher": + def __iter__(self) -> "_DataFetcher": self.reset() self.dataloader_iter = iter(self.dataloader) - self.prefetching() return self def __next__(self) -> Any: - return self.fetching_function() + self._start_profiler() + assert self.dataloader_iter is not None + try: + data = next(self.dataloader_iter) + except StopIteration as e: + self.done = True + raise e + finally: + self._stop_profiler() + self.fetched += 1 + return data def reset(self) -> None: self.fetched = 0 @@ -115,7 +88,7 @@ def _no_op_batch_to_device(batch: Any) -> Any: return batch -class DataFetcher(AbstractDataFetcher): +class _PrefetchDataFetcher(_DataFetcher): """This class is used to control batch fetching flow. Args: @@ -125,7 +98,10 @@ class DataFetcher(AbstractDataFetcher): """ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None: - super().__init__(prefetch_batches=prefetch_batches) + super().__init__() + if prefetch_batches < 0: + raise ValueError("`prefetch_batches` should at least be 0.") + self.prefetch_batches = prefetch_batches self.store_on_device = store_on_device self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device self.batches: List[Any] = [] @@ -141,15 +117,8 @@ def setup( # type: ignore[override] if batch_to_device is not None: self.batch_to_device = batch_to_device - def on_fetch_start(self) -> Any: - self._start_profiler() - - def on_fetch_end(self, batch: Any, start_output: Any) -> None: - """Hook to extend which handles the logic after fetching a batch.""" - self._stop_profiler() - self.batches.append(batch) - - def prefetching(self) -> None: + def __iter__(self) -> "_PrefetchDataFetcher": + super().__iter__() iterator = self.dataloader_iter assert iterator is not None for _ in range(self.prefetch_batches): @@ -157,11 +126,12 @@ def prefetching(self) -> None: self._fetch_next_batch(iterator) except StopIteration: # this would only happen when prefetch_batches > the number of batches available and makes - # `fetching_function` jump directly to the empty iterator case without trying to fetch again + # `__next__` jump directly to the empty iterator case without trying to fetch again self.done = True break + return self - def fetching_function(self) -> Any: + def __next__(self) -> Any: assert self.dataloader_iter is not None if self.batches: # there are pre-fetched batches already from a previous `prefetching` call. @@ -185,23 +155,21 @@ def fetching_function(self) -> Any: else: # the iterator is empty raise StopIteration - self.wait() return self.move_to_device(batch) def _fetch_next_batch(self, iterator: Iterator) -> None: - start_output = self.on_fetch_start() + self._start_profiler() try: batch = next(iterator) - except StopIteration as e: + finally: self._stop_profiler() - raise e self.fetched += 1 if not self.prefetch_batches and self._has_len: # when we don't prefetch but the dataloader is sized, we use the length for `done` dataloader = self.dataloader assert isinstance(dataloader, Sized) # `_has_len` is True self.done = self.fetched >= len(dataloader) - self.on_fetch_end(batch, start_output) + self.batches.append(batch) def move_to_device(self, batch: Any) -> Any: if self.store_on_device: @@ -213,39 +181,13 @@ def reset(self) -> None: self.batches = [] -class StepFuncDataLoaderIter(Iterator): - - """This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user - control.""" - - def __init__(self, iterator: Iterator, data_fetcher: AbstractDataFetcher) -> None: - self.iterator = iterator - self.data_fetcher = data_fetcher - - def __next__(self) -> Any: - try: - self.data_fetcher._start_profiler() - data = next(self.iterator) - self.data_fetcher._stop_profiler() - self.data_fetcher.fetched += 1 - return data - except StopIteration as e: - self.data_fetcher.done = True - raise e - - -class DataLoaderIterDataFetcher(AbstractDataFetcher): - +class _DataLoaderIterDataFetcher(_DataFetcher): """This class is used to return directly the `dataloader_iter` to the ``LightningModule`` training_step for users to implement their own pre-fetching logic. This feature can be activated as follows: Example:: Class MyModel(LightningModule): - - def __init__(self): - self.automatic_optimization = False - def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: # it is the user responsibility to fetch and move the batch to the right device. batch = next(dataloader_iter) @@ -253,17 +195,22 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: ... """ - def __init__(self, prefetch_batches: int = 0) -> None: - # prefetch batches is not used for this class - super().__init__() - self.store_on_device = False - - def prefetching(self) -> None: + def __iter__(self) -> "_DataLoaderIterDataFetcher": + super().__iter__() iterator = self.dataloader_iter assert iterator is not None - self.iterator = iter(StepFuncDataLoaderIter(iterator, self)) + self.iterator = iter(_DataFetcherWrapper(self)) + return self - def fetching_function(self) -> Tuple[int, Iterator]: + def __next__(self) -> Tuple[int, Iterator]: if not self.done: return self.fetched, self.iterator raise StopIteration + + +class _DataFetcherWrapper(Iterator): + def __init__(self, data_fetcher: _DataLoaderIterDataFetcher) -> None: + self.data_fetcher = data_fetcher + + def __next__(self) -> Any: + return super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__() diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index 3d3aaee498b06..43ba495b092bd 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -23,7 +23,7 @@ from lightning.pytorch.profilers import SimpleProfiler from lightning.pytorch.trainer.supporters import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher +from lightning.pytorch.utilities.fetching import _DataLoaderIterDataFetcher, _PrefetchDataFetcher from lightning.pytorch.utilities.types import STEP_OUTPUT from tests_pytorch.helpers.runif import RunIf @@ -47,7 +47,7 @@ def __getitem__(self, idx): @pytest.mark.parametrize("dataset_cls", [IterDataset, SizedDataset]) @pytest.mark.parametrize("prefetch_batches", list(range(5))) def test_prefetch_iterator(use_combined_loader, dataset_cls, prefetch_batches): - fetcher = DataFetcher(prefetch_batches=prefetch_batches) + fetcher = _PrefetchDataFetcher(prefetch_batches=prefetch_batches) assert fetcher.prefetch_batches == prefetch_batches if use_combined_loader: @@ -87,7 +87,7 @@ def __init__(self): def __iter__(self): return iter(self.list) - fetcher = DataFetcher() + fetcher = _PrefetchDataFetcher() if use_combined_loader: loader = CombinedLoader([DataLoader(TestDataset()), DataLoader(TestDataset())]) else: @@ -115,7 +115,7 @@ def __len__(self): @pytest.mark.parametrize("prefetch_batches", list(range(2))) def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): loader = DataLoader(dataset_cls()) - fetcher = DataFetcher(prefetch_batches=prefetch_batches) + fetcher = _PrefetchDataFetcher(prefetch_batches=prefetch_batches) fetcher.setup(loader) assert not fetcher.done @@ -124,7 +124,7 @@ def test_empty_prefetch_iterator(dataset_cls, prefetch_batches): def test_misconfiguration_error(): - fetcher = DataFetcher() + fetcher = _PrefetchDataFetcher() loader = DataLoader(range(10)) fetcher.setup(loader) with pytest.raises( @@ -224,7 +224,7 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs): def training_step(self, dataloader_iter, batch_idx): assert self.count == batch_idx - assert isinstance(self.trainer.fit_loop._data_fetcher, DataLoaderIterDataFetcher) + assert isinstance(self.trainer.fit_loop._data_fetcher, _DataLoaderIterDataFetcher) # fetch 2 batches self.batches.append(next(dataloader_iter)) self.batches.append(next(dataloader_iter)) @@ -258,7 +258,7 @@ def on_train_epoch_end(self): def test_fetching_dataloader_iter_running_stages(fn, tmpdir): class TestModel(BoringModel): def fetch(self, data_fetcher, dataloader_iter, batch_idx): - assert isinstance(data_fetcher, DataLoaderIterDataFetcher) + assert isinstance(data_fetcher, _DataLoaderIterDataFetcher) assert data_fetcher.fetched == batch_idx batch = next(dataloader_iter) assert data_fetcher.fetched == batch_idx + 1