@@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
755
755
return block_idx
756
756
757
757
758
+ def convert_diffusers_to_sai_if_needed (weights_sd ):
759
+ # only supports U-Net LoRA modules
760
+
761
+ found_up_down_blocks = False
762
+ for k in list (weights_sd .keys ()):
763
+ if "down_blocks" in k :
764
+ found_up_down_blocks = True
765
+ break
766
+ if "up_blocks" in k :
767
+ found_up_down_blocks = True
768
+ break
769
+ if not found_up_down_blocks :
770
+ return
771
+
772
+ from library .sdxl_model_util import make_unet_conversion_map
773
+
774
+ unet_conversion_map = make_unet_conversion_map ()
775
+ unet_conversion_map = {hf .replace ("." , "_" )[:- 1 ]: sd .replace ("." , "_" )[:- 1 ] for sd , hf in unet_conversion_map }
776
+
777
+ # # add extra conversion
778
+ # unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
779
+
780
+ logger .info (f"Converting LoRA keys from Diffusers to SAI" )
781
+ lora_unet_prefix = "lora_unet_"
782
+ for k in list (weights_sd .keys ()):
783
+ if not k .startswith (lora_unet_prefix ):
784
+ continue
785
+
786
+ unet_module_name = k [len (lora_unet_prefix ) :].split ("." )[0 ]
787
+
788
+ # search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
789
+ for hf_module_name , sd_module_name in unet_conversion_map .items ():
790
+ if hf_module_name in unet_module_name :
791
+ new_key = (
792
+ lora_unet_prefix
793
+ + unet_module_name .replace (hf_module_name , sd_module_name )
794
+ + k [len (lora_unet_prefix ) + len (unet_module_name ) :]
795
+ )
796
+ weights_sd [new_key ] = weights_sd .pop (k )
797
+ found = True
798
+ break
799
+
800
+ if not found :
801
+ logger .warning (f"Key { k } is not found in unet_conversion_map" )
802
+
803
+
758
804
# Create network from weights for inference, weights are not loaded here (because can be merged)
759
805
def create_network_from_weights (multiplier , file , vae , text_encoder , unet , weights_sd = None , for_inference = False , ** kwargs ):
760
806
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
@@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
768
814
else :
769
815
weights_sd = torch .load (file , map_location = "cpu" )
770
816
817
+ # if keys are Diffusers based, convert to SAI based
818
+ convert_diffusers_to_sai_if_needed (weights_sd )
819
+
771
820
# get dim/alpha mapping
772
821
modules_dim = {}
773
822
modules_alpha = {}
0 commit comments