@@ -286,3 +286,38 @@ def test_serialization():
286
286
287
287
new_optimizer = tf .keras .optimizers .deserialize (config )
288
288
assert new_optimizer .get_config () == optimizer .get_config ()
289
+
290
+
291
+ def test_serialization_after_training (tmpdir ):
292
+ x = np .array (np .ones ([100 ]))
293
+ y = np .array (np .ones ([100 ]))
294
+ model = tf .keras .Sequential (
295
+ [tf .keras .Input (shape = [1 ]), tf .keras .layers .Dense (1 ), tf .keras .layers .Dense (1 )]
296
+ )
297
+
298
+ opt1 = tf .keras .optimizers .Adam (learning_rate = 1e-3 )
299
+ opt2 = tf .keras .optimizers .SGD (learning_rate = 0 )
300
+
301
+ opt_layer_pairs = [(opt1 , model .layers [0 ]), (opt2 , model .layers [1 ])]
302
+
303
+ optimizer = MultiOptimizer (opt_layer_pairs )
304
+
305
+ # Train the model for a few epochs.
306
+ model .compile (loss = "categorical_crossentropy" , optimizer = optimizer )
307
+ model .fit (x , y )
308
+
309
+ # Verify the optimizer can still be serialized (saved).
310
+ model .save (str (tmpdir ))
311
+ loaded_model = tf .keras .models .load_model (str (tmpdir ))
312
+ old_config = model .optimizer .get_config ()
313
+ new_config = loaded_model .optimizer .get_config ()
314
+ # Verify the loaded model has the same optimizer as before.
315
+ assert len (old_config ["optimizer_specs" ]) == len (new_config ["optimizer_specs" ])
316
+ for old_optimizer_spec , new_optimizer_spec in zip (
317
+ old_config ["optimizer_specs" ], new_config ["optimizer_specs" ]
318
+ ):
319
+ assert old_optimizer_spec ["weights" ] == new_optimizer_spec ["weights" ]
320
+ assert (
321
+ old_optimizer_spec ["optimizer" ].get_config ()
322
+ == new_optimizer_spec ["optimizer" ].get_config ()
323
+ )
0 commit comments