Skip to content

Commit da14c3b

Browse files
authored
Fix exclude_from_weight_decay performance regression (#2676)
1 parent a5cd76d commit da14c3b

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,14 @@ def test_exclude_weight_decay_adamw():
246246
optimizer = weight_decay_optimizers.AdamW(
247247
learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
248248
)
249-
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
250-
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
251-
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))
249+
var0 = tf.Variable([], name="var0")
250+
var1 = tf.Variable([], name="var1")
251+
var1_weight = tf.Variable([], name="var1_weight")
252+
253+
optimizer._set_decay_var_list([var0, var1, var1_weight])
254+
assert optimizer._do_use_weight_decay(var0)
255+
assert not optimizer._do_use_weight_decay(var1)
256+
assert not optimizer._do_use_weight_decay(var1_weight)
252257

253258

254259
@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
@@ -371,9 +376,14 @@ def test_exclude_weight_decay_sgdw():
371376
optimizer = weight_decay_optimizers.SGDW(
372377
learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
373378
)
374-
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
375-
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
376-
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))
379+
var0 = tf.Variable([], name="var0")
380+
var1 = tf.Variable([], name="var1")
381+
var1_weight = tf.Variable([], name="var1_weight")
382+
383+
optimizer._set_decay_var_list([var0, var1, var1_weight])
384+
assert optimizer._do_use_weight_decay(var0)
385+
assert not optimizer._do_use_weight_decay(var1)
386+
assert not optimizer._do_use_weight_decay(var1_weight)
377387

378388

379389
@pytest.mark.usefixtures("maybe_run_functions_eagerly")

tensorflow_addons/optimizers/weight_decay_optimizers.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ def minimize(
157157
Raises:
158158
ValueError: If some of the variables are not `Variable` objects.
159159
"""
160-
self._decay_var_list = (
161-
set([v.ref() for v in decay_var_list]) if decay_var_list else False
162-
)
160+
self._set_decay_var_list(var_list, decay_var_list)
163161
return super().minimize(
164162
loss, var_list=var_list, grad_loss=grad_loss, name=name, tape=tape
165163
)
@@ -186,9 +184,8 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
186184
TypeError: If `grads_and_vars` is malformed.
187185
ValueError: If none of the variables have gradients.
188186
"""
189-
self._decay_var_list = (
190-
set([v.ref() for v in decay_var_list]) if decay_var_list else False
191-
)
187+
grads_and_vars = list(grads_and_vars)
188+
self._set_decay_var_list((v for _, v in grads_and_vars), decay_var_list)
192189
return super().apply_gradients(grads_and_vars, name=name, **kwargs)
193190

194191
def _decay_weights_op(self, var, apply_state=None):
@@ -245,11 +242,23 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
245242
grad, var, indices, apply_state=apply_state
246243
)
247244

245+
def _set_decay_var_list(self, var_list, decay_var_list=None):
246+
if decay_var_list:
247+
self._decay_var_list = set(v.ref() for v in decay_var_list)
248+
elif self.exclude_from_weight_decay:
249+
self._decay_var_list = set(
250+
v.ref()
251+
for v in var_list
252+
if not is_variable_matched_by_regexes(v, self.exclude_from_weight_decay)
253+
)
254+
else:
255+
self._decay_var_list = None
256+
248257
def _do_use_weight_decay(self, var):
249258
"""Whether to use L2 weight decay for `var`."""
250-
if self._decay_var_list and var.ref() in self._decay_var_list:
259+
if self._decay_var_list is None:
251260
return True
252-
return not is_variable_matched_by_regexes(var, self.exclude_from_weight_decay)
261+
return var.ref() in self._decay_var_list
253262

254263

255264
@typechecked

0 commit comments

Comments
 (0)