Skip to content

Commit 3328b04

Browse files
authored
Inline the ModelIO interface (#16999)
1 parent 98f9708 commit 3328b04

File tree

6 files changed

+123
-111
lines changed

6 files changed

+123
-111
lines changed

docs/source-pytorch/api_references.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ core
8080
LightningModule
8181
~mixins.HyperparametersMixin
8282
~optimizer.LightningOptimizer
83-
~saving.ModelIO
8483

8584

8685
loggers

docs/source-pytorch/common/checkpointing_intermediate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ Save checkpoints manually
147147
*************************
148148

149149
You can manually save checkpoints and restore your model from the checkpointed state using :meth:`~lightning.pytorch.trainer.trainer.Trainer.save_checkpoint`
150-
and :meth:`~lightning.pytorch.core.saving.ModelIO.load_from_checkpoint`.
150+
and :meth:`~lightning.pytorch.core.module.LightningModule.load_from_checkpoint`.
151151

152152
.. code-block:: python
153153

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
399399

400400
- Removed registration of `ShardedTensor` state dict hooks in `LightningModule.__init__` with `torch>=2.1` ([#16892](https://github.com/Lightning-AI/lightning/pull/16892))
401401

402+
- Removed the `lightning.pytorch.core.saving.ModelIO` class interface ([#16999](https://github.com/Lightning-AI/lightning/pull/16999))
402403

403404

404405
### Fixed

src/lightning/pytorch/core/mixins/hparams_mixin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from argparse import Namespace
1818
from typing import Any, List, MutableMapping, Optional, Sequence, Union
1919

20-
from lightning.pytorch.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
2120
from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters
2221

22+
_PRIMITIVE_TYPES = (bool, int, float, str)
23+
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
24+
2325

2426
class HyperparametersMixin:
2527

@@ -123,9 +125,9 @@ def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[Mutable
123125
hp = vars(hp)
124126
if isinstance(hp, dict):
125127
hp = AttributeDict(hp)
126-
elif isinstance(hp, PRIMITIVE_TYPES):
127-
raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.")
128-
elif not isinstance(hp, ALLOWED_CONFIG_TYPES):
128+
elif isinstance(hp, _PRIMITIVE_TYPES):
129+
raise ValueError(f"Primitives {_PRIMITIVE_TYPES} are not allowed.")
130+
elif not isinstance(hp, _ALLOWED_CONFIG_TYPES):
129131
raise ValueError(f"Unsupported config type of {type(hp)}.")
130132
return hp
131133

src/lightning/pytorch/core/module.py

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,21 @@
1818
import weakref
1919
from contextlib import contextmanager
2020
from pathlib import Path
21-
from typing import Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, overload, Sequence, Tuple, Union
21+
from typing import (
22+
Any,
23+
Callable,
24+
Dict,
25+
Generator,
26+
IO,
27+
List,
28+
Literal,
29+
Mapping,
30+
Optional,
31+
overload,
32+
Sequence,
33+
Tuple,
34+
Union,
35+
)
2236

2337
import torch
2438
from lightning_utilities.core.apply_func import apply_to_collection
@@ -27,21 +41,23 @@
2741
from torch.nn import Module
2842
from torch.optim.optimizer import Optimizer
2943
from torchmetrics import Metric, MetricCollection
44+
from typing_extensions import Self
3045

3146
import lightning.fabric as lf
3247
import lightning.pytorch as pl
3348
from lightning.fabric.loggers import Logger as FabricLogger
3449
from lightning.fabric.utilities.apply_func import convert_to_tensors
3550
from lightning.fabric.utilities.cloud_io import get_filesystem
3651
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
37-
from lightning.fabric.utilities.distributed import _distributed_available, _sync_ddp
52+
from lightning.fabric.utilities.distributed import _distributed_available
3853
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
54+
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
3955
from lightning.fabric.wrappers import _FabricOptimizer
4056
from lightning.pytorch.callbacks.callback import Callback
4157
from lightning.pytorch.core.hooks import CheckpointHooks, DataHooks, ModelHooks
4258
from lightning.pytorch.core.mixins import HyperparametersMixin
4359
from lightning.pytorch.core.optimizer import LightningOptimizer
44-
from lightning.pytorch.core.saving import ModelIO
60+
from lightning.pytorch.core.saving import _load_from_checkpoint
4561
from lightning.pytorch.loggers import Logger
4662
from lightning.pytorch.trainer import call
4763
from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator
@@ -65,7 +81,6 @@
6581
class LightningModule(
6682
_DeviceDtypeModuleMixin,
6783
HyperparametersMixin,
68-
ModelIO,
6984
ModelHooks,
7085
DataHooks,
7186
CheckpointHooks,
@@ -92,6 +107,10 @@ class LightningModule(
92107
)
93108
_jit_is_scripting = False
94109

110+
CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
111+
CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name"
112+
CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type"
113+
95114
def __init__(self, *args: Any, **kwargs: Any) -> None:
96115
super().__init__(*args, **kwargs)
97116

@@ -480,7 +499,7 @@ def log(
480499
add_dataloader_idx=add_dataloader_idx,
481500
batch_size=batch_size,
482501
sync_dist=sync_dist and _distributed_available(),
483-
sync_dist_fn=trainer.strategy.reduce or _sync_ddp,
502+
sync_dist_fn=trainer.strategy.reduce,
484503
sync_dist_group=sync_dist_group,
485504
metric_attribute=metric_attribute,
486505
rank_zero_only=rank_zero_only,
@@ -1419,6 +1438,95 @@ def to_torchscript(
14191438

14201439
return torchscript_module
14211440

1441+
@classmethod
1442+
def load_from_checkpoint(
1443+
cls,
1444+
checkpoint_path: Union[_PATH, IO],
1445+
map_location: _MAP_LOCATION_TYPE = None,
1446+
hparams_file: Optional[_PATH] = None,
1447+
strict: bool = True,
1448+
**kwargs: Any,
1449+
) -> Self: # type: ignore[valid-type]
1450+
r"""
1451+
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
1452+
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
1453+
1454+
Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.
1455+
1456+
Args:
1457+
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
1458+
map_location:
1459+
If your checkpoint saved a GPU model and you now load on CPUs
1460+
or a different number of GPUs, use this to map to the new setup.
1461+
The behaviour is the same as in :func:`torch.load`.
1462+
hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure
1463+
as in this example::
1464+
1465+
drop_prob: 0.2
1466+
dataloader:
1467+
batch_size: 32
1468+
1469+
You most likely won't need this since Lightning will always save the hyperparameters
1470+
to the checkpoint.
1471+
However, if your checkpoint weights don't have the hyperparameters saved,
1472+
use this method to pass in a ``.yaml`` file with the hparams you'd like to use.
1473+
These will be converted into a :class:`~dict` and passed into your
1474+
:class:`LightningModule` for use.
1475+
1476+
If your model's ``hparams`` argument is :class:`~argparse.Namespace`
1477+
and ``.yaml`` file has hierarchical structure, you need to refactor your model to treat
1478+
``hparams`` as :class:`~dict`.
1479+
strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
1480+
returned by this module's state dict.
1481+
\**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
1482+
hyperparameter values.
1483+
1484+
Return:
1485+
:class:`LightningModule` instance with loaded weights and hyperparameters (if available).
1486+
1487+
Note:
1488+
``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
1489+
**class** to call it instead of the :class:`LightningModule` instance.
1490+
1491+
Example::
1492+
1493+
# load weights without mapping ...
1494+
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
1495+
1496+
# or load weights mapping all weights from GPU 1 to GPU 0 ...
1497+
map_location = {'cuda:1':'cuda:0'}
1498+
model = MyLightningModule.load_from_checkpoint(
1499+
'path/to/checkpoint.ckpt',
1500+
map_location=map_location
1501+
)
1502+
1503+
# or load weights and hyperparameters from separate files.
1504+
model = MyLightningModule.load_from_checkpoint(
1505+
'path/to/checkpoint.ckpt',
1506+
hparams_file='/path/to/hparams_file.yaml'
1507+
)
1508+
1509+
# override some of the params with new values
1510+
model = MyLightningModule.load_from_checkpoint(
1511+
PATH,
1512+
num_layers=128,
1513+
pretrained_ckpt_path=NEW_PATH,
1514+
)
1515+
1516+
# predict
1517+
pretrained_model.eval()
1518+
pretrained_model.freeze()
1519+
y_hat = pretrained_model(x)
1520+
"""
1521+
return _load_from_checkpoint(
1522+
cls,
1523+
checkpoint_path,
1524+
map_location,
1525+
hparams_file,
1526+
strict,
1527+
**kwargs,
1528+
)
1529+
14221530
@contextmanager
14231531
def _prevent_trainer_and_dataloaders_deepcopy(self) -> Generator[None, None, None]:
14241532
self._should_prevent_trainer_and_dataloaders_deepcopy = True

src/lightning/pytorch/core/saving.py

Lines changed: 2 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
from copy import deepcopy
2222
from enum import Enum
2323
from pathlib import Path
24-
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
24+
from typing import Any, Callable, cast, Dict, IO, Optional, Type, Union
2525
from warnings import warn
2626

2727
import yaml
2828
from lightning_utilities.core.apply_func import apply_to_collection
29-
from typing_extensions import Self
3029

3130
import lightning.pytorch as pl
3231
from lightning.fabric.utilities.cloud_io import _load as pl_load
@@ -39,8 +38,6 @@
3938
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
4039

4140
log = logging.getLogger(__name__)
42-
PRIMITIVE_TYPES = (bool, int, float, str)
43-
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
4441

4542
if _OMEGACONF_AVAILABLE:
4643
from omegaconf import OmegaConf
@@ -51,103 +48,8 @@
5148
CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments") # used in 0.7.6
5249

5350

54-
class ModelIO:
55-
CHECKPOINT_HYPER_PARAMS_KEY = "hyper_parameters"
56-
CHECKPOINT_HYPER_PARAMS_NAME = "hparams_name"
57-
CHECKPOINT_HYPER_PARAMS_TYPE = "hparams_type"
58-
59-
@classmethod
60-
def load_from_checkpoint(
61-
cls,
62-
checkpoint_path: Union[_PATH, IO],
63-
map_location: _MAP_LOCATION_TYPE = None,
64-
hparams_file: Optional[_PATH] = None,
65-
strict: bool = True,
66-
**kwargs: Any,
67-
) -> Self: # type: ignore[valid-type]
68-
r"""
69-
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
70-
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
71-
72-
Any arguments specified through \*\*kwargs will override args stored in ``"hyper_parameters"``.
73-
74-
Args:
75-
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
76-
map_location:
77-
If your checkpoint saved a GPU model and you now load on CPUs
78-
or a different number of GPUs, use this to map to the new setup.
79-
The behaviour is the same as in :func:`torch.load`.
80-
hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure
81-
as in this example::
82-
83-
drop_prob: 0.2
84-
dataloader:
85-
batch_size: 32
86-
87-
You most likely won't need this since Lightning will always save the hyperparameters
88-
to the checkpoint.
89-
However, if your checkpoint weights don't have the hyperparameters saved,
90-
use this method to pass in a ``.yaml`` file with the hparams you'd like to use.
91-
These will be converted into a :class:`~dict` and passed into your
92-
:class:`LightningModule` for use.
93-
94-
If your model's ``hparams`` argument is :class:`~argparse.Namespace`
95-
and ``.yaml`` file has hierarchical structure, you need to refactor your model to treat
96-
``hparams`` as :class:`~dict`.
97-
strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
98-
returned by this module's state dict.
99-
\**kwargs: Any extra keyword args needed to init the model. Can also be used to override saved
100-
hyperparameter values.
101-
102-
Return:
103-
:class:`LightningModule` instance with loaded weights and hyperparameters (if available).
104-
105-
Note:
106-
``load_from_checkpoint`` is a **class** method. You should use your :class:`LightningModule`
107-
**class** to call it instead of the :class:`LightningModule` instance.
108-
109-
Example::
110-
111-
# load weights without mapping ...
112-
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
113-
114-
# or load weights mapping all weights from GPU 1 to GPU 0 ...
115-
map_location = {'cuda:1':'cuda:0'}
116-
model = MyLightningModule.load_from_checkpoint(
117-
'path/to/checkpoint.ckpt',
118-
map_location=map_location
119-
)
120-
121-
# or load weights and hyperparameters from separate files.
122-
model = MyLightningModule.load_from_checkpoint(
123-
'path/to/checkpoint.ckpt',
124-
hparams_file='/path/to/hparams_file.yaml'
125-
)
126-
127-
# override some of the params with new values
128-
model = MyLightningModule.load_from_checkpoint(
129-
PATH,
130-
num_layers=128,
131-
pretrained_ckpt_path=NEW_PATH,
132-
)
133-
134-
# predict
135-
pretrained_model.eval()
136-
pretrained_model.freeze()
137-
y_hat = pretrained_model(x)
138-
"""
139-
return _load_from_checkpoint(
140-
cls,
141-
checkpoint_path,
142-
map_location,
143-
hparams_file,
144-
strict,
145-
**kwargs,
146-
)
147-
148-
14951
def _load_from_checkpoint(
150-
cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
52+
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
15153
checkpoint_path: Union[_PATH, IO],
15254
map_location: _MAP_LOCATION_TYPE = None,
15355
hparams_file: Optional[_PATH] = None,

0 commit comments

Comments
 (0)