Skip to content

Commit 146edce

Browse files
committed
support Diffusers' based SDXL LoRA key for inference
1 parent 153764a commit 146edce

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

networks/lora.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
755755
return block_idx
756756

757757

758+
def convert_diffusers_to_sai_if_needed(weights_sd):
759+
# only supports U-Net LoRA modules
760+
761+
found_up_down_blocks = False
762+
for k in list(weights_sd.keys()):
763+
if "down_blocks" in k:
764+
found_up_down_blocks = True
765+
break
766+
if "up_blocks" in k:
767+
found_up_down_blocks = True
768+
break
769+
if not found_up_down_blocks:
770+
return
771+
772+
from library.sdxl_model_util import make_unet_conversion_map
773+
774+
unet_conversion_map = make_unet_conversion_map()
775+
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
776+
777+
# # add extra conversion
778+
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
779+
780+
logger.info(f"Converting LoRA keys from Diffusers to SAI")
781+
lora_unet_prefix = "lora_unet_"
782+
for k in list(weights_sd.keys()):
783+
if not k.startswith(lora_unet_prefix):
784+
continue
785+
786+
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
787+
788+
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
789+
for hf_module_name, sd_module_name in unet_conversion_map.items():
790+
if hf_module_name in unet_module_name:
791+
new_key = (
792+
lora_unet_prefix
793+
+ unet_module_name.replace(hf_module_name, sd_module_name)
794+
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
795+
)
796+
weights_sd[new_key] = weights_sd.pop(k)
797+
found = True
798+
break
799+
800+
if not found:
801+
logger.warning(f"Key {k} is not found in unet_conversion_map")
802+
803+
758804
# Create network from weights for inference, weights are not loaded here (because can be merged)
759805
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
760806
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
@@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
768814
else:
769815
weights_sd = torch.load(file, map_location="cpu")
770816

817+
# if keys are Diffusers based, convert to SAI based
818+
convert_diffusers_to_sai_if_needed(weights_sd)
819+
771820
# get dim/alpha mapping
772821
modules_dim = {}
773822
modules_alpha = {}

0 commit comments

Comments
 (0)