@@ -1033,7 +1033,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
1033
1033
return lr_weight
1034
1034
1035
1035
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036
- def prepare_optimizer_params (self , text_encoder_lr , unet_lr , default_lr , , unet_lora_plus_ratio = None , text_encoder_lora_plus_ratio = None ):
1036
+ def prepare_optimizer_params (
1037
+ self ,
1038
+ text_encoder_lr ,
1039
+ unet_lr ,
1040
+ default_lr ,
1041
+ unet_lora_plus_ratio = None ,
1042
+ text_encoder_lora_plus_ratio = None
1043
+ ):
1037
1044
self .requires_grad_ (True )
1038
1045
all_params = []
1039
1046
@@ -1068,7 +1075,11 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
1068
1075
return params
1069
1076
1070
1077
if self .text_encoder_loras :
1071
- params = assemble_params (self .text_encoder_loras , text_encoder_lr , text_encoder_lora_plus_ratio )
1078
+ params = assemble_params (
1079
+ self .text_encoder_loras ,
1080
+ text_encoder_lr if text_encoder_lr is not None else default_lr ,
1081
+ text_encoder_lora_plus_ratio
1082
+ )
1072
1083
all_params .extend (params )
1073
1084
1074
1085
if self .unet_loras :
@@ -1083,14 +1094,19 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
1083
1094
1084
1095
# blockごとにパラメータを設定する
1085
1096
for idx , block_loras in block_idx_to_lora .items ():
1086
- if unet_lr is not None :
1087
- params = assemble_params (block_loras , unet_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1088
- elif default_lr is not None :
1089
- params = assemble_params (block_loras , default_lr * self .get_lr_weight (block_loras [0 ]), unet_lora_plus_ratio )
1097
+ params = assemble_params (
1098
+ block_loras ,
1099
+ (unet_lr if unet_lr is not None else default_lr ) * self .get_lr_weight (block_loras [0 ]),
1100
+ unet_lora_plus_ratio
1101
+ )
1090
1102
all_params .extend (params )
1091
1103
1092
1104
else :
1093
- params = assemble_params (self .unet_loras , unet_lr , unet_lora_plus_ratio )
1105
+ params = assemble_params (
1106
+ self .unet_loras ,
1107
+ default_lr if unet_lr is None else unet_lr ,
1108
+ unet_lora_plus_ratio
1109
+ )
1094
1110
all_params .extend (params )
1095
1111
1096
1112
return all_params
0 commit comments