Skip to content

Commit bc8c0f6

Browse files
authored
Use base version check before calling _register_load_state_dict_pre_hook (#17030)
1 parent f3a20d0 commit bc8c0f6

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
436436

437437
- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))
438438

439+
- Fixed an issue with comparing torch versions when using a version of torch built from source ([#17030](https://github.com/Lightning-AI/lightning/pull/17030))
440+
439441

440442
## [1.9.4] - 2023-03-01
441443

src/lightning/pytorch/core/module.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""The LightningModule - an nn.Module with many additional features."""
15-
1615
import logging
1716
import numbers
17+
import operator
1818
import weakref
1919
from contextlib import contextmanager
2020
from pathlib import Path
@@ -37,7 +37,7 @@
3737

3838
import torch
3939
from lightning_utilities.core.apply_func import apply_to_collection
40-
from lightning_utilities.core.imports import RequirementCache
40+
from lightning_utilities.core.imports import compare_version, RequirementCache
4141
from torch import ScriptModule, Tensor
4242
from torch.nn import Module
4343
from torch.optim.optimizer import Optimizer
@@ -51,12 +51,7 @@
5151
from lightning.fabric.utilities.cloud_io import get_filesystem
5252
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
5353
from lightning.fabric.utilities.distributed import _distributed_available
54-
from lightning.fabric.utilities.imports import (
55-
_IS_WINDOWS,
56-
_TORCH_GREATER_EQUAL_1_13,
57-
_TORCH_GREATER_EQUAL_2_0,
58-
_TORCH_GREATER_EQUAL_2_1,
59-
)
54+
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
6055
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
6156
from lightning.fabric.wrappers import _FabricOptimizer
6257
from lightning.pytorch.callbacks.callback import Callback
@@ -1576,7 +1571,8 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
15761571

15771572
self._register_state_dict_hook(state_dict_hook)
15781573

1579-
if _TORCH_GREATER_EQUAL_1_13:
1574+
if compare_version("torch", operator.ge, "1.13.0", use_base_version=True):
1575+
# See https://github.com/Lightning-AI/lightning/issues/16644 for why a base-version check is used here
15801576
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
15811577
else:
15821578
# We need to make sure the self inside the method is a weakref proxy

0 commit comments

Comments
 (0)