77
77
XLAProfiler ,
78
78
)
79
79
from pytorch_lightning .strategies import ParallelStrategy , Strategy
80
- from pytorch_lightning .trainer .configuration_validator import verify_loop_configurations , _check_dataloader_none , \
81
- _check_datamodule_none
80
+ from pytorch_lightning .trainer .configuration_validator import (
81
+ _check_dataloader_none ,
82
+ _check_datamodule_none ,
83
+ verify_loop_configurations ,
84
+ )
82
85
from pytorch_lightning .trainer .connectors .accelerator_connector import _LITERAL_WARN , AcceleratorConnector
83
86
from pytorch_lightning .trainer .connectors .callback_connector import CallbackConnector
84
87
from pytorch_lightning .trainer .connectors .checkpoint_connector import CheckpointConnector
@@ -729,7 +732,6 @@ def _fit_impl(
729
732
)
730
733
train_dataloaders = None if train_dataloaders is _NO_DATALOADER else train_dataloaders
731
734
val_dataloaders = None if val_dataloaders is _NO_DATALOADER else val_dataloaders
732
-
733
735
_check_datamodule_none (stage = self .state .fn , datamodule = datamodule )
734
736
datamodule = None if datamodule is _NO_DATAMODULE else datamodule
735
737
@@ -796,10 +798,10 @@ def validate(
796
798
def _validate_impl (
797
799
self ,
798
800
model : Optional ["pl.LightningModule" ] = None ,
799
- dataloaders : Optional [ Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
801
+ dataloaders : Union [EVAL_DATALOADERS , LightningDataModule ] = _NO_DATALOADER ,
800
802
ckpt_path : Optional [str ] = None ,
801
803
verbose : bool = True ,
802
- datamodule : Optional [ LightningDataModule ] = None ,
804
+ datamodule : LightningDataModule = _NO_DATAMODULE ,
803
805
) -> _EVALUATE_OUTPUT :
804
806
# --------------------
805
807
# SETUP HOOK
@@ -814,7 +816,13 @@ def _validate_impl(
814
816
# if a datamodule comes in as the second arg, then fix it for the user
815
817
if isinstance (dataloaders , LightningDataModule ):
816
818
datamodule = dataloaders
817
- dataloaders = None
819
+ dataloaders = _NO_DATALOADER
820
+
821
+ _check_dataloader_none (stage = self .state .fn , dataloaders = dataloaders )
822
+ dataloaders = None if dataloaders is _NO_DATALOADER else dataloaders
823
+ _check_datamodule_none (stage = self .state .fn , datamodule = datamodule )
824
+ datamodule = None if datamodule is _NO_DATAMODULE else datamodule
825
+
818
826
# If you supply a datamodule you can't supply val_dataloaders
819
827
if dataloaders is not None and datamodule :
820
828
raise MisconfigurationException ("You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`" )
@@ -886,10 +894,10 @@ def test(
886
894
def _test_impl (
887
895
self ,
888
896
model : Optional ["pl.LightningModule" ] = None ,
889
- dataloaders : Optional [ Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
897
+ dataloaders : Union [EVAL_DATALOADERS , LightningDataModule ] = _NO_DATALOADER ,
890
898
ckpt_path : Optional [str ] = None ,
891
899
verbose : bool = True ,
892
- datamodule : Optional [ LightningDataModule ] = None ,
900
+ datamodule : LightningDataModule = _NO_DATAMODULE ,
893
901
) -> _EVALUATE_OUTPUT :
894
902
# --------------------
895
903
# SETUP HOOK
@@ -904,7 +912,13 @@ def _test_impl(
904
912
# if a datamodule comes in as the second arg, then fix it for the user
905
913
if isinstance (dataloaders , LightningDataModule ):
906
914
datamodule = dataloaders
907
- dataloaders = None
915
+ dataloaders = _NO_DATALOADER
916
+
917
+ _check_dataloader_none (stage = self .state .fn , dataloaders = dataloaders )
918
+ dataloaders = None if dataloaders is _NO_DATALOADER else dataloaders
919
+ _check_datamodule_none (stage = self .state .fn , datamodule = datamodule )
920
+ datamodule = None if datamodule is _NO_DATAMODULE else datamodule
921
+
908
922
# If you supply a datamodule you can't supply test_dataloaders
909
923
if dataloaders is not None and datamodule :
910
924
raise MisconfigurationException ("You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`" )
@@ -977,8 +991,8 @@ def predict(
977
991
def _predict_impl (
978
992
self ,
979
993
model : Optional ["pl.LightningModule" ] = None ,
980
- dataloaders : Optional [ Union [EVAL_DATALOADERS , LightningDataModule ]] = None ,
981
- datamodule : Optional [ LightningDataModule ] = None ,
994
+ dataloaders : Union [EVAL_DATALOADERS , LightningDataModule ] = _NO_DATALOADER ,
995
+ datamodule : LightningDataModule = _NO_DATAMODULE ,
982
996
return_predictions : Optional [bool ] = None ,
983
997
ckpt_path : Optional [str ] = None ,
984
998
) -> Optional [_PREDICT_OUTPUT ]:
@@ -998,6 +1012,12 @@ def _predict_impl(
998
1012
if isinstance (dataloaders , LightningDataModule ):
999
1013
datamodule = dataloaders
1000
1014
dataloaders = None
1015
+
1016
+ _check_dataloader_none (stage = self .state .fn , dataloaders = dataloaders )
1017
+ dataloaders = None if dataloaders is _NO_DATALOADER else dataloaders
1018
+ _check_datamodule_none (stage = self .state .fn , datamodule = datamodule )
1019
+ datamodule = None if datamodule is _NO_DATAMODULE else datamodule
1020
+
1001
1021
if dataloaders is not None and datamodule :
1002
1022
raise MisconfigurationException ("You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`" )
1003
1023
0 commit comments