41
41
from torch import Tensor
42
42
from torch .optim import Optimizer
43
43
from torch .utils .data import DataLoader
44
- from typing_extensions import Literal
45
44
46
45
import pytorch_lightning as pl
47
46
from lightning_lite .utilities .cloud_io import get_filesystem
@@ -890,11 +889,9 @@ def tune(
890
889
model : "pl.LightningModule" ,
891
890
train_dataloaders : Optional [Union [TRAIN_DATALOADERS , LightningDataModule ]] = None ,
892
891
val_dataloaders : Optional [EVAL_DATALOADERS ] = None ,
893
- dataloaders : Optional [EVAL_DATALOADERS ] = None ,
894
892
datamodule : Optional [LightningDataModule ] = None ,
895
893
scale_batch_size_kwargs : Optional [Dict [str , Any ]] = None ,
896
894
lr_find_kwargs : Optional [Dict [str , Any ]] = None ,
897
- method : Literal ["fit" , "validate" , "test" , "predict" ] = "fit" ,
898
895
) -> _TunerResult :
899
896
r"""
900
897
Runs routines to tune hyperparameters before training.
@@ -908,34 +905,44 @@ def tune(
908
905
909
906
val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples.
910
907
911
- dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying val/test/predict
912
- samples used for running tuner on validation/testing/prediction.
913
-
914
908
datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
915
909
916
910
scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size`
917
911
918
912
lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find`
919
-
920
- method: Method to run tuner on. It can be any of ``("fit", "validate", "test", "predict")``.
921
913
"""
922
914
if not isinstance (model , pl .LightningModule ):
923
915
raise TypeError (f"`Trainer.tune()` requires a `LightningModule`, got: { model .__class__ .__qualname__ } " )
924
916
925
917
Trainer ._log_api_event ("tune" )
926
918
919
+ self .state .fn = TrainerFn .TUNING
920
+ self .state .status = TrainerStatus .RUNNING
921
+ self .tuning = True
922
+
923
+ # if a datamodule comes in as the second arg, then fix it for the user
924
+ if isinstance (train_dataloaders , LightningDataModule ):
925
+ datamodule = train_dataloaders
926
+ train_dataloaders = None
927
+ # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
928
+ if (train_dataloaders is not None or val_dataloaders is not None ) and datamodule is not None :
929
+ raise MisconfigurationException (
930
+ "You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`"
931
+ )
932
+
933
+ # links data to the trainer
934
+ self ._data_connector .attach_data (
935
+ model , train_dataloaders = train_dataloaders , val_dataloaders = val_dataloaders , datamodule = datamodule
936
+ )
937
+
927
938
with isolate_rng ():
928
939
result = self .tuner ._tune (
929
- model ,
930
- train_dataloaders ,
931
- val_dataloaders ,
932
- dataloaders ,
933
- datamodule ,
934
- scale_batch_size_kwargs = scale_batch_size_kwargs ,
935
- lr_find_kwargs = lr_find_kwargs ,
936
- method = method ,
940
+ model , scale_batch_size_kwargs = scale_batch_size_kwargs , lr_find_kwargs = lr_find_kwargs
937
941
)
938
942
943
+ assert self .state .stopped
944
+ self .tuning = False
945
+
939
946
return result
940
947
941
948
def _restore_modules_and_callbacks (self , checkpoint_path : Optional [_PATH ] = None ) -> None :
0 commit comments