diff --git a/tensorflow_addons/optimizers/lamb.py b/tensorflow_addons/optimizers/lamb.py index d3f9abbd75..b166657251 100644 --- a/tensorflow_addons/optimizers/lamb.py +++ b/tensorflow_addons/optimizers/lamb.py @@ -18,7 +18,6 @@ 76 minutes](https://arxiv.org/abs/1904.00962). """ -import re import warnings from typing import Optional, Union, Callable, List @@ -26,6 +25,7 @@ import tensorflow as tf from tensorflow_addons.utils.types import FloatTensorLike +from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes @tf.keras.utils.register_keras_serializable(package="Addons") @@ -163,12 +163,11 @@ def _resource_apply_dense(self, grad, var, apply_state=None): v_sqrt = tf.sqrt(v_t_hat) update = m_t_hat / (v_sqrt + coefficients["epsilon"]) - var_name = self._get_variable_name(var.name) - if self._do_use_weight_decay(var_name): + if self._do_use_weight_decay(var): update += coefficients["weight_decay"] * var ratio = 1.0 - if self._do_layer_adaptation(var_name): + if self._do_layer_adaptation(var): w_norm = tf.norm(var, ord=2) g_norm = tf.norm(update, ord=2) ratio = tf.where( @@ -206,12 +205,11 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None): v_sqrt = tf.sqrt(v_t_hat) update = m_t_hat / (v_sqrt + coefficients["epsilon"]) - var_name = self._get_variable_name(var.name) - if self._do_use_weight_decay(var_name): + if self._do_use_weight_decay(var): update += coefficients["weight_decay"] * var ratio = 1.0 - if self._do_layer_adaptation(var_name): + if self._do_layer_adaptation(var): w_norm = tf.norm(var, ord=2) g_norm = tf.norm(update, ord=2) ratio = tf.where( @@ -241,26 +239,15 @@ def get_config(self): ) return config - def _do_use_weight_decay(self, param_name): + def _do_use_weight_decay(self, variable): """Whether to use L2 weight decay for `param_name`.""" - if self.exclude_from_weight_decay: - for r in self.exclude_from_weight_decay: - if re.search(r, param_name) is not None: - return False - return True + return not is_variable_matched_by_regexes( + variable, self.exclude_from_weight_decay + ) - def _do_layer_adaptation(self, param_name): + def _do_layer_adaptation(self, variable): """Whether to do layer-wise learning rate adaptation for `param_name`.""" - if self.exclude_from_layer_adaptation: - for r in self.exclude_from_layer_adaptation: - if re.search(r, param_name) is not None: - return False - return True - - def _get_variable_name(self, param_name): - """Get the variable name from the tensor name.""" - m = re.match("^(.*):\\d+$", param_name) - if m is not None: - param_name = m.group(1) - return param_name + return not is_variable_matched_by_regexes( + variable, self.exclude_from_layer_adaptation + ) diff --git a/tensorflow_addons/optimizers/tests/lamb_test.py b/tensorflow_addons/optimizers/tests/lamb_test.py index 631aed99a5..e80f7e4ae6 100644 --- a/tensorflow_addons/optimizers/tests/lamb_test.py +++ b/tensorflow_addons/optimizers/tests/lamb_test.py @@ -335,20 +335,22 @@ def test_get_config(): def test_exclude_weight_decay(): opt = lamb.LAMB(0.01, weight_decay=0.01, exclude_from_weight_decay=["var1"]) - assert opt._do_use_weight_decay("var0") - assert not opt._do_use_weight_decay("var1") - assert not opt._do_use_weight_decay("var1_weight") + assert opt._do_use_weight_decay(tf.Variable([], name="var0")) + assert not opt._do_use_weight_decay(tf.Variable([], name="var1")) + assert not opt._do_use_weight_decay(tf.Variable([], name="var1_weight")) def test_exclude_layer_adaptation(): opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"]) - assert opt._do_layer_adaptation("var0") - assert not opt._do_layer_adaptation("var1") - assert not opt._do_layer_adaptation("var1_weight") + assert opt._do_layer_adaptation(tf.Variable([], name="var0")) + assert not opt._do_layer_adaptation(tf.Variable([], name="var1")) + assert not opt._do_layer_adaptation(tf.Variable([], name="var1_weight")) def test_serialization(): - optimizer = lamb.LAMB(1e-4) + optimizer = lamb.LAMB( + 1e-4, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"] + ) config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() diff --git a/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py b/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py index 6a832585c7..d9ac1fe0f1 100644 --- a/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py +++ b/tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py @@ -80,6 +80,7 @@ def do_test( opt = optimizer(**optimizer_kwargs) # Create the update op. # Run 3 steps of the optimizer + optimizer_kwargs.pop("exclude_from_weight_decay", None) for _ in range(3): if do_decay_var_list: opt.apply_gradients( @@ -241,6 +242,31 @@ def test_basic_decay_var_list_adamw(dtype): ) +def test_exclude_weight_decay_adamw(): + optimizer = weight_decay_optimizers.AdamW( + learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) + assert optimizer._do_use_weight_decay(tf.Variable([], name="var0")) + assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1")) + assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight")) + + +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_var_list_with_exclude_list_adamw(dtype): + do_test( + dtype, + weight_decay_optimizers.AdamW, + adamw_update_numpy, + do_decay_var_list=True, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + weight_decay=WEIGHT_DECAY, + exclude_from_weight_decay=["var0_*", "var1_*"], + ) + + def test_keras_fit(): """Check if calling model.fit works.""" model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)]) @@ -341,6 +367,30 @@ def test_basic_decay_var_list_sgdw(dtype): ) +def test_exclude_weight_decay_sgdw(): + optimizer = weight_decay_optimizers.SGDW( + learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) + assert optimizer._do_use_weight_decay(tf.Variable([], name="var0")) + assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1")) + assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight")) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)]) +def test_var_list_with_exclude_list_sgdw(dtype): + do_test( + dtype, + weight_decay_optimizers.SGDW, + sgdw_update_numpy, + do_decay_var_list=True, + learning_rate=0.001, + momentum=0.9, + weight_decay=WEIGHT_DECAY, + exclude_from_weight_decay=["var0_*", "var1_*"], + ) + + @pytest.mark.parametrize( "optimizer", [ @@ -379,7 +429,9 @@ def test_optimizer_sparse(dtype, optimizer): def test_serialization(): - optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4) + optimizer = weight_decay_optimizers.AdamW( + learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"] + ) config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() diff --git a/tensorflow_addons/optimizers/utils.py b/tensorflow_addons/optimizers/utils.py index af365edbff..91d0d92b1e 100644 --- a/tensorflow_addons/optimizers/utils.py +++ b/tensorflow_addons/optimizers/utils.py @@ -14,7 +14,9 @@ # ============================================================================== """Additional Utilities used for tfa.optimizers.""" +import re import tensorflow as tf +from typing import List def fit_bn(model, *args, **kwargs): @@ -51,3 +53,23 @@ def fit_bn(model, *args, **kwargs): model.trainable = _trainable model._metrics = _metrics + + +def get_variable_name(variable) -> str: + """Get the variable name from the variable tensor.""" + param_name = variable.name + m = re.match("^(.*):\\d+$", param_name) + if m is not None: + param_name = m.group(1) + return param_name + + +def is_variable_matched_by_regexes(variable, regexes: List[str]) -> bool: + """Whether variable is matched in regexes list by its name.""" + if regexes: + # var_name = get_variable_name(variable) + var_name = variable.name + for r in regexes: + if re.search(r, var_name): + return True + return False diff --git a/tensorflow_addons/optimizers/weight_decay_optimizers.py b/tensorflow_addons/optimizers/weight_decay_optimizers.py index bf26d03bfd..3d882b0169 100644 --- a/tensorflow_addons/optimizers/weight_decay_optimizers.py +++ b/tensorflow_addons/optimizers/weight_decay_optimizers.py @@ -16,9 +16,10 @@ import tensorflow as tf from tensorflow_addons.utils.types import FloatTensorLike +from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes from typeguard import typechecked -from typing import Union, Callable, Type +from typing import Union, Callable, Type, Optional, List class DecoupledWeightDecayExtension: @@ -71,13 +72,23 @@ def __init__(self, weight_decay, *args, **kwargs): """ @typechecked - def __init__(self, weight_decay: Union[FloatTensorLike, Callable], **kwargs): + def __init__( + self, + weight_decay: Union[FloatTensorLike, Callable], + exclude_from_weight_decay: Optional[List[str]] = None, + **kwargs, + ): """Extension class that adds weight decay to an optimizer. Args: weight_decay: A `Tensor`, a floating point value, or a schedule that is a `tf.keras.optimizers.schedules.LearningRateSchedule` to decay the variable by, in the update step. + exclude_from_weight_decay: List of regex patterns of + variables excluded from weight decay. Variables whose name + contain a substring matching the pattern will be excluded. + Note `decay_var_list` in `minimize` or `apply_gradients` takes + priority over `exclude_from_weight_decay` if specified. **kwargs: Optional list or tuple or set of `Variable` objects to decay. """ @@ -85,10 +96,16 @@ def __init__(self, weight_decay: Union[FloatTensorLike, Callable], **kwargs): super().__init__(**kwargs) self._decay_var_list = None # is set in minimize or apply_gradients self._set_hyper("weight_decay", wd) + self.exclude_from_weight_decay = exclude_from_weight_decay def get_config(self): config = super().get_config() - config.update({"weight_decay": self._serialize_hyperparameter("weight_decay")}) + config.update( + { + "weight_decay": self._serialize_hyperparameter("weight_decay"), + "exclude_from_weight_decay": self.exclude_from_weight_decay, + } + ) return config @classmethod @@ -130,7 +147,8 @@ def minimize( grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. decay_var_list: Optional list of variables to be decayed. Defaults - to all variables in var_list. + to all variables in var_list. Note `decay_var_list` takes + priority over `exclude_from_weight_decay` if specified. name: Optional name for the returned operation. tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`, the tape that computed the `loss` must be provided. @@ -154,10 +172,11 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar Args: grads_and_vars: List of (gradient, variable) pairs. - name: Optional name for the returned operation. Default to the + name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. decay_var_list: Optional list of variables to be decayed. Defaults - to all variables in var_list. + to all variables in var_list. Note `decay_var_list` takes + priority over `exclude_from_weight_decay` if specified. **kwargs: Additional arguments to pass to the base optimizer's apply_gradient method, e.g., TF2.2 added an argument `experimental_aggregate_gradients`. @@ -173,7 +192,7 @@ def apply_gradients(self, grads_and_vars, name=None, decay_var_list=None, **kwar return super().apply_gradients(grads_and_vars, name=name, **kwargs) def _decay_weights_op(self, var, apply_state=None): - if not self._decay_var_list or var.ref() in self._decay_var_list: + if self._do_use_weight_decay(var): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype) @@ -183,7 +202,7 @@ def _decay_weights_op(self, var, apply_state=None): return tf.no_op() def _decay_weights_sparse_op(self, var, indices, apply_state=None): - if not self._decay_var_list or var.ref() in self._decay_var_list: + if self._do_use_weight_decay(var): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (apply_state or {}).get( (var_device, var_dtype) @@ -226,6 +245,12 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None): grad, var, indices, apply_state=apply_state ) + def _do_use_weight_decay(self, var): + """Whether to use L2 weight decay for `var`.""" + if self._decay_var_list and var.ref() in self._decay_var_list: + return True + return not is_variable_matched_by_regexes(var, self.exclude_from_weight_decay) + @typechecked def extend_with_decoupled_weight_decay( @@ -243,9 +268,13 @@ def extend_with_decoupled_weight_decay( The API of the new optimizer class slightly differs from the API of the base optimizer: - The first argument to the constructor is the weight decay rate. + - Optional keyword argument `exclude_from_weight_decay` accepts list of + regex patterns of variables excluded from weight decay. Variables whose + name contain a substring matching the pattern will be excluded. - `minimize` and `apply_gradients` accept the optional keyword argument `decay_var_list`, which specifies the variables that should be decayed. - If `None`, all variables that are optimized are decayed. + Note this takes priority over `exclude_from_weight_decay` if specified. + If both `None`, all variables that are optimized are decayed. Usage example: ```python @@ -376,12 +405,14 @@ def __init__( nesterov: boolean. Whether to apply Nesterov momentum. name: Optional name prefix for the operations created when applying gradients. Defaults to 'SGD'. - **kwargs: keyword arguments. Allowed to be {`clipnorm`, - `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by - norm; `clipvalue` is clip gradients by value, `decay` is - included for backward compatibility to allow time inverse decay - of learning rate. `lr` is included for backward compatibility, - recommended to use `learning_rate` instead. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip + gradients by norm; `clipvalue` is clip gradients by value. + `decay` is included for backward compatibility to allow time + inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + `exclude_from_weight_decay` accepts list of regex patterns of + variables excluded from weight decay. """ super().__init__( weight_decay, @@ -466,12 +497,14 @@ def __init__( beyond". name: Optional name for the operations created when applying gradients. Defaults to "AdamW". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, - `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by - norm; `clipvalue` is clip gradients by value, `decay` is - included for backward compatibility to allow time inverse decay - of learning rate. `lr` is included for backward compatibility, - recommended to use `learning_rate` instead. + **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, + `lr`, `decay`, `exclude_from_weight_decay`}. `clipnorm` is clip + gradients by norm; `clipvalue` is clip gradients by value. + `decay` is included for backward compatibility to allow time + inverse decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + `exclude_from_weight_decay` accepts list of regex patterns of + variables excluded from weight decay. """ super().__init__( weight_decay,