Skip to content
Merged
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated duplicate `SaveConfigCallback` parameters in `LightningCLI.__init__`: `save_config_kwargs`, `save_config_overwrite` and `save_config_multifile`. New `save_config_kwargs` parameter should be used instead ([#14998](https://github.com/Lightning-AI/lightning/pull/14998)


- Deprecated `TrainerFn.TUNING`, `RunningStage.TUNING` and `trainer.tuning` property ([#15100](https://github.com/Lightning-AI/lightning/pull/15100)


### Removed

- Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011))
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def __init__(
self._duration = duration.total_seconds() if duration is not None else None
self._interval = interval
self._verbose = verbose
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
self._offset = 0

def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
Expand Down Expand Up @@ -161,7 +161,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -
self._check_time_remaining(trainer)

def state_dict(self) -> Dict[str, Any]:
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}}
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
time_elapsed = state_dict.get("time_elapsed", {})
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer.state.fn != TrainerFn.FITTING:
return
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
# Separate models are instantiated for different stages, but they share the same weights on host.
# When validation/test models are run, weights are synced first.
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer_fn == TrainerFn.FITTING:
# Create model for training and validation which will run on fit
training_opts = self.training_opts
inference_opts = self.inference_opts
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer.state.fn != TrainerFn.FITTING:
return
assert self.lightning_module is not None
self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:

if trainer.state.fn is None:
raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.")
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer.state.fn == TrainerFn.FITTING:
__verify_train_val_loop_configuration(trainer, model)
__verify_manual_optimization_support(trainer, model)
__check_training_step_requires_dataloader_iter(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def restore_loops(self) -> None:
assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
if state_dict is not None:
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if self.trainer.state.fn == TrainerFn.FITTING:
fit_loop.load_state_dict(state_dict["fit_loop"])
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
Expand Down
62 changes: 52 additions & 10 deletions src/pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from enum import Enum, EnumMeta
from typing import Any, List, Optional

from lightning_utilities.core.rank_zero import rank_zero_deprecation

from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.enums import _FaultTolerantMode


class _DeprecationManagingEnumMeta(EnumMeta):
"""Enum that calls `deprecate()` whenever a member is accessed.

Adapted from: https://stackoverflow.com/a/62309159/208880
"""

def __getattribute__(cls, name: str) -> Any:
obj = super().__getattribute__(name)
# ignore __dunder__ names -- prevents potential recursion errors
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
obj.deprecate()
return obj

def __getitem__(cls, name: str) -> Any:
member: _DeprecationManagingEnumMeta = super().__getitem__(name)
member.deprecate()
return member

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
obj = super().__call__(*args, **kwargs)
if isinstance(obj, Enum):
obj.deprecate()
return obj


class TrainerStatus(LightningEnum):
"""Enum for the status of the :class:`~pytorch_lightning.trainer.trainer.Trainer`"""

Expand All @@ -31,7 +59,7 @@ def stopped(self) -> bool:
return self in (self.FINISHED, self.INTERRUPTED)


class TrainerFn(LightningEnum):
class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
"""
Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
Expand All @@ -44,16 +72,19 @@ class TrainerFn(LightningEnum):
PREDICTING = "predict"
TUNING = "tune"

@property
def _setup_fn(self) -> "TrainerFn":
"""``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders.
def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0."
)

This is used for the ``setup()`` and ``teardown()`` hooks
"""
return TrainerFn.FITTING if self == TrainerFn.TUNING else self
@classmethod
def _without_tune(cls) -> List["TrainerFn"]:
fns = [fn for fn in cls if fn != "tune"]
return fns


class RunningStage(LightningEnum):
class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
"""Enum for the current running stage.

This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
Expand All @@ -79,12 +110,23 @@ def evaluating(self) -> bool:

@property
def dataloader_prefix(self) -> Optional[str]:
if self in (self.SANITY_CHECKING, self.TUNING):
if self == self.SANITY_CHECKING:
return None
if self == self.VALIDATING:
return "val"
return self.value

def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0."
)

@classmethod
def _without_tune(cls) -> List["RunningStage"]:
fns = [fn for fn in cls if fn != "tune"]
return fns


@dataclass
class TrainerState:
Expand Down
13 changes: 8 additions & 5 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if self.state.fn == TrainerFn.FITTING:
min_epochs, max_epochs = _parse_loop_limits(
self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self
)
Expand Down Expand Up @@ -1233,7 +1233,7 @@ def _run_sanity_check(self) -> None:

def _call_setup_hook(self) -> None:
assert self.state.fn is not None
fn = self.state.fn._setup_fn
fn = self.state.fn

self.strategy.barrier("pre_setup")

Expand All @@ -1256,7 +1256,7 @@ def _call_configure_sharded_model(self) -> None:

def _call_teardown_hook(self) -> None:
assert self.state.fn is not None
fn = self.state.fn._setup_fn
fn = self.state.fn

if self.datamodule is not None:
self._call_lightning_datamodule_hook("teardown", stage=fn)
Expand Down Expand Up @@ -1449,7 +1449,7 @@ def __setup_profiler(self) -> None:
assert self.state.fn is not None
local_rank = self.local_rank if self.world_size > 1 else None
self.profiler._lightning_module = proxy(self.lightning_module)
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
self.profiler.setup(stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir)

