16
16
17
17
import tensorflow as tf
18
18
from tensorflow_addons .utils .types import FloatTensorLike
19
+ from tensorflow_addons .optimizers .utils import is_variable_matched_by_regexes
19
20
20
21
from typeguard import typechecked
21
- from typing import Union , Callable , Type
22
+ from typing import Union , Callable , Type , Optional , List
22
23
23
24
24
25
class DecoupledWeightDecayExtension :
@@ -71,24 +72,40 @@ def __init__(self, weight_decay, *args, **kwargs):
71
72
"""
72
73
73
74
@typechecked
74
- def __init__ (self , weight_decay : Union [FloatTensorLike , Callable ], ** kwargs ):
75
+ def __init__ (
76
+ self ,
77
+ weight_decay : Union [FloatTensorLike , Callable ],
78
+ exclude_from_weight_decay : Optional [List [str ]] = None ,
79
+ ** kwargs ,
80
+ ):
75
81
"""Extension class that adds weight decay to an optimizer.
76
82
77
83
Args:
78
84
weight_decay: A `Tensor`, a floating point value, or a schedule
79
85
that is a `tf.keras.optimizers.schedules.LearningRateSchedule`
80
86
to decay the variable by, in the update step.
87
+ exclude_from_weight_decay: List of regex patterns of
88
+ variables excluded from weight decay. Variables whose name
89
+ contain a substring matching the pattern will be excluded.
90
+ Note `decay_var_list` in `minimize` or `apply_gradients` takes
91
+ priority over `exclude_from_weight_decay` if specified.
81
92
**kwargs: Optional list or tuple or set of `Variable` objects to
82
93
decay.
83
94
"""
84
95
wd = kwargs .pop ("weight_decay" , weight_decay )
85
96
super ().__init__ (** kwargs )
86
97
self ._decay_var_list = None # is set in minimize or apply_gradients
87
98
self ._set_hyper ("weight_decay" , wd )
99
+ self .exclude_from_weight_decay = exclude_from_weight_decay
88
100
89
101
def get_config (self ):
90
102
config = super ().get_config ()
91
- config .update ({"weight_decay" : self ._serialize_hyperparameter ("weight_decay" )})
103
+ config .update (
104
+ {
105
+ "weight_decay" : self ._serialize_hyperparameter ("weight_decay" ),
106
+ "exclude_from_weight_decay" : self .exclude_from_weight_decay ,
107
+ }
108
+ )
92
109
return config
93
110
94
111
@classmethod
@@ -130,7 +147,8 @@ def minimize(
130
147
grad_loss: Optional. A `Tensor` holding the gradient computed for
131
148
`loss`.
132
149
decay_var_list: Optional list of variables to be decayed. Defaults
133
- to all variables in var_list.
150
+ to all variables in var_list. Note `decay_var_list` takes
151
+ priority over `exclude_from_weight_decay` if specified.
134
152
name: Optional name for the returned operation.
135
153
tape: (Optional) `tf.GradientTape`. If `loss` is provided as a
136
154
`Tensor`, the tape that computed the `loss` must be provided.
@@ -154,10 +172,11 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
154
172
155
173
Args:
156
174
grads_and_vars: List of (gradient, variable) pairs.
157
- name: Optional name for the returned operation. Default to the
175
+ name: Optional name for the returned operation. Default to the
158
176
name passed to the `Optimizer` constructor.
159
177
decay_var_list: Optional list of variables to be decayed. Defaults
160
- to all variables in var_list.
178
+ to all variables in var_list. Note `decay_var_list` takes
179
+ priority over `exclude_from_weight_decay` if specified.
161
180
**kwargs: Additional arguments to pass to the base optimizer's
162
181
apply_gradient method, e.g., TF2.2 added an argument
163
182
`experimental_aggregate_gradients`.
@@ -173,7 +192,7 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar
173
192
return super ().apply_gradients (grads_and_vars , name = name , ** kwargs )
174
193
175
194
def _decay_weights_op (self , var , apply_state = None ):
176
- if not self ._decay_var_list or var . ref () in self . _decay_var_list :
195
+ if self ._do_use_weight_decay ( var ) :
177
196
var_device , var_dtype = var .device , var .dtype .base_dtype
178
197
coefficients = (apply_state or {}).get (
179
198
(var_device , var_dtype )
@@ -183,7 +202,7 @@ def _decay_weights_op(self, var, apply_state=None):
183
202
return tf .no_op ()
184
203
185
204
def _decay_weights_sparse_op (self , var , indices , apply_state = None ):
186
- if not self ._decay_var_list or var . ref () in self . _decay_var_list :
205
+ if self ._do_use_weight_decay ( var ) :
187
206
var_device , var_dtype = var .device , var .dtype .base_dtype
188
207
coefficients = (apply_state or {}).get (
189
208
(var_device , var_dtype )
@@ -226,6 +245,12 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
226
245
grad , var , indices , apply_state = apply_state
227
246
)
228
247
248
+ def _do_use_weight_decay (self , var ):
249
+ """Whether to use L2 weight decay for `var`."""
250
+ if self ._decay_var_list and var .ref () in self ._decay_var_list :
251
+ return True
252
+ return not is_variable_matched_by_regexes (var , self .exclude_from_weight_decay )
253
+
229
254
230
255
@typechecked
231
256
def extend_with_decoupled_weight_decay (
@@ -243,9 +268,13 @@ def extend_with_decoupled_weight_decay(
243
268
The API of the new optimizer class slightly differs from the API of the
244
269
base optimizer:
245
270
- The first argument to the constructor is the weight decay rate.
271
+ - Optional keyword argument `exclude_from_weight_decay` accepts list of
272
+ regex patterns of variables excluded from weight decay. Variables whose
273
+ name contain a substring matching the pattern will be excluded.
246
274
- `minimize` and `apply_gradients` accept the optional keyword argument
247
275
`decay_var_list`, which specifies the variables that should be decayed.
248
- If `None`, all variables that are optimized are decayed.
276
+ Note this takes priority over `exclude_from_weight_decay` if specified.
277
+ If both `None`, all variables that are optimized are decayed.
249
278
250
279
Usage example:
251
280
```python
@@ -376,12 +405,14 @@ def __init__(
376
405
nesterov: boolean. Whether to apply Nesterov momentum.
377
406
name: Optional name prefix for the operations created when applying
378
407
gradients. Defaults to 'SGD'.
379
- **kwargs: keyword arguments. Allowed to be {`clipnorm`,
380
- `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
381
- norm; `clipvalue` is clip gradients by value, `decay` is
382
- included for backward compatibility to allow time inverse decay
383
- of learning rate. `lr` is included for backward compatibility,
384
- recommended to use `learning_rate` instead.
408
+ **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
409
+ `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
410
+ gradients by norm; `clipvalue` is clip gradients by value.
411
+ `decay` is included for backward compatibility to allow time
412
+ inverse decay of learning rate. `lr` is included for backward
413
+ compatibility, recommended to use `learning_rate` instead.
414
+ `exclude_from_weight_decay` accepts list of regex patterns of
415
+ variables excluded from weight decay.
385
416
"""
386
417
super ().__init__ (
387
418
weight_decay ,
@@ -466,12 +497,14 @@ def __init__(
466
497
beyond".
467
498
name: Optional name for the operations created when applying
468
499
gradients. Defaults to "AdamW".
469
- **kwargs: keyword arguments. Allowed to be {`clipnorm`,
470
- `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
471
- norm; `clipvalue` is clip gradients by value, `decay` is
472
- included for backward compatibility to allow time inverse decay
473
- of learning rate. `lr` is included for backward compatibility,
474
- recommended to use `learning_rate` instead.
500
+ **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
501
+ `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip
502
+ gradients by norm; `clipvalue` is clip gradients by value.
503
+ `decay` is included for backward compatibility to allow time
504
+ inverse decay of learning rate. `lr` is included for backward
505
+ compatibility, recommended to use `learning_rate` instead.
506
+ `exclude_from_weight_decay` accepts list of regex patterns of
507
+ variables excluded from weight decay.
475
508
"""
476
509
super ().__init__ (
477
510
weight_decay ,
0 commit comments