Skip to content

Commit 6280314

Browse files
Minor bug fix to MultiHeadAttention registry function.
Covers the case where a layer arg can be passed as a layer kwarg. PiperOrigin-RevId: 520394717
1 parent abb0c3f commit 6280314

File tree

2 files changed

+154
-39
lines changed

2 files changed

+154
-39
lines changed

tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Utility functions that help in the computation of per-example gradient norms."""
1515

16-
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
16+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Text, Tuple, Union
1717

1818
from absl import logging
1919
import tensorflow as tf
@@ -36,19 +36,6 @@ def has_internal_compute_graph(input_object: Any):
3636
)
3737

3838

39-
def _get_internal_layers(
40-
input_layer: tf.keras.layers.Layer,
41-
) -> List[tf.keras.layers.Layer]:
42-
"""Returns a list of layers that are nested within a given layer."""
43-
internal_layers = []
44-
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
45-
for layer in input_layer.layers:
46-
internal_layers.extend(_get_internal_layers(layer))
47-
else:
48-
internal_layers.append(input_layer)
49-
return internal_layers
50-
51-
5239
def model_forward_pass(
5340
input_model: tf.keras.Model,
5441
inputs: PackedTensors,
@@ -114,18 +101,10 @@ def generator_fn(layer_instance, args, kwargs):
114101
generator_outputs_list.extend(node_generator_outputs)
115102
else:
116103
# Otherwise, we parse the node directly.
117-
node_layers = _get_internal_layers(node.layer)
118-
for layer in node_layers:
119-
node_layer_outputs, layer_generator_outputs = generator_fn(
120-
layer, args, kwargs
121-
)
122-
generator_outputs_list.append(layer_generator_outputs)
123-
args = (
124-
node_layer_outputs
125-
if isinstance(node_layer_outputs, tuple)
126-
else (node_layer_outputs,)
127-
)
128-
kwargs = {}
104+
node_layer_outputs, layer_generator_outputs = generator_fn(
105+
node.layer, args, kwargs
106+
)
107+
generator_outputs_list.append(layer_generator_outputs)
129108

130109
# Update the current dictionary of inputs for the next node.
131110
for x_id, y in zip(
@@ -163,9 +142,8 @@ def all_trainable_layers_are_registered(
163142
False otherwise.
164143
"""
165144
for layer in input_model.layers:
166-
for sublayer in _get_internal_layers(layer):
167-
if not layer_registry.is_elem(sublayer) and sublayer.trainable_variables:
168-
return False
145+
if not layer_registry.is_elem(layer) and layer.trainable_variables:
146+
return False
169147
return True
170148

171149

@@ -213,17 +191,53 @@ def add_noise(g):
213191

214192
def generate_model_outputs_using_core_keras_layers(
215193
input_model: tf.keras.Model,
194+
custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic
216195
) -> PackedTensors:
217-
"""Returns the model outputs generated by only core Keras layers."""
218-
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
219-
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
196+
"""Returns the model outputs generated by only core Keras layers.
197+
198+
Args:
199+
input_model: A `tf.keras.Model` instance to obtain outputs from.
200+
custom_layer_set: An optional `set` of custom layers to expand. If `None`,
201+
then this is the set of all registered custom Keras layers.
202+
203+
Returns:
204+
A `tf.Tensor` that is the result of `input_model(input_model.inputs)`
205+
using only Keras layers that are not in `custom_layer_set`.
206+
"""
207+
# Set up helper variables and functions.
208+
custom_layer_set = (
209+
custom_layer_set or tf.keras.utils.get_custom_objects().values()
210+
)
211+
212+
def _is_core(layer_instance):
213+
return type(layer_instance) not in custom_layer_set
220214

221215
def generator_fn(layer_instance, args, kwargs):
222-
if hash(layer_instance.__class__) in cust_hash_set:
223-
# Using `.call()` does not register the layer in the compute graph of
224-
# a forward pass.
225-
return layer_instance.call(*args, **kwargs), None
226-
else:
227-
return layer_instance(*args, **kwargs), None
216+
# Using `.call()` does not register the layer in the compute graph of
217+
# a forward pass.
218+
layer_outputs = (
219+
layer_instance(*args, **kwargs)
220+
if _is_core(layer_instance)
221+
else layer_instance.call(*args, **kwargs)
222+
)
223+
return layer_outputs, None
224+
225+
# Return early if all the existing layers contain only core layers.
226+
if all(_is_core(layer) for layer in input_model.layers):
227+
return model_forward_pass(input_model, input_model.inputs)[0]
228228

