Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue where `DistributedSampler.set_epoch` wasn't getting called during `trainer.predict` ([#16785](https://github.com/Lightning-AI/lightning/pull/16785), [#16826](https://github.com/Lightning-AI/lightning/pull/16826))

- Fixed an issue with comparing torch versions when using a version of torch built from source ([#17030](https://github.com/Lightning-AI/lightning/pull/17030))


## [1.9.4] - 2023-03-01

Expand Down
14 changes: 5 additions & 9 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The LightningModule - an nn.Module with many additional features."""

import logging
import numbers
import operator
import weakref
from contextlib import contextmanager
from pathlib import Path
Expand All @@ -37,7 +37,7 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import compare_version, RequirementCache
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand All @@ -51,12 +51,7 @@
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.distributed import _distributed_available
from lightning.fabric.utilities.imports import (
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_1_13,
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
Expand Down Expand Up @@ -1567,7 +1562,8 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:

self._register_state_dict_hook(state_dict_hook)

if _TORCH_GREATER_EQUAL_1_13:
if compare_version("torch", operator.ge, "1.13.0", use_base_version=True):
# See https://github.com/Lightning-AI/lightning/issues/16644 for why a base-version check is used here
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
else:
# We need to make sure the self inside the method is a weakref proxy
Expand Down