Skip to content

Commit 2e16932

Browse files
committed
more
1 parent c0c363e commit 2e16932

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

examples/pl_bug_report/bug_report_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import torch
2-
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
2+
3+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
34

45

56
class BoringData(LightningDataModule):
67
pass
78

9+
810
class BoringModel(LightningModule):
911
def __init__(self):
1012
super().__init__()

src/pytorch_lightning/trainer/trainer.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,11 @@
7777
XLAProfiler,
7878
)
7979
from pytorch_lightning.strategies import ParallelStrategy, Strategy
80-
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations, _check_dataloader_none, \
81-
_check_datamodule_none
80+
from pytorch_lightning.trainer.configuration_validator import (
81+
_check_dataloader_none,
82+
_check_datamodule_none,
83+
verify_loop_configurations,
84+
)
8285
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
8386
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
8487
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
@@ -729,7 +732,6 @@ def _fit_impl(
729732
)
730733
train_dataloaders = None if train_dataloaders is _NO_DATALOADER else train_dataloaders
731734
val_dataloaders = None if val_dataloaders is _NO_DATALOADER else val_dataloaders
732-
733735
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
734736
datamodule = None if datamodule is _NO_DATAMODULE else datamodule
735737

@@ -796,10 +798,10 @@ def validate(
796798
def _validate_impl(
797799
self,
798800
model: Optional["pl.LightningModule"] = None,
799-
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
801+
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule] = _NO_DATALOADER,
800802
ckpt_path: Optional[str] = None,
801803
verbose: bool = True,
802-
datamodule: Optional[LightningDataModule] = None,
804+
datamodule: LightningDataModule = _NO_DATAMODULE,
803805
) -> _EVALUATE_OUTPUT:
804806
# --------------------
805807
# SETUP HOOK
@@ -814,7 +816,13 @@ def _validate_impl(
814816
# if a datamodule comes in as the second arg, then fix it for the user
815817
if isinstance(dataloaders, LightningDataModule):
816818
datamodule = dataloaders
817-
dataloaders = None
819+
dataloaders = _NO_DATALOADER
820+
821+
_check_dataloader_none(stage=self.state.fn, dataloaders=dataloaders)
822+
dataloaders = None if dataloaders is _NO_DATALOADER else dataloaders
823+
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
824+
datamodule = None if datamodule is _NO_DATAMODULE else datamodule
825+
818826
# If you supply a datamodule you can't supply val_dataloaders
819827
if dataloaders is not None and datamodule:
820828
raise MisconfigurationException("You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`")
@@ -886,10 +894,10 @@ def test(
886894
def _test_impl(
887895
self,
888896
model: Optional["pl.LightningModule"] = None,
889-
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
897+
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule] = _NO_DATALOADER,
890898
ckpt_path: Optional[str] = None,
891899
verbose: bool = True,
892-
datamodule: Optional[LightningDataModule] = None,
900+
datamodule: LightningDataModule = _NO_DATAMODULE,
893901
) -> _EVALUATE_OUTPUT:
894902
# --------------------
895903
# SETUP HOOK
@@ -904,7 +912,13 @@ def _test_impl(
904912
# if a datamodule comes in as the second arg, then fix it for the user
905913
if isinstance(dataloaders, LightningDataModule):
906914
datamodule = dataloaders
907-
dataloaders = None
915+
dataloaders = _NO_DATALOADER
916+
917+
_check_dataloader_none(stage=self.state.fn, dataloaders=dataloaders)
918+
dataloaders = None if dataloaders is _NO_DATALOADER else dataloaders
919+
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
920+
datamodule = None if datamodule is _NO_DATAMODULE else datamodule
921+
908922
# If you supply a datamodule you can't supply test_dataloaders
909923
if dataloaders is not None and datamodule:
910924
raise MisconfigurationException("You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`")
@@ -977,8 +991,8 @@ def predict(
977991
def _predict_impl(
978992
self,
979993
model: Optional["pl.LightningModule"] = None,
980-
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
981-
datamodule: Optional[LightningDataModule] = None,
994+
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule] = _NO_DATALOADER,
995+
datamodule: LightningDataModule = _NO_DATAMODULE,
982996
return_predictions: Optional[bool] = None,
983997
ckpt_path: Optional[str] = None,
984998
) -> Optional[_PREDICT_OUTPUT]:
@@ -998,6 +1012,12 @@ def _predict_impl(
9981012
if isinstance(dataloaders, LightningDataModule):
9991013
datamodule = dataloaders
10001014
dataloaders = None
1015+
1016+
_check_dataloader_none(stage=self.state.fn, dataloaders=dataloaders)
1017+
dataloaders = None if dataloaders is _NO_DATALOADER else dataloaders
1018+
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
1019+
datamodule = None if datamodule is _NO_DATAMODULE else datamodule
1020+
10011021
if dataloaders is not None and datamodule:
10021022
raise MisconfigurationException("You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`")
10031023

0 commit comments

Comments
 (0)