-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
in train_dreambooth_lora_sdxl.py
you can see those codes:
if not args.train_text_encoder:
unet_added_conditions = {
"time_ids": add_time_ids,
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
}
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]
else:
unet_added_conditions = {"time_ids": add_time_ids}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
unet_added_conditions.update(
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
timesteps,
prompt_embeds_input,
added_cond_kwargs=unet_added_conditions,
return_dict=False,
)[0]
it means we should repeat prompt_embeds many time for every picture.
but in collect_fn:
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]
original_sizes = [example["original_size"] for example in examples]
crop_top_lefts = [example["crop_top_left"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
prompts += [example["class_prompt"] for example in examples]
original_sizes += [example["original_size"] for example in examples]
crop_top_lefts += [example["crop_top_left"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {
"pixel_values": pixel_values,
"prompts": prompts,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
}
return batch
you can see class_images are directly append to the batch.
but when we have no train_dataset.custom_instance_prompts provided, the prompt_embeds like:
if not train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states
unet_add_text_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps
else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
they seems to be[ins_token, cls_token] or [ins_embed, cls_embed]
so back to the code like
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
are they wrong? because if you use repeat
, the embed will like: prompt_embeds_input = [ins_embed, cls_embed, ins_embed, cls_embed, ....]
but i think it should be [ins_embed, ins_embed, ....cls_embed, cls_embed]
Reproduction
None, just the question.
Logs
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working