Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- When using multiple loggers, by default checkpoints and profiler output now get saved to the log dir of the first logger in the list ([14325](https://github.com/Lightning-AI/lightning/pull/14325))


- Explicitly passing `train_dataloaders=None`, `val_dataloaders=None`, `dataloaders=None` or `datamodule=None` is officially no longer supported and will inform using a better error message than before ([14614](https://github.com/Lightning-AI/lightning/pull/14614))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is a huge breaking change as described in #14602 (comment).

Very often people write scripts with

train_dataloader = None
if ...:
    train_dataloader = ...
trainer.fit(model, train_dataloader=train_dataloader)

(you had to update examples like this in our CI)

I don't think we can afford to break this usecase, sure, it can be error-prone, but it's a perfectly valid Python pattern when you want to pass arguments optionally.

I strongly think it should be a PossibleUserWarning for this reason.

Copy link
Contributor Author

@awaelchli awaelchli Sep 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make it a warning, then another error will raise later on and completely overshadow any warning printed earlier. It won't help the user.

The only other option I see is to change the message in "No trainining_step defined ..." to hint at the possibility that one of the dataloaders is None. But this is very strange. Ideally we would want to tell the user directly what the issue is, not a list of possibilities.

Copy link
Contributor

@carmocca carmocca Sep 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another possibility would be to check at the same time whether there' a *_dataloader fallback. If there is, allow None and proceed normally. Otherwise, raise an error.

If we do this, we should also re-write the "No train_dataloader() method defined" error as it would be covered by this new error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an alternative implementation here: #14637

My brain is melting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Much better




### Deprecated

Expand Down
22 changes: 22 additions & 0 deletions src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import pytorch_lightning as pl
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.accelerators.ipu import IPUAccelerator
Expand Down Expand Up @@ -308,3 +310,23 @@ def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None:
"`LightningDataModule.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
)


def _check_dataloader_none(stage: str, **dataloader_args: Any) -> None:
for arg_name, value in dataloader_args.items():
dataloader_method = arg_name if not arg_name.endswith("s") else arg_name[:-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass the RunningStage object instead of the value, you can use stage.dataloader_prefix to construct the method name

if value is None:
raise ValueError(
f"You explicitly passed `Trainer.{stage}({arg_name}=None, ...)`, but this is not supported."
" You should either a) pass in valid dataloader(s) or"
f" b) remove the argument from `.{stage}()` and implement `def {dataloader_method}(self):` in your"
f" LightningModule/LightningDataModule instead."
)


def _check_datamodule_none(stage: str, datamodule: Optional["pl.LightningDataModule"]) -> None:
if datamodule is None:
raise ValueError(
f"You explicitly passed `Trainer.{stage}(datamodule=None, ...)`, but this is not supported."
f" Please pass in valid `LightningDataModule` or remove the argument from `.{stage}()`."
)
98 changes: 70 additions & 28 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@
XLAProfiler,
)
from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
from pytorch_lightning.trainer.configuration_validator import (
_check_dataloader_none,
_check_datamodule_none,
verify_loop_configurations,
)
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
Expand Down Expand Up @@ -114,6 +118,7 @@
from pytorch_lightning.utilities.seed import isolate_rng
from pytorch_lightning.utilities.types import (
_EVALUATE_OUTPUT,
_NO_DATA,
_PREDICT_OUTPUT,
EVAL_DATALOADERS,
LRSchedulerConfig,
Expand Down Expand Up @@ -670,9 +675,9 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
def fit(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
train_dataloaders: Union[TRAIN_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
val_dataloaders: Union[EVAL_DATALOADERS, Type[_NO_DATA]] = _NO_DATA,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
ckpt_path: Optional[str] = None,
) -> None:
r"""
Expand Down Expand Up @@ -703,9 +708,9 @@ def fit(
def _fit_impl(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
train_dataloaders: Union[TRAIN_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
val_dataloaders: Union[EVAL_DATALOADERS, Type[_NO_DATA]] = _NO_DATA,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
ckpt_path: Optional[str] = None,
) -> None:
Trainer._log_api_event("fit")
Expand All @@ -718,11 +723,21 @@ def _fit_impl(
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
train_dataloaders = None
train_dataloaders = _NO_DATA

_check_dataloader_none(
stage=self.state.fn, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders
)
train_dataloaders = None if train_dataloaders is _NO_DATA else train_dataloaders
val_dataloaders = None if val_dataloaders is _NO_DATA else val_dataloaders
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
datamodule = None if datamodule is _NO_DATA else datamodule

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
"You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`"
"You cannot pass `train_dataloader` or `val_dataloaders` together with `trainer.fit(datamodule=...)`."
" Choose either the datamodule or the raw dataloaders."
)

# links data to the trainer
Expand All @@ -744,10 +759,10 @@ def _fit_impl(
def validate(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
) -> _EVALUATE_OUTPUT:
r"""
Perform one evaluation epoch over the validation set.
Expand Down Expand Up @@ -781,10 +796,10 @@ def validate(
def _validate_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
) -> _EVALUATE_OUTPUT:
# --------------------
# SETUP HOOK
Expand All @@ -799,7 +814,13 @@ def _validate_impl(
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
dataloaders = None
dataloaders = _NO_DATA

_check_dataloader_none(stage=self.state.fn, dataloaders=dataloaders)
dataloaders = None if dataloaders is _NO_DATA else dataloaders
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
datamodule = None if datamodule is _NO_DATA else datamodule

# If you supply a datamodule you can't supply val_dataloaders
if dataloaders is not None and datamodule:
raise MisconfigurationException("You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`")
Expand Down Expand Up @@ -833,10 +854,10 @@ def _validate_impl(
def test(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
) -> _EVALUATE_OUTPUT:
r"""
Perform one evaluation epoch over the test set.
Expand Down Expand Up @@ -871,10 +892,10 @@ def test(
def _test_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
) -> _EVALUATE_OUTPUT:
# --------------------
# SETUP HOOK
Expand All @@ -889,7 +910,13 @@ def _test_impl(
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
dataloaders = None
dataloaders = _NO_DATA

_check_dataloader_none(stage=self.state.fn, dataloaders=dataloaders)
dataloaders = None if dataloaders is _NO_DATA else dataloaders
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
datamodule = None if datamodule is _NO_DATA else datamodule

# If you supply a datamodule you can't supply test_dataloaders
if dataloaders is not None and datamodule:
raise MisconfigurationException("You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`")
Expand Down Expand Up @@ -923,8 +950,8 @@ def _test_impl(
def predict(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None,
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = None,
) -> Optional[_PREDICT_OUTPUT]:
Expand Down Expand Up @@ -962,8 +989,8 @@ def predict(
def _predict_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None,
dataloaders: Union[EVAL_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = None,
) -> Optional[_PREDICT_OUTPUT]:
Expand All @@ -982,7 +1009,13 @@ def _predict_impl(
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
dataloaders = None
dataloaders = _NO_DATA

_check_dataloader_none(stage=self.state.fn, dataloaders=dataloaders)
dataloaders = None if dataloaders is _NO_DATA else dataloaders
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
datamodule = None if datamodule is _NO_DATA else datamodule

if dataloaders is not None and datamodule:
raise MisconfigurationException("You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`")

Expand Down Expand Up @@ -1012,9 +1045,9 @@ def _predict_impl(
def tune(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
train_dataloaders: Union[TRAIN_DATALOADERS, LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
val_dataloaders: Union[EVAL_DATALOADERS, Type[_NO_DATA]] = _NO_DATA,
datamodule: Union[LightningDataModule, Type[_NO_DATA]] = _NO_DATA,
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
) -> _TunerResult:
Expand Down Expand Up @@ -1048,7 +1081,16 @@ def tune(
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
train_dataloaders = None
train_dataloaders = _NO_DATA

_check_dataloader_none(
stage=self.state.fn, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders
)
train_dataloaders = None if train_dataloaders is _NO_DATA else train_dataloaders
val_dataloaders = None if val_dataloaders is _NO_DATA else val_dataloaders
_check_datamodule_none(stage=self.state.fn, datamodule=datamodule)
datamodule = None if datamodule is _NO_DATA else datamodule

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
Expand Down
16 changes: 8 additions & 8 deletions src/pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Type, Union

from typing_extensions import NotRequired, TypedDict

Expand All @@ -20,7 +20,7 @@
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
from pytorch_lightning.tuner.lr_finder import _LRFinder, lr_find
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.types import _NO_DATA, EVAL_DATALOADERS, TRAIN_DATALOADERS


class _TunerResult(TypedDict):
Expand Down Expand Up @@ -83,9 +83,9 @@ def _run(self, *args: Any, **kwargs: Any) -> None:
def scale_batch_size(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional["pl.LightningDataModule"] = None,
train_dataloaders: Union[TRAIN_DATALOADERS, "pl.LightningDataModule", Type[_NO_DATA]] = _NO_DATA,
val_dataloaders: Union[EVAL_DATALOADERS, Type[_NO_DATA]] = _NO_DATA,
datamodule: Union["pl.LightningDataModule", Type[_NO_DATA]] = _NO_DATA,
mode: str = "power",
steps_per_trial: int = 3,
init_val: int = 2,
Expand Down Expand Up @@ -149,9 +149,9 @@ def scale_batch_size(
def lr_find(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, "pl.LightningDataModule"]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional["pl.LightningDataModule"] = None,
train_dataloaders: Union[TRAIN_DATALOADERS, "pl.LightningDataModule", Type[_NO_DATA]] = _NO_DATA,
val_dataloaders: Union[EVAL_DATALOADERS, Type[_NO_DATA]] = _NO_DATA,
datamodule: Union["pl.LightningDataModule", Type[_NO_DATA]] = _NO_DATA,
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
Expand Down
24 changes: 23 additions & 1 deletion src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch import Tensor
from torch.utils.data import DataLoader
from torchmetrics import Metric
from typing_extensions import Protocol, runtime_checkable
from typing_extensions import Literal, Protocol, runtime_checkable

from lightning_lite.utilities.types import _LRScheduler, ReduceLROnPlateau

Expand Down Expand Up @@ -139,3 +139,25 @@ class LRSchedulerConfig:
strict: bool = True
# opt_idx assigned internally if not assigned by user
opt_idx: Optional[int] = None


class _SentinelMeta(type):
"""Metaclass representing a sentinel value by the name of the class.

Reference: https://stackoverflow.com/a/69243488/1162383
See also: https://peps.python.org/pep-0661/
"""

def __repr__(cls) -> str:
return f"<{cls.__name__}>"

def __bool__(cls) -> Literal[False]:
return False


class Sentinel(metaclass=_SentinelMeta):
"""Subclass this to create a new sentinel."""


class _NO_DATA(Sentinel):
"""A sentinel representing the default value for 'no dataloader passed to Trainer method'."""
Comment on lines +162 to +163
Copy link
Contributor Author

@awaelchli awaelchli Sep 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Sometimes we see sentinels defined as

DEFAULT = object()

But this does not work for us, because pickling (e.g. ddp_spawn) will re-instantiate this object and then checks like dataloader is DEFAULT have the wrong value!

10 changes: 8 additions & 2 deletions tests/tests_pytorch/helpers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def run_model_test_without_loggers(

# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model, datamodule=data)
if data is not None:
trainer.fit(model, datamodule=data)
else:
trainer.fit(model)

# correct result and ok accuracy
assert trainer.state.finished, f"Training failed with {trainer.state}"
Expand Down Expand Up @@ -59,7 +62,10 @@ def run_model_test(
trainer_options.update(logger=logger)
trainer = Trainer(**trainer_options)
initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])
trainer.fit(model, datamodule=data)
if data is not None:
trainer.fit(model, datamodule=data)
else:
trainer.fit(model)
post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])

assert trainer.state.finished, f"Training failed with {trainer.state}"
Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/helpers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def test_models(tmpdir, data_class, model_class):
model = model_class()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

trainer.fit(model, datamodule=dm)
if dm is not None:
trainer.fit(model, datamodule=dm)
else:
trainer.fit(model)

if dm is not None:
trainer.test(model, datamodule=dm)
Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def _run_standard_hparams_test(tmpdir, model, cls, datamodule=None, try_overwrit

# verify we can train
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)
trainer.fit(model, datamodule=datamodule if issubclass(cls, LightningDataModule) else None)
if issubclass(cls, LightningDataModule):
trainer.fit(model, datamodule=datamodule)
else:
trainer.fit(model)

# make sure the raw checkpoint saved the properties
raw_checkpoint_path = _raw_checkpoint_path(trainer)
Expand Down
Loading