13
13
# limitations under the License.
14
14
"""Utility functions that help in the computation of per-example gradient norms."""
15
15
16
- from typing import Any , Callable , Dict , Iterable , List , Optional , Text , Tuple , Union
16
+ from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Text , Tuple , Union
17
17
18
18
from absl import logging
19
19
import tensorflow as tf
@@ -36,19 +36,6 @@ def has_internal_compute_graph(input_object: Any):
36
36
)
37
37
38
38
39
- def _get_internal_layers (
40
- input_layer : tf .keras .layers .Layer ,
41
- ) -> List [tf .keras .layers .Layer ]:
42
- """Returns a list of layers that are nested within a given layer."""
43
- internal_layers = []
44
- if isinstance (input_layer , tf .keras .Model ) and hasattr (input_layer , 'layers' ):
45
- for layer in input_layer .layers :
46
- internal_layers .extend (_get_internal_layers (layer ))
47
- else :
48
- internal_layers .append (input_layer )
49
- return internal_layers
50
-
51
-
52
39
def model_forward_pass (
53
40
input_model : tf .keras .Model ,
54
41
inputs : PackedTensors ,
@@ -114,18 +101,10 @@ def generator_fn(layer_instance, args, kwargs):
114
101
generator_outputs_list .extend (node_generator_outputs )
115
102
else :
116
103
# Otherwise, we parse the node directly.
117
- node_layers = _get_internal_layers (node .layer )
118
- for layer in node_layers :
119
- node_layer_outputs , layer_generator_outputs = generator_fn (
120
- layer , args , kwargs
121
- )
122
- generator_outputs_list .append (layer_generator_outputs )
123
- args = (
124
- node_layer_outputs
125
- if isinstance (node_layer_outputs , tuple )
126
- else (node_layer_outputs ,)
127
- )
128
- kwargs = {}
104
+ node_layer_outputs , layer_generator_outputs = generator_fn (
105
+ node .layer , args , kwargs
106
+ )
107
+ generator_outputs_list .append (layer_generator_outputs )
129
108
130
109
# Update the current dictionary of inputs for the next node.
131
110
for x_id , y in zip (
@@ -163,9 +142,8 @@ def all_trainable_layers_are_registered(
163
142
False otherwise.
164
143
"""
165
144
for layer in input_model .layers :
166
- for sublayer in _get_internal_layers (layer ):
167
- if not layer_registry .is_elem (sublayer ) and sublayer .trainable_variables :
168
- return False
145
+ if not layer_registry .is_elem (layer ) and layer .trainable_variables :
146
+ return False
169
147
return True
170
148
171
149
@@ -213,17 +191,53 @@ def add_noise(g):
213
191
214
192
def generate_model_outputs_using_core_keras_layers (
215
193
input_model : tf .keras .Model ,
194
+ custom_layer_set : Optional [Set [type ]] = None , # pylint: disable=g-bare-generic
216
195
) -> PackedTensors :
217
- """Returns the model outputs generated by only core Keras layers."""
218
- cust_obj_dict = dict .copy (tf .keras .utils .get_custom_objects ())
219
- cust_hash_set = set ([hash (v ) for v in cust_obj_dict .values ()])
196
+ """Returns the model outputs generated by only core Keras layers.
197
+
198
+ Args:
199
+ input_model: A `tf.keras.Model` instance to obtain outputs from.
200
+ custom_layer_set: An optional `set` of custom layers to expand. If `None`,
201
+ then this is the set of all registered custom Keras layers.
202
+
203
+ Returns:
204
+ A `tf.Tensor` that is the result of `input_model(input_model.inputs)`
205
+ using only Keras layers that are not in `custom_layer_set`.
206
+ """
207
+ # Set up helper variables and functions.
208
+ custom_layer_set = (
209
+ custom_layer_set or tf .keras .utils .get_custom_objects ().values ()
210
+ )
211
+
212
+ def _is_core (layer_instance ):
213
+ return type (layer_instance ) not in custom_layer_set
220
214
221
215
def generator_fn (layer_instance , args , kwargs ):
222
- if hash (layer_instance .__class__ ) in cust_hash_set :
223
- # Using `.call()` does not register the layer in the compute graph of
224
- # a forward pass.
225
- return layer_instance .call (* args , ** kwargs ), None
226
- else :
227
- return layer_instance (* args , ** kwargs ), None
216
+ # Using `.call()` does not register the layer in the compute graph of
217
+ # a forward pass.
218
+ layer_outputs = (
219
+ layer_instance (* args , ** kwargs )
220
+ if _is_core (layer_instance )
221
+ else layer_instance .call (* args , ** kwargs )
222
+ )
223
+ return layer_outputs , None
224
+
225
+ # Return early if all the existing layers contain only core layers.
226
+ if all (_is_core (layer ) for layer in input_model .layers ):
227
+ return model_forward_pass (input_model , input_model .inputs )[0 ]
228
228
229
- return model_forward_pass (input_model , input_model .inputs , generator_fn )[0 ]
229
+ # Do a forward pass to expand the outermost layers.
230
+ candidate_outputs , _ = model_forward_pass (
231
+ input_model , input_model .inputs , generator_fn
232
+ )
233
+
234
+ # The following recursion is inefficient because it recursively builds `n`
235
+ # Keras model graphs, where `n` is the number of recursive calls. However,
236
+ # it appears to be the only valid approach without accessing Keras's internal
237
+ # functions (e.g., `keras.engine.functional._map_graph_network()`).
238
+ cleaned_model = tf .keras .Model (
239
+ inputs = input_model .inputs , outputs = candidate_outputs
240
+ )
241
+ return generate_model_outputs_using_core_keras_layers (
242
+ cleaned_model , custom_layer_set
243
+ )
0 commit comments