@@ -157,9 +157,7 @@ def minimize(
157
157
Raises:
158
158
ValueError: If some of the variables are not `Variable` objects.
159
159
"""
160
- self ._decay_var_list = (
161
- set ([v .ref () for v in decay_var_list ]) if decay_var_list else False
162
- )
160
+ self ._set_decay_var_list (var_list , decay_var_list )
163
161
return super ().minimize (
164
162
loss , var_list = var_list , grad_loss = grad_loss , name = name , tape = tape
165
163
)
@@ -186,9 +184,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
186
184
TypeError: If `grads_and_vars` is malformed.
187
185
ValueError: If none of the variables have gradients.
188
186
"""
189
- self ._decay_var_list = (
190
- set ([v .ref () for v in decay_var_list ]) if decay_var_list else False
191
- )
187
+ grads_and_vars = list (grads_and_vars )
188
+ self ._set_decay_var_list ((v for _ , v in grads_and_vars ), decay_var_list )
192
189
return super ().apply_gradients (grads_and_vars , name = name , ** kwargs )
193
190
194
191
def _decay_weights_op (self , var , apply_state = None ):
@@ -245,11 +242,23 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
245
242
grad , var , indices , apply_state = apply_state
246
243
)
247
244
245
+ def _set_decay_var_list (self , var_list , decay_var_list = None ):
246
+ if decay_var_list :
247
+ self ._decay_var_list = set (v .ref () for v in decay_var_list )
248
+ elif self .exclude_from_weight_decay :
249
+ self ._decay_var_list = set (
250
+ v .ref ()
251
+ for v in var_list
252
+ if not is_variable_matched_by_regexes (v , self .exclude_from_weight_decay )
253
+ )
254
+ else :
255
+ self ._decay_var_list = None
256
+
248
257
def _do_use_weight_decay (self , var ):
249
258
"""Whether to use L2 weight decay for `var`."""
250
- if self ._decay_var_list and var . ref () in self . _decay_var_list :
259
+ if self ._decay_var_list is None :
251
260
return True
252
- return not is_variable_matched_by_regexes ( var , self .exclude_from_weight_decay )
261
+ return var . ref () in self ._decay_var_list
253
262
254
263
255
264
@typechecked
0 commit comments