|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | """The LightningModule - an nn.Module with many additional features."""
|
15 |
| - |
16 | 15 | import logging
|
17 | 16 | import numbers
|
| 17 | +import operator |
18 | 18 | import weakref
|
19 | 19 | from contextlib import contextmanager
|
20 | 20 | from pathlib import Path
|
|
37 | 37 |
|
38 | 38 | import torch
|
39 | 39 | from lightning_utilities.core.apply_func import apply_to_collection
|
40 |
| -from lightning_utilities.core.imports import RequirementCache |
| 40 | +from lightning_utilities.core.imports import compare_version, RequirementCache |
41 | 41 | from torch import ScriptModule, Tensor
|
42 | 42 | from torch.nn import Module
|
43 | 43 | from torch.optim.optimizer import Optimizer
|
|
51 | 51 | from lightning.fabric.utilities.cloud_io import get_filesystem
|
52 | 52 | from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
53 | 53 | from lightning.fabric.utilities.distributed import _distributed_available
|
54 |
| -from lightning.fabric.utilities.imports import ( |
55 |
| - _IS_WINDOWS, |
56 |
| - _TORCH_GREATER_EQUAL_1_13, |
57 |
| - _TORCH_GREATER_EQUAL_2_0, |
58 |
| - _TORCH_GREATER_EQUAL_2_1, |
59 |
| -) |
| 54 | +from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1 |
60 | 55 | from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
61 | 56 | from lightning.fabric.wrappers import _FabricOptimizer
|
62 | 57 | from lightning.pytorch.callbacks.callback import Callback
|
@@ -1576,7 +1571,8 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
|
1576 | 1571 |
|
1577 | 1572 | self._register_state_dict_hook(state_dict_hook)
|
1578 | 1573 |
|
1579 |
| - if _TORCH_GREATER_EQUAL_1_13: |
| 1574 | + if compare_version("torch", operator.ge, "1.13.0", use_base_version=True): |
| 1575 | + # See https://github.com/Lightning-AI/lightning/issues/16644 for why a base-version check is used here |
1580 | 1576 | self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
|
1581 | 1577 | else:
|
1582 | 1578 | # We need to make sure the self inside the method is a weakref proxy
|
|
0 commit comments