15
15
import shutil
16
16
import sys
17
17
from collections import ChainMap , OrderedDict
18
- from typing import Any , Iterable , List , Optional , Sequence , Tuple , Type , Union
18
+ from typing import Any , Iterable , List , Optional , Sequence , Tuple , Union
19
19
20
20
from lightning_utilities .core .apply_func import apply_to_collection
21
21
from torch import Tensor
22
22
from torch .utils .data .dataloader import DataLoader
23
23
24
- import lightning .pytorch as pl
25
24
from lightning .pytorch .callbacks .progress .rich_progress import _RICH_AVAILABLE
26
25
from lightning .pytorch .loops .dataloader import _DataLoaderLoop
27
26
from lightning .pytorch .loops .epoch import _EvaluationEpochLoop
28
- from lightning .pytorch .loops .utilities import _set_sampler_epoch
27
+ from lightning .pytorch .loops .utilities import _select_data_fetcher , _set_sampler_epoch
29
28
from lightning .pytorch .trainer .connectors .logger_connector .result import _OUT_DICT , _ResultCollection
30
29
from lightning .pytorch .trainer .states import TrainerFn
31
- from lightning .pytorch .utilities .fetching import AbstractDataFetcher , DataFetcher , DataLoaderIterDataFetcher
32
- from lightning .pytorch .utilities .rank_zero import rank_zero_warn
33
- from lightning .pytorch .utilities .signature_utils import is_param_in_hook_signature
30
+ from lightning .pytorch .utilities .fetching import _DataFetcher
34
31
35
32
if _RICH_AVAILABLE :
36
33
from rich import get_console
@@ -53,7 +50,7 @@ def __init__(self, verbose: bool = True) -> None:
53
50
self ._logged_outputs : List [_OUT_DICT ] = []
54
51
self ._max_batches : List [Union [int , float ]] = []
55
52
self ._has_run : bool = False
56
- self ._data_fetcher : Optional [AbstractDataFetcher ] = None
53
+ self ._data_fetcher : Optional [_DataFetcher ] = None
57
54
58
55
@property
59
56
def num_dataloaders (self ) -> int :
@@ -125,8 +122,7 @@ def reset(self) -> None:
125
122
def on_run_start (self ) -> None :
126
123
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
127
124
hooks."""
128
- data_fetcher_cls = _select_data_fetcher_type (self .trainer )
129
- self ._data_fetcher = data_fetcher_cls (prefetch_batches = self .prefetch_batches )
125
+ self ._data_fetcher = _select_data_fetcher (self .trainer , prefetch_batches = self .prefetch_batches )
130
126
131
127
# hook
132
128
self ._on_evaluation_model_eval ()
@@ -373,16 +369,3 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None:
373
369
lines .append (row_format .format (metric , * row ).rstrip ())
374
370
lines .append (bar )
375
371
print (os .linesep .join (lines ))
376
-
377
-
378
- def _select_data_fetcher_type (trainer : "pl.Trainer" ) -> Type [AbstractDataFetcher ]:
379
- lightning_module = trainer .lightning_module
380
- step_fx_name = "test_step" if trainer .testing else "validation_step"
381
- step_fx = getattr (lightning_module , step_fx_name )
382
- if is_param_in_hook_signature (step_fx , "dataloader_iter" , explicit = True ):
383
- rank_zero_warn (
384
- f"Found `dataloader_iter` argument in the `{ step_fx_name } `. Note that the support for "
385
- "this signature is experimental and the behavior is subject to change."
386
- )
387
- return DataLoaderIterDataFetcher
388
- return DataFetcher
0 commit comments