Skip to content

Commit 4dbcef4

Browse files
committed
update for corner cases
1 parent 321e24d commit 4dbcef4

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

library/train_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ def set_current_epoch(self, epoch):
663663
for _ in range(num_epochs):
664664
self.current_epoch += 1
665665
self.shuffle_buckets()
666+
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
666667
else:
667668
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
668669
self.current_epoch = epoch
@@ -5560,6 +5561,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None:
55605561
if epoch == 0:
55615562
self.loss_list.append(loss)
55625563
else:
5564+
while len(self.loss_list) <= step:
5565+
self.loss_list.append(0.0)
55635566
self.loss_total -= self.loss_list[step]
55645567
self.loss_list[step] = loss
55655568
self.loss_total += loss

train_network.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,15 @@ def train(self, args):
493493
# before resuming make hook for saving/loading to save/load the network weights only
494494
def save_model_hook(models, weights, output_dir):
495495
# pop weights of other models than network to save only network weights
496-
if accelerator.is_main_process:
496+
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
497+
if accelerator.is_main_process or args.deepspeed:
497498
remove_indices = []
498499
for i, model in enumerate(models):
499500
if not isinstance(model, type(accelerator.unwrap_model(network))):
500501
remove_indices.append(i)
501502
for i in reversed(remove_indices):
502-
weights.pop(i)
503+
if len(weights) > i:
504+
weights.pop(i)
503505
# print(f"save model hook: {len(weights)} weights will be saved")
504506

505507
# save current ecpoch and step
@@ -813,11 +815,12 @@ def load_model_hook(models, input_dir):
813815
)
814816
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
815817
initial_step *= args.gradient_accumulation_steps
818+
819+
# set epoch to start to make initial_step less than len(train_dataloader)
820+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
816821
else:
817822
# if not, only epoch no is skipped for informative purpose
818-
epoch_to_start = initial_step // math.ceil(
819-
len(train_dataloader) / args.gradient_accumulation_steps
820-
)
823+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
821824
initial_step = 0 # do not skip
822825

823826
global_step = 0
@@ -878,9 +881,11 @@ def remove_model(old_ckpt_name):
878881
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
879882

880883
# training loop
881-
for skip_epoch in range(epoch_to_start): # skip epochs
882-
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
883-
initial_step -= len(train_dataloader)
884+
if initial_step > 0: # only if skip_until_initial_step is specified
885+
for skip_epoch in range(epoch_to_start): # skip epochs
886+
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
887+
initial_step -= len(train_dataloader)
888+
global_step = initial_step
884889

885890
for epoch in range(epoch_to_start, num_train_epochs):
886891
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -892,7 +897,7 @@ def remove_model(old_ckpt_name):
892897

893898
skipped_dataloader = None
894899
if initial_step > 0:
895-
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1)
900+
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
896901
initial_step = 1
897902

898903
for step, batch in enumerate(skipped_dataloader or train_dataloader):

0 commit comments

Comments
 (0)