Skip to content

Commit 37a368a

Browse files
authored
exclude_from_weight_decay for AdamW and SGDW (#2624)
* exclude_from_weight_decay for AdamW and SGDW
1 parent ef80dc4 commit 37a368a

File tree

5 files changed

+151
-55
lines changed

5 files changed

+151
-55
lines changed

tensorflow_addons/optimizers/lamb.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
76 minutes](https://arxiv.org/abs/1904.00962).
1919
"""
2020

21-
import re
2221
import warnings
2322

2423
from typing import Optional, Union, Callable, List
2524
from typeguard import typechecked
2625

2726
import tensorflow as tf
2827
from tensorflow_addons.utils.types import FloatTensorLike
28+
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes
2929

3030

3131
@tf.keras.utils.register_keras_serializable(package="Addons")
@@ -163,12 +163,11 @@ def _resource_apply_dense(self, grad, var, apply_state=None):
163163
v_sqrt = tf.sqrt(v_t_hat)
164164
update = m_t_hat / (v_sqrt + coefficients["epsilon"])
165165

166-
var_name = self._get_variable_name(var.name)
167-
if self._do_use_weight_decay(var_name):
166+
if self._do_use_weight_decay(var):
168167
update += coefficients["weight_decay"] * var
169168

170169
ratio = 1.0
171-
if self._do_layer_adaptation(var_name):
170+
if self._do_layer_adaptation(var):
172171
w_norm = tf.norm(var, ord=2)
173172
g_norm = tf.norm(update, ord=2)
174173
ratio = tf.where(
@@ -206,12 +205,11 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
206205
v_sqrt = tf.sqrt(v_t_hat)
207206
update = m_t_hat / (v_sqrt + coefficients["epsilon"])
208207

209-
var_name = self._get_variable_name(var.name)
210-
if self._do_use_weight_decay(var_name):
208+
if self._do_use_weight_decay(var):
211209
update += coefficients["weight_decay"] * var
212210

213211
ratio = 1.0
214-
if self._do_layer_adaptation(var_name):
212+
if self._do_layer_adaptation(var):
215213
w_norm = tf.norm(var, ord=2)
216214
g_norm = tf.norm(update, ord=2)
217215
ratio = tf.where(
@@ -241,26 +239,15 @@ def get_config(self):
241239
)
242240
return config
243241

244-
def _do_use_weight_decay(self, param_name):
242+
def _do_use_weight_decay(self, variable):
245243
"""Whether to use L2 weight decay for `param_name`."""
246-
if self.exclude_from_weight_decay:
247-
for r in self.exclude_from_weight_decay:
248-
if re.search(r, param_name) is not None:
249-
return False
250-
return True
244+
return not is_variable_matched_by_regexes(
245+
variable, self.exclude_from_weight_decay
246+
)
251247

252-
def _do_layer_adaptation(self, param_name):
248+
def _do_layer_adaptation(self, variable):
253249
"""Whether to do layer-wise learning rate adaptation for
254250
`param_name`."""
255-
if self.exclude_from_layer_adaptation:
256-
for r in self.exclude_from_layer_adaptation:
257-
if re.search(r, param_name) is not None:
258-
return False
259-
return True
260-
261-
def _get_variable_name(self, param_name):
262-
"""Get the variable name from the tensor name."""
263-
m = re.match("^(.*):\\d+$", param_name)
264-
if m is not None:
265-
param_name = m.group(1)
266-
return param_name
251+
return not is_variable_matched_by_regexes(
252+
variable, self.exclude_from_layer_adaptation
253+
)

tensorflow_addons/optimizers/tests/lamb_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -335,20 +335,22 @@ def test_get_config():
335335

336336
def test_exclude_weight_decay():
337337
opt = lamb.LAMB(0.01, weight_decay=0.01, exclude_from_weight_decay=["var1"])
338-
assert opt._do_use_weight_decay("var0")
339-
assert not opt._do_use_weight_decay("var1")
340-
assert not opt._do_use_weight_decay("var1_weight")
338+
assert opt._do_use_weight_decay(tf.Variable([], name="var0"))
339+
assert not opt._do_use_weight_decay(tf.Variable([], name="var1"))
340+
assert not opt._do_use_weight_decay(tf.Variable([], name="var1_weight"))
341341

342342

343343
def test_exclude_layer_adaptation():
344344
opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"])
345-
assert opt._do_layer_adaptation("var0")
346-
assert not opt._do_layer_adaptation("var1")
347-
assert not opt._do_layer_adaptation("var1_weight")
345+
assert opt._do_layer_adaptation(tf.Variable([], name="var0"))
346+
assert not opt._do_layer_adaptation(tf.Variable([], name="var1"))
347+
assert not opt._do_layer_adaptation(tf.Variable([], name="var1_weight"))
348348

349349

350350
def test_serialization():
351-
optimizer = lamb.LAMB(1e-4)
351+
optimizer = lamb.LAMB(
352+
1e-4, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"]
353+
)
352354
config = tf.keras.optimizers.serialize(optimizer)
353355
new_optimizer = tf.keras.optimizers.deserialize(config)
354356
assert new_optimizer.get_config() == optimizer.get_config()

tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def do_test(
8080
opt = optimizer(**optimizer_kwargs)
8181
# Create the update op.
8282
# Run 3 steps of the optimizer
83+
optimizer_kwargs.pop("exclude_from_weight_decay", None)
8384
for _ in range(3):
8485
if do_decay_var_list:
8586
opt.apply_gradients(
@@ -241,6 +242,31 @@ def test_basic_decay_var_list_adamw(dtype):
241242
)
242243

243244

245+
def test_exclude_weight_decay_adamw():
246+
optimizer = weight_decay_optimizers.AdamW(
247+
learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
248+
)
249+
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
250+
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
251+
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))
252+
253+
254+
@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
255+
def test_var_list_with_exclude_list_adamw(dtype):
256+
do_test(
257+
dtype,
258+
weight_decay_optimizers.AdamW,
259+
adamw_update_numpy,
260+
do_decay_var_list=True,
261+
learning_rate=0.001,
262+
beta_1=0.9,
263+
beta_2=0.999,
264+
epsilon=1e-8,
265+
weight_decay=WEIGHT_DECAY,
266+
exclude_from_weight_decay=["var0_*", "var1_*"],
267+
)
268+
269+
244270
def test_keras_fit():
245271
"""Check if calling model.fit works."""
246272
model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)])
@@ -341,6 +367,30 @@ def test_basic_decay_var_list_sgdw(dtype):
341367
)
342368

343369

370+
def test_exclude_weight_decay_sgdw():
371+
optimizer = weight_decay_optimizers.SGDW(
372+
learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
373+
)
374+
assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
375+
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
376+
assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))
377+
378+
379+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
380+
@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])
381+
def test_var_list_with_exclude_list_sgdw(dtype):
382+
do_test(
383+
dtype,
384+
weight_decay_optimizers.SGDW,
385+
sgdw_update_numpy,
386+
do_decay_var_list=True,
387+
learning_rate=0.001,
388+
momentum=0.9,
389+
weight_decay=WEIGHT_DECAY,
390+
exclude_from_weight_decay=["var0_*", "var1_*"],
391+
)
392+
393+
344394
@pytest.mark.parametrize(
345395
"optimizer",
346396
[
@@ -379,7 +429,9 @@ def test_optimizer_sparse(dtype, optimizer):
379429

380430

381431
def test_serialization():
382-
optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4)
432+
optimizer = weight_decay_optimizers.AdamW(
433+
learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
434+
)
383435
config = tf.keras.optimizers.serialize(optimizer)
384436
new_optimizer = tf.keras.optimizers.deserialize(config)
385437
assert new_optimizer.get_config() == optimizer.get_config()

tensorflow_addons/optimizers/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# ==============================================================================
1515
"""Additional Utilities used for tfa.optimizers."""
1616

17+
import re
1718
import tensorflow as tf
19+
from typing import List
1820

1921

2022
def fit_bn(model, *args, **kwargs):
@@ -51,3 +53,23 @@ def fit_bn(model, *args, **kwargs):
5153

5254
model.trainable = _trainable
5355
model._metrics = _metrics
56+
57+
58+
def get_variable_name(variable) -> str:
59+
"""Get the variable name from the variable tensor."""
60+
param_name = variable.name
61+
m = re.match("^(.*):\\d+$", param_name)
62+
if m is not None:
63+
param_name = m.group(1)
64+
return param_name
65+
66+
67+
def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool:
68+
"""Whether variable is matched in regexes list by its name."""
69+
if regexes:
70+
# var_name = get_variable_name(variable)
71+
var_name = variable.name
72+
for r in regexes:
73+
if re.search(r, var_name):
74+
return True
75+
return False

tensorflow_addons/optimizers/weight_decay_optimizers.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
import tensorflow as tf
1818
from tensorflow_addons.utils.types import FloatTensorLike
19+
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes
1920

2021
from typeguard import typechecked
21-
from typing import Union, Callable, Type
22+
from typing import Union, Callable, Type, Optional, List
2223

2324

2425
class DecoupledWeightDecayExtension:
@@ -71,24 +72,40 @@ def __init__(self, weight_decay, *args, **kwargs):
7172
"""
7273

7374
@typechecked
74-
def __init__(self, weight_decay: Union[FloatTensorLike, Callable], **kwargs):
75+
def __init__(
76+
self,
77+
weight_decay: Union[FloatTensorLike, Callable],
78+
exclude_from_weight_decay: Optional[List[str]] = None,
79+
**kwargs,
80+
):
7581
"""Extension class that adds weight decay to an optimizer.
7682
7783
Args:
7884
weight_decay: A `Tensor`, a floating point value, or a schedule
7985
that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
8086
to decay the variable by, in the update step.
87+
exclude_from_weight_decay: List of regex patterns of
88+
variables excluded from weight decay. Variables whose name
89+
contain a substring matching the pattern will be excluded.
90+
Note `decay_var_list` in `minimize` or `apply_gradients` takes
91+
priority over `exclude_from_weight_decay` if specified.
8192
**kwargs: Optional list or tuple or set of `Variable` objects to
8293
decay.
8394
"""
8495
wd = kwargs.pop("weight_decay", weight_decay)
8596
super().__init__(**kwargs)
8697
self._decay_var_list = None # is set in minimize or apply_gradients
8798
self._set_hyper("weight_decay", wd)
99+
self.exclude_from_weight_decay = exclude_from_weight_decay
88100

89101
def get_config(self):
90102
config = super().get_config()
91-
config.update({"weight_decay": self._serialize_hyperparameter("weight_decay")})
103+
config.update(
104+
{
105+
"weight_decay": self._serialize_hyperparameter("weight_decay"),
106+
"exclude_from_weight_decay": self.exclude_from_weight_decay,
107+
}
108+
)
92109
return config
93110

94111
@classmethod
@@ -130,7 +147,8 @@ def minimize(
130147
grad_loss: Optional. A `Tensor` holding the gradient computed for
131148
`loss`.
132149
decay_var_list: Optional list of variables to be decayed. Defaults
133-
to all variables in var_list.
150+
to all variables in var_list. Note `decay_var_list` takes
151+
priority over `exclude_from_weight_decay` if specified.
134152
name: Optional name for the returned operation.
135153
tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
136154
`Tensor`, the tape that computed the `loss` must be provided.
@@ -154,10 +172,11 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
154172
155173
Args:
156174
grads_and_vars: List of (gradient, variable) pairs.
157-
name: Optional name for the returned operation. Default to the
175+
name: Optional name for the returned operation. Default to the
158176
name passed to the `Optimizer` constructor.
159177
decay_var_list: Optional list of variables to be decayed. Defaults
160-
to all variables in var_list.
178+
to all variables in var_list. Note `decay_var_list` takes
179+
priority over `exclude_from_weight_decay` if specified.
161180
**kwargs: Additional arguments to pass to the base optimizer's
162181
apply_gradient method, e.g., TF2.2 added an argument
163182
`experimental_aggregate_gradients`.
@@ -173,7 +192,7 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
173192
return super().apply_gradients(grads_and_vars, name=name, **kwargs)
174193

175194
def _decay_weights_op(self, var, apply_state=None):
176-
if not self._decay_var_list or var.ref() in self._decay_var_list:
195+
if self._do_use_weight_decay(var):
177196
var_device, var_dtype = var.device, var.dtype.base_dtype
178197
coefficients = (apply_state or {}).get(
179198
(var_device, var_dtype)
@@ -183,7 +202,7 @@ def _decay_weights_op(self, var, apply_state=None):
183202
return tf.no_op()
184203

185204
def _decay_weights_sparse_op(self, var, indices, apply_state=None):
186-
if not self._decay_var_list or var.ref() in self._decay_var_list:
205+
if self._do_use_weight_decay(var):
187206
var_device, var_dtype = var.device, var.dtype.base_dtype
188207
coefficients = (apply_state or {}).get(
189208
(var_device, var_dtype)
@@ -226,6 +245,12 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
226245
grad, var, indices, apply_state=apply_state
227246
)
228247

248+
def _do_use_weight_decay(self, var):
249+
"""Whether to use L2 weight decay for `var`."""
250+
if self._decay_var_list and var.ref() in self._decay_var_list:
251+
return True
252+
return not is_variable_matched_by_regexes(var, self.exclude_from_weight_decay)
253+
229254

230255
@typechecked
231256
def extend_with_decoupled_weight_decay(
@@ -243,9 +268,13 @@ def extend_with_decoupled_weight_decay(
243268
The API of the new optimizer class slightly differs from the API of the
244269
base optimizer:
245270
- The first argument to the constructor is the weight decay rate.
271+
- Optional keyword argument `exclude_from_weight_decay` accepts list of
272+
regex patterns of variables excluded from weight decay. Variables whose
273+
name contain a substring matching the pattern will be excluded.
246274
- `minimize` and `apply_gradients` accept the optional keyword argument
247275
`decay_var_list`, which specifies the variables that should be decayed.
248-
If `None`, all variables that are optimized are decayed.
276+
Note this takes priority over `exclude_from_weight_decay` if specified.
277+
If both `None`, all variables that are optimized are decayed.
249278
250279
Usage example:
251280
```python
@@ -376,12 +405,14 @@ def __init__(
376405
nesterov: boolean. Whether to apply Nesterov momentum.
377406
name: Optional name prefix for the operations created when applying
378407
gradients. Defaults to 'SGD'.
379-
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
380-
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
381-
norm; `clipvalue` is clip gradients by value, `decay` is
382-
included for backward compatibility to allow time inverse decay
383-
of learning rate. `lr` is included for backward compatibility,
384-
recommended to use `learning_rate` instead.
408+
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
409+
`lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
410+
gradients by norm; `clipvalue` is clip gradients by value.
411+
`decay` is included for backward compatibility to allow time
412+
inverse decay of learning rate. `lr` is included for backward
413+
compatibility, recommended to use `learning_rate` instead.
414+
`exclude_from_weight_decay` accepts list of regex patterns of
415+
variables excluded from weight decay.
385416
"""
386417
super().__init__(
387418
weight_decay,
@@ -466,12 +497,14 @@ def __init__(
466497
beyond".
467498
name: Optional name for the operations created when applying
468499
gradients. Defaults to "AdamW".
469-
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
470-
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
471-
norm; `clipvalue` is clip gradients by value, `decay` is
472-
included for backward compatibility to allow time inverse decay
473-
of learning rate. `lr` is included for backward compatibility,
474-
recommended to use `learning_rate` instead.
500+
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
501+
`lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
502+
gradients by norm; `clipvalue` is clip gradients by value.
503+
`decay` is included for backward compatibility to allow time
504+
inverse decay of learning rate. `lr` is included for backward
505+
compatibility, recommended to use `learning_rate` instead.
506+
`exclude_from_weight_decay` accepts list of regex patterns of
507+
variables excluded from weight decay.
475508
"""
476509
super().__init__(
477510
weight_decay,

0 commit comments

Comments
 (0)