229-
return model_forward_pass(input_model, input_model.inputs, generator_fn)[0]
229+
# Do a forward pass to expand the outermost layers.
230+
candidate_outputs, _ = model_forward_pass(
231+
input_model, input_model.inputs, generator_fn
232+
)
233+
234+
# The following recursion is inefficient because it recursively builds `n`
235+
# Keras model graphs, where `n` is the number of recursive calls. However,
236+
# it appears to be the only valid approach without accessing Keras's internal
237+
# functions (e.g., `keras.engine.functional._map_graph_network()`).
238+
cleaned_model = tf.keras.Model(
239+
inputs=input_model.inputs, outputs=candidate_outputs
240+
)
241+
return generate_model_outputs_using_core_keras_layers(
242+
cleaned_model, custom_layer_set
243+
)

tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,72 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Any
16+
1517
from absl.testing import parameterized
1618
import tensorflow as tf
1719

1820
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
1921

2022

23+
# ==============================================================================
24+
# Helper functions and classes.
25+
# ==============================================================================
26+
@tf.keras.utils.register_keras_serializable('gradient_clipping_utils_test')
27+
class DoubleDense(tf.keras.layers.Layer):
28+
"""Generates two dense layers nested together."""
29+
30+
def __init__(self, units: int):
31+
super().__init__()
32+
self.dense1 = tf.keras.layers.Dense(units, name='DDense_ext_1')
33+
self.dense2 = tf.keras.layers.Dense(1, name='DDense_ext_2')
34+
35+
def call(self, inputs: Any):
36+
x = self.dense1(inputs)
37+
return self.dense2(x)
38+
39+
40+
@tf.keras.utils.register_keras_serializable('gradient_clipping_utils_test')
41+
class TripleDense(tf.keras.layers.Layer):
42+
"""Generates three dense layers nested together."""
43+
44+
def __init__(self, units: int):
45+
super().__init__()
46+
self.dense1 = tf.keras.layers.Dense(units, name='TDense_ext_1')
47+
self.dense2 = tf.keras.layers.Dense(units, name='TDense_ext_2')
48+
self.dense3 = tf.keras.layers.Dense(1, name='TDense_ext_3')
49+
50+
def call(self, inputs: Any):
51+
x1 = self.dense1(inputs)
52+
x2 = self.dense2(x1)
53+
return self.dense3(x2)
54+
55+
56+
def get_reduced_model(sample_inputs, hidden_layer_list, new_custom_layers=None):
57+
"""Reduces a set of layers to only core Keras layers in a model."""
58+
sample_outputs = sample_inputs
59+
for l in hidden_layer_list:
60+
sample_outputs = l(sample_outputs)
61+
custom_model = tf.keras.Model(inputs=sample_inputs, outputs=sample_outputs)
62+
if new_custom_layers:
63+
reduced_outputs = (
64+
gradient_clipping_utils.generate_model_outputs_using_core_keras_layers(
65+
custom_model,
66+
custom_layer_set=new_custom_layers,
67+
)
68+
)
69+
else:
70+
reduced_outputs = (
71+
gradient_clipping_utils.generate_model_outputs_using_core_keras_layers(
72+
custom_model
73+
)
74+
)
75+
return tf.keras.Model(inputs=custom_model.inputs, outputs=reduced_outputs)
76+
77+
78+
# ==============================================================================
79+
# Main tests.
80+
# ==============================================================================
2181
class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):
2282

2383
@parameterized.product(
@@ -75,5 +135,46 @@ def test_outputs_are_consistent(
75135
self.assertAllClose(computed_outputs, true_outputs)
76136

77137

138+
class GenerateOutputsUsingCoreKerasLayers(
139+
tf.test.TestCase, parameterized.TestCase
140+
):
141+
142+
def test_single_custom_layer_is_reduced(self):
143+
num_units = 5
144+
num_dims = 3
145+
reduced_model = get_reduced_model(
146+
tf.keras.Input(num_dims),
147+
[DoubleDense(num_units)],
148+
)
149+
# Ignore the input layer.
150+
for l in reduced_model.layers[1:]:
151+
self.assertIsInstance(l, tf.keras.layers.Dense)
152+
153+
def test_two_distinct_custom_layers_are_reduced(self):
154+
num_units = 5
155+
num_dims = 3
156+
reduced_model = get_reduced_model(
157+
tf.keras.Input(num_dims),
158+
[DoubleDense(num_units), TripleDense(num_units)],
159+
)
160+
# Ignore the input layer.
161+
for l in reduced_model.layers[1:]:
162+
self.assertIsInstance(l, tf.keras.layers.Dense)
163+
164+
def test_new_custom_layer_spec(self):
165+
num_units = 5
166+
num_dims = 3
167+
reduced_model = get_reduced_model(
168+
tf.keras.Input(num_dims),
169+
[DoubleDense(num_units), TripleDense(num_units)],
170+
new_custom_layers=set([DoubleDense]),
171+
)
172+
# Ignore the input layer.
173+
for l in reduced_model.layers[1:]:
174+
self.assertTrue(
175+
isinstance(l, tf.keras.layers.Dense) or isinstance(l, TripleDense)
176+
)
177+
178+
78179
if __name__ == '__main__':
79180
tf.test.main()

0 commit comments

Comments
 (0)