@@ -493,13 +493,15 @@ def train(self, args):
493
493
# before resuming make hook for saving/loading to save/load the network weights only
494
494
def save_model_hook (models , weights , output_dir ):
495
495
# pop weights of other models than network to save only network weights
496
- if accelerator .is_main_process :
496
+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
497
+ if accelerator .is_main_process or args .deepspeed :
497
498
remove_indices = []
498
499
for i , model in enumerate (models ):
499
500
if not isinstance (model , type (accelerator .unwrap_model (network ))):
500
501
remove_indices .append (i )
501
502
for i in reversed (remove_indices ):
502
- weights .pop (i )
503
+ if len (weights ) > i :
504
+ weights .pop (i )
503
505
# print(f"save model hook: {len(weights)} weights will be saved")
504
506
505
507
# save current ecpoch and step
@@ -813,11 +815,12 @@ def load_model_hook(models, input_dir):
813
815
)
814
816
logger .info (f"skipping { initial_step } steps / { initial_step } ステップをスキップします" )
815
817
initial_step *= args .gradient_accumulation_steps
818
+
819
+ # set epoch to start to make initial_step less than len(train_dataloader)
820
+ epoch_to_start = initial_step // math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
816
821
else :
817
822
# if not, only epoch no is skipped for informative purpose
818
- epoch_to_start = initial_step // math .ceil (
819
- len (train_dataloader ) / args .gradient_accumulation_steps
820
- )
823
+ epoch_to_start = initial_step // math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
821
824
initial_step = 0 # do not skip
822
825
823
826
global_step = 0
@@ -878,9 +881,11 @@ def remove_model(old_ckpt_name):
878
881
self .sample_images (accelerator , args , 0 , global_step , accelerator .device , vae , tokenizer , text_encoder , unet )
879
882
880
883
# training loop
881
- for skip_epoch in range (epoch_to_start ): # skip epochs
882
- logger .info (f"skipping epoch { skip_epoch + 1 } because initial_step (multiplied) is { initial_step } " )
883
- initial_step -= len (train_dataloader )
884
+ if initial_step > 0 : # only if skip_until_initial_step is specified
885
+ for skip_epoch in range (epoch_to_start ): # skip epochs
886
+ logger .info (f"skipping epoch { skip_epoch + 1 } because initial_step (multiplied) is { initial_step } " )
887
+ initial_step -= len (train_dataloader )
888
+ global_step = initial_step
884
889
885
890
for epoch in range (epoch_to_start , num_train_epochs ):
886
891
accelerator .print (f"\n epoch { epoch + 1 } /{ num_train_epochs } " )
@@ -892,7 +897,7 @@ def remove_model(old_ckpt_name):
892
897
893
898
skipped_dataloader = None
894
899
if initial_step > 0 :
895
- skipped_dataloader = accelerator .skip_first_batches (train_dataloader , initial_step - 1 )
900
+ skipped_dataloader = accelerator .skip_first_batches (train_dataloader , initial_step - 1 )
896
901
initial_step = 1
897
902
898
903
for step , batch in enumerate (skipped_dataloader or train_dataloader ):
0 commit comments