"""
Data loading methods
Expand Down Expand Up @@ -1965,10 +1965,13 @@ def predicting(self, val: bool) -> None:

@property
def tuning(self) -> bool:
rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0.")
return self.state.stage == RunningStage.TUNING

@tuning.setter
def tuning(self, val: bool) -> None:
rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0.")

if val:
self.state.stage = RunningStage.TUNING
elif self.tuning:
Expand Down Expand Up @@ -2097,7 +2100,7 @@ def predict_loop(self, loop: PredictionLoop) -> None:

@property
def _evaluation_loop(self) -> EvaluationLoop:
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if self.state.fn == TrainerFn.FITTING:
return self.fit_loop.epoch_loop.val_loop
if self.state.fn == TrainerFn.VALIDATING:
return self.validate_loop
Expand Down
12 changes: 1 addition & 11 deletions src/pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -77,9 +77,7 @@ def _tune(

# Run learning rate finder:
if self.trainer.auto_lr_find:
self.trainer.state.fn = TrainerFn.TUNING
self.trainer.state.status = TrainerStatus.RUNNING
self.tuning = True

# TODO: Remove this once LRFinder is converted to a Callback
# if a datamodule comes in as the second arg, then fix it for the user
Expand Down Expand Up @@ -112,7 +110,6 @@ def _run(self, *args: Any, **kwargs: Any) -> None:
self.trainer.state.status = TrainerStatus.RUNNING # last `_run` call might have set it to `FINISHED`
self.trainer.training = True
self.trainer._run(*args, **kwargs)
self.trainer.tuning = True

def scale_batch_size(
self,
Expand Down Expand Up @@ -170,10 +167,6 @@ def scale_batch_size(
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
"""
# TODO: Remove TrainerFn.TUNING since we are now calling fit/validate/test/predict methods directly
self.trainer.state.fn = TrainerFn.TUNING
self.tuning = True

_check_tuner_configuration(self.trainer, train_dataloaders, val_dataloaders, dataloaders, method)

batch_size_finder: Callback = BatchSizeFinder(
Expand Down Expand Up @@ -254,9 +247,6 @@ def lr_find(
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``,
or if you are using more than one optimizer.
"""
self.trainer.state.fn = TrainerFn.TUNING
self.tuning = True

if method != "fit":
raise MisconfigurationException("method='fit' is an invalid configuration to run lr finder.")

Expand Down
25 changes: 25 additions & 0 deletions tests/tests_pytorch/deprecated_api/test_remove_1-10.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.strategies.bagua import LightningBaguaModule
from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.strategies.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.apply_func import (
apply_to_collection,
apply_to_collections,
Expand Down Expand Up @@ -297,3 +298,27 @@ def test_lite_convert_deprecated_tpus_argument(tpu_available):
def test_lightningCLI_save_config_init_params_deprecation_warning(name, value):
with mock.patch("sys.argv", ["any.py"]), pytest.deprecated_call(match=f".*{name!r} init parameter is deprecated.*"):
LightningCLI(BoringModel, run=False, **{name: value})


def test_tuning_enum():
with pytest.deprecated_call(
match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0."
):
TrainerFn.TUNING

with pytest.deprecated_call(
match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0."
):
RunningStage.TUNING


def test_tuning_trainer_property():
trainer = Trainer()

with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0."):
trainer.tuning

with pytest.deprecated_call(
match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0."
):
trainer.tuning = True
27 changes: 8 additions & 19 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,39 +183,28 @@ def _check_model_state_dict(self):
for actual, expected in zip(self.state_dict(), state_dict["state_dict"])
)

def _test_on_val_test_predict_tune_start(self):
def _test_on_val_test_predict_start(self):
assert self.trainer.current_epoch == state_dict["epoch"]
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()

# no optimizes and schedulers are loaded otherwise
if self.trainer.state.fn != TrainerFn.TUNING:
return

assert not self._check_optimizers()
assert not self._check_schedulers()

def on_train_start(self):
if self.trainer.state.fn == TrainerFn.TUNING:
self._test_on_val_test_predict_tune_start()
else:
assert self.trainer.current_epoch == state_dict["epoch"] + 1
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()
assert self._check_optimizers()
assert self._check_schedulers()
assert self.trainer.current_epoch == state_dict["epoch"] + 1
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()
assert self._check_optimizers()
assert self._check_schedulers()

def on_validation_start(self):
if self.trainer.state.fn == TrainerFn.VALIDATING:
self._test_on_val_test_predict_tune_start()
self._test_on_val_test_predict_start()

def on_test_start(self):
self._test_on_val_test_predict_tune_start()
self._test_on_val_test_predict_start()

for fn in ("fit", "validate", "test", "predict"):
model = CustomClassifModel()
dm = ClassifDataModule()
trainer_args["auto_scale_batch_size"] = (fn == "tune",)
trainer = Trainer(**trainer_args)
trainer_fn = getattr(trainer, fn)
trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
Expand Down
4 changes: 1 addition & 3 deletions tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ def test_ddp_configure_ddp():


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize(
"trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TUNING, TrainerFn.TESTING, TrainerFn.PREDICTING)
)
@pytest.mark.parametrize("trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING))
def test_ddp_dont_configure_sync_batchnorm(trainer_fn):
model = BoringModelGPU()
model.layer = torch.nn.BatchNorm1d(10)
Expand Down
Loading