Skip to content

Commit e4279c4

Browse files
author
Jack Fan
authored
Make MultiOptimizer serializable (#2719)
Make MultiOptimizer serializable
1 parent ee0df43 commit e4279c4

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

tensorflow_addons/optimizers/discriminative_layer_training.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,15 @@ def apply_gradients(self, grads_and_vars, **kwargs):
145145

146146
def get_config(self):
147147
config = super(MultiOptimizer, self).get_config()
148-
config.update({"optimizer_specs": self.optimizer_specs})
148+
optimizer_specs_without_gv = []
149+
for optimizer_spec in self.optimizer_specs:
150+
optimizer_specs_without_gv.append(
151+
{
152+
"optimizer": optimizer_spec["optimizer"],
153+
"weights": optimizer_spec["weights"],
154+
}
155+
)
156+
config.update({"optimizer_specs": optimizer_specs_without_gv})
149157
return config
150158

151159
@classmethod

tensorflow_addons/optimizers/tests/discriminative_layer_training_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,38 @@ def test_serialization():
286286

287287
new_optimizer = tf.keras.optimizers.deserialize(config)
288288
assert new_optimizer.get_config() == optimizer.get_config()
289+
290+
291+
def test_serialization_after_training(tmpdir):
292+
x = np.array(np.ones([100]))
293+
y = np.array(np.ones([100]))
294+
model = tf.keras.Sequential(
295+
[tf.keras.Input(shape=[1]), tf.keras.layers.Dense(1), tf.keras.layers.Dense(1)]
296+
)
297+
298+
opt1 = tf.keras.optimizers.Adam(learning_rate=1e-3)
299+
opt2 = tf.keras.optimizers.SGD(learning_rate=0)
300+
301+
opt_layer_pairs = [(opt1, model.layers[0]), (opt2, model.layers[1])]
302+
303+
optimizer = MultiOptimizer(opt_layer_pairs)
304+
305+
# Train the model for a few epochs.
306+
model.compile(loss="categorical_crossentropy", optimizer=optimizer)
307+
model.fit(x, y)
308+
309+
# Verify the optimizer can still be serialized (saved).
310+
model.save(str(tmpdir))
311+
loaded_model = tf.keras.models.load_model(str(tmpdir))
312+
old_config = model.optimizer.get_config()
313+
new_config = loaded_model.optimizer.get_config()
314+
# Verify the loaded model has the same optimizer as before.
315+
assert len(old_config["optimizer_specs"]) == len(new_config["optimizer_specs"])
316+
for old_optimizer_spec, new_optimizer_spec in zip(
317+
old_config["optimizer_specs"], new_config["optimizer_specs"]
318+
):
319+
assert old_optimizer_spec["weights"] == new_optimizer_spec["weights"]
320+
assert (
321+
old_optimizer_spec["optimizer"].get_config()
322+
== new_optimizer_spec["optimizer"].get_config()
323+
)

0 commit comments

Comments
 (0)