18
18
from datetime import timedelta
19
19
from typing import Optional , Union
20
20
21
- from lightning_utilities import module_available
21
+ from lightning_utilities . core . imports import RequirementCache
22
22
23
23
import lightning .pytorch as pl
24
24
from lightning .fabric .utilities .registry import _load_external_callbacks
@@ -93,7 +93,7 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
93
93
" but found `ModelCheckpoint` in callbacks list."
94
94
)
95
95
elif enable_checkpointing :
96
- if module_available ("litmodels" ) and self .trainer ._model_registry :
96
+ if RequirementCache ("litmodels >=0.1.7 " ) and self .trainer ._model_registry :
97
97
trainer_source = inspect .getmodule (self .trainer )
98
98
if trainer_source is None or not isinstance (trainer_source .__package__ , str ):
99
99
raise RuntimeError ("Unable to determine the source of the trainer." )
@@ -103,12 +103,11 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
103
103
else :
104
104
from litmodels .integrations .checkpoints import LightningModelCheckpoint as LitModelCheckpoint
105
105
106
- model_checkpoint = LitModelCheckpoint (model_name = self .trainer ._model_registry )
106
+ model_checkpoint = LitModelCheckpoint (model_registry = self .trainer ._model_registry )
107
107
else :
108
108
rank_zero_info (
109
- "You are using the default ModelCheckpoint callback."
110
- " Install `litmodels` package to use the `LitModelCheckpoint` instead"
111
- " for seamless uploading to the Lightning model registry."
109
+ "Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable"
110
+ " `LitModelCheckpoint` for automatic upload to the Lightning model registry."
112
111
)
113
112
model_checkpoint = ModelCheckpoint ()
114
113
self .trainer .callbacks .append (model_checkpoint )
0 commit comments