18
18
import weakref
19
19
from contextlib import contextmanager
20
20
from pathlib import Path
21
- from typing import Any , Callable , Dict , Generator , List , Literal , Mapping , Optional , overload , Sequence , Tuple , Union
21
+ from typing import (
22
+ Any ,
23
+ Callable ,
24
+ Dict ,
25
+ Generator ,
26
+ IO ,
27
+ List ,
28
+ Literal ,
29
+ Mapping ,
30
+ Optional ,
31
+ overload ,
32
+ Sequence ,
33
+ Tuple ,
34
+ Union ,
35
+ )
22
36
23
37
import torch
24
38
from lightning_utilities .core .apply_func import apply_to_collection
27
41
from torch .nn import Module
28
42
from torch .optim .optimizer import Optimizer
29
43
from torchmetrics import Metric , MetricCollection
44
+ from typing_extensions import Self
30
45
31
46
import lightning .fabric as lf
32
47
import lightning .pytorch as pl
33
48
from lightning .fabric .loggers import Logger as FabricLogger
34
49
from lightning .fabric .utilities .apply_func import convert_to_tensors
35
50
from lightning .fabric .utilities .cloud_io import get_filesystem
36
51
from lightning .fabric .utilities .device_dtype_mixin import _DeviceDtypeModuleMixin
37
- from lightning .fabric .utilities .distributed import _distributed_available , _sync_ddp
52
+ from lightning .fabric .utilities .distributed import _distributed_available
38
53
from lightning .fabric .utilities .imports import _IS_WINDOWS , _TORCH_GREATER_EQUAL_2_0 , _TORCH_GREATER_EQUAL_2_1
54
+ from lightning .fabric .utilities .types import _MAP_LOCATION_TYPE , _PATH
39
55
from lightning .fabric .wrappers import _FabricOptimizer
40
56
from lightning .pytorch .callbacks .callback import Callback
41
57
from lightning .pytorch .core .hooks import CheckpointHooks , DataHooks , ModelHooks
42
58
from lightning .pytorch .core .mixins import HyperparametersMixin
43
59
from lightning .pytorch .core .optimizer import LightningOptimizer
44
- from lightning .pytorch .core .saving import ModelIO
60
+ from lightning .pytorch .core .saving import _load_from_checkpoint
45
61
from lightning .pytorch .loggers import Logger
46
62
from lightning .pytorch .trainer import call
47
63
from lightning .pytorch .trainer .connectors .logger_connector .fx_validator import _FxValidator
65
81
class LightningModule (
66
82
_DeviceDtypeModuleMixin ,
67
83
HyperparametersMixin ,
68
- ModelIO ,
69
84
ModelHooks ,
70
85
DataHooks ,
71
86
CheckpointHooks ,
@@ -92,6 +107,10 @@ class LightningModule(
92
107
)
93
108
_jit_is_scripting = False
94
109
110
+ CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
111
+ CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name"
112
+ CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type"
113
+
95
114
def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
96
115
super ().__init__ (* args , ** kwargs )
97
116
@@ -480,7 +499,7 @@ def log(
480
499
add_dataloader_idx = add_dataloader_idx ,
481
500
batch_size = batch_size ,
482
501
sync_dist = sync_dist and _distributed_available (),
483
- sync_dist_fn = trainer .strategy .reduce or _sync_ddp ,
502
+ sync_dist_fn = trainer .strategy .reduce ,
484
503
sync_dist_group = sync_dist_group ,
485
504
metric_attribute = metric_attribute ,
486
505
rank_zero_only = rank_zero_only ,
@@ -1419,6 +1438,95 @@ def to_torchscript(
1419
1438
1420
1439
return torchscript_module
1421
1440
1441
+ @classmethod
1442
+ def load_from_checkpoint (
1443
+ cls ,
1444
+ checkpoint_path : Union [_PATH , IO ],
1445
+ map_location : _MAP_LOCATION_TYPE = None ,
1446
+ hparams_file : Optional [_PATH ] = None ,
1447
+ strict : bool = True ,
1448
+ ** kwargs : Any ,
1449
+ ) -> Self : # type: ignore[valid-type]
1450
+ r"""
1451
+ Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
1452
+ it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
1453
+
1454
+ Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.
1455
+
1456
+ Args:
1457
+ checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
1458
+ map_location:
1459
+ If your checkpoint saved a GPU model and you now load on CPUs
1460
+ or a different number of GPUs, use this to map to the new setup.
1461
+ The behaviour is the same as in :func:`torch.load`.
1462
+ hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure
1463
+ as in this example::
1464
+
1465
+ drop_prob: 0.2
1466
+ dataloader:
1467
+ batch_size: 32
1468
+
1469
+ You most likely won't need this since Lightning will always save the hyperparameters
1470
+ to the checkpoint.
1471
+ However, if your checkpoint weights don't have the hyperparameters saved,
1472
+ use this method to pass in a ``.yaml`` file with the hparams you'd like to use.
1473
+ These will be converted into a :class:`~dict` and passed into your
1474
+ :class:`LightningModule` for use.
1475
+
1476
+ If your model's ``hparams`` argument is :class:`~argparse.Namespace`
1477
+ and ``.yaml`` file has hierarchical structure, you need to refactor your model to treat
1478
+ ``hparams`` as :class:`~dict`.
1479
+ strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
1480
+ returned by this module's state dict.
1481
+ \**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
1482
+ hyperparameter values.
1483
+
1484
+ Return:
1485
+ :class:`LightningModule` instance with loaded weights and hyperparameters (if available).
1486
+
1487
+ Note:
1488
+ ``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
1489
+ **class** to call it instead of the :class:`LightningModule` instance.
1490
+
1491
+ Example::
1492
+
1493
+ # load weights without mapping ...
1494
+ model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
1495
+
1496
+ # or load weights mapping all weights from GPU 1 to GPU 0 ...
1497
+ map_location = {'cuda:1':'cuda:0'}
1498
+ model = MyLightningModule.load_from_checkpoint(
1499
+ 'path/to/checkpoint.ckpt',
1500
+ map_location=map_location
1501
+ )
1502
+
1503
+ # or load weights and hyperparameters from separate files.
1504
+ model = MyLightningModule.load_from_checkpoint(
1505
+ 'path/to/checkpoint.ckpt',
1506
+ hparams_file='/path/to/hparams_file.yaml'
1507
+ )
1508
+
1509
+ # override some of the params with new values
1510
+ model = MyLightningModule.load_from_checkpoint(
1511
+ PATH,
1512
+ num_layers=128,
1513
+ pretrained_ckpt_path=NEW_PATH,
1514
+ )
1515
+
1516
+ # predict
1517
+ pretrained_model.eval()
1518
+ pretrained_model.freeze()
1519
+ y_hat = pretrained_model(x)
1520
+ """
1521
+ return _load_from_checkpoint (
1522
+ cls ,
1523
+ checkpoint_path ,
1524
+ map_location ,
1525
+ hparams_file ,
1526
+ strict ,
1527
+ ** kwargs ,
1528
+ )
1529
+
1422
1530
@contextmanager
1423
1531
def _prevent_trainer_and_dataloaders_deepcopy (self ) -> Generator [None , None , None ]:
1424
1532
self ._should_prevent_trainer_and_dataloaders_deepcopy = True
0 commit comments