@@ -278,7 +278,6 @@ def any_lightning_module_function_or_hook(self):
278
278
prefix: A string to put at the beginning of metric keys.
279
279
experiment: WandB experiment object. Automatically set when creating a run.
280
280
checkpoint_name: Name of the model checkpoint artifact being logged.
281
- add_file_policy: If "mutable", copies file to tempdirectory before upload.
282
281
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
283
282
284
283
Raises:
@@ -305,7 +304,6 @@ def __init__(
305
304
experiment : Union ["Run" , "RunDisabled" , None ] = None ,
306
305
prefix : str = "" ,
307
306
checkpoint_name : Optional [str ] = None ,
308
- add_file_policy : Literal ["mutable" , "immutable" ] = "mutable" ,
309
307
** kwargs : Any ,
310
308
) -> None :
311
309
if not _WANDB_AVAILABLE :
@@ -324,8 +322,7 @@ def __init__(
324
322
self ._prefix = prefix
325
323
self ._experiment = experiment
326
324
self ._logged_model_time : dict [str , float ] = {}
327
- self ._checkpoint_callbacks : dict [int , ModelCheckpoint ] = {}
328
- self .add_file_policy = add_file_policy
325
+ self ._checkpoint_callback : Optional [ModelCheckpoint ] = None
329
326
330
327
# paths are processed as strings
331
328
if save_dir is not None :
@@ -594,7 +591,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
594
591
if self ._log_model == "all" or self ._log_model is True and checkpoint_callback .save_top_k == - 1 :
595
592
self ._scan_and_log_checkpoints (checkpoint_callback )
596
593
elif self ._log_model is True :
597
- self ._checkpoint_callbacks [ id ( checkpoint_callback )] = checkpoint_callback
594
+ self ._checkpoint_callback = checkpoint_callback
598
595
599
596
@staticmethod
600
597
@rank_zero_only
@@ -647,9 +644,8 @@ def finalize(self, status: str) -> None:
647
644
# Currently, checkpoints only get logged on success
648
645
return
649
646
# log checkpoints as artifacts
650
- if self ._experiment is not None :
651
- for checkpoint_callback in self ._checkpoint_callbacks .values ():
652
- self ._scan_and_log_checkpoints (checkpoint_callback )
647
+ if self ._checkpoint_callback and self ._experiment is not None :
648
+ self ._scan_and_log_checkpoints (self ._checkpoint_callback )
653
649
654
650
def _scan_and_log_checkpoints (self , checkpoint_callback : ModelCheckpoint ) -> None :
655
651
import wandb
@@ -679,7 +675,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
679
675
if not self ._checkpoint_name :
680
676
self ._checkpoint_name = f"model-{ self .experiment .id } "
681
677
artifact = wandb .Artifact (name = self ._checkpoint_name , type = "model" , metadata = metadata )
682
- artifact .add_file (p , name = "model.ckpt" , policy = self . add_file_policy )
678
+ artifact .add_file (p , name = "model.ckpt" )
683
679
aliases = ["latest" , "best" ] if p == checkpoint_callback .best_model_path else ["latest" ]
684
680
self .experiment .log_artifact (artifact , aliases = aliases )
685
681
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
0 commit comments