|
23 | 23 | from tensorflow_gnn.graph import graph_constants as const
|
24 | 24 | from tensorflow_gnn.graph import graph_tensor as gt
|
25 | 25 | from tensorflow_gnn.keras.layers import graph_ops
|
26 |
| -# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top |
27 |
| -if not tf.__version__.startswith("2.20."): # TODO: b/441006328 - Remove this. |
28 |
| - # The following import crashes with tf-nightly~=2.20.0. |
29 |
| - from ai_edge_litert import interpreter as tfl_interpreter |
30 |
| -# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top |
| 26 | +# pylint: disable=g-direct-tensorflow-import |
| 27 | +from ai_edge_litert import interpreter as tfl_interpreter |
| 28 | +# pylint: enable=g-direct-tensorflow-import |
31 | 29 |
|
32 | 30 |
|
33 | 31 | class ReadoutTest(tf.test.TestCase, parameterized.TestCase):
|
@@ -172,9 +170,6 @@ def testTFLite(self, location):
|
172 | 170 |
|
173 | 171 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
174 | 172 | model_content = converter.convert()
|
175 |
| - if tf.__version__.startswith("2.20."): |
176 |
| - self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported " |
177 |
| - f"next to tf-nightly~=2.20.0; got TF {tf.__version__}") |
178 | 173 | interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
179 | 174 | signature_runner = interpreter.get_signature_runner("serving_default")
|
180 | 175 | obtained = signature_runner(**test_graph_134_dict)["test_readout"]
|
@@ -303,12 +298,8 @@ def testTFLite(self):
|
303 | 298 | model = tf.keras.Model(inputs, outputs)
|
304 | 299 | expected = model(test_graph_22_dict)
|
305 | 300 |
|
306 |
| - # TODO(b/276291104): Remove when TF 2.11+ is required by all of TFGNN |
307 | 301 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
308 | 302 | model_content = converter.convert()
|
309 |
| - if tf.__version__.startswith("2.20."): |
310 |
| - self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported " |
311 |
| - f"next to tf-nightly~=2.20.0; got TF {tf.__version__}") |
312 | 303 | interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
313 | 304 | signature_runner = interpreter.get_signature_runner("serving_default")
|
314 | 305 | obtained = signature_runner(**test_graph_22_dict)["test_readout_first"]
|
@@ -436,9 +427,6 @@ def testTFLite(self):
|
436 | 427 |
|
437 | 428 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
438 | 429 | model_content = converter.convert()
|
439 |
| - if tf.__version__.startswith("2.20."): |
440 |
| - self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported " |
441 |
| - f"next to tf-nightly~=2.20.0; got TF {tf.__version__}") |
442 | 430 | interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
443 | 431 | signature_runner = interpreter.get_signature_runner("serving_default")
|
444 | 432 | actual = signature_runner(
|
@@ -573,9 +561,6 @@ def testTFLite(self):
|
573 | 561 |
|
574 | 562 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
575 | 563 | model_content = converter.convert()
|
576 |
| - if tf.__version__.startswith("2.20."): |
577 |
| - self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported " |
578 |
| - f"next to tf-nightly~=2.20.0; got TF {tf.__version__}") |
579 | 564 | interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
580 | 565 | signature_runner = interpreter.get_signature_runner("serving_default")
|
581 | 566 | actual = signature_runner(
|
@@ -639,9 +624,6 @@ def testTFLite(self):
|
639 | 624 |
|
640 | 625 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
641 | 626 | model_content = converter.convert()
|
642 |
| - if tf.__version__.startswith("2.20."): |
643 |
| - self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported " |
644 |
| - f"next to tf-nightly~=2.20.0; got TF {tf.__version__}") |
645 | 627 | interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
646 | 628 | signature_runner = interpreter.get_signature_runner("serving_default")
|
647 | 629 | actual = signature_runner(
|
@@ -760,9 +742,6 @@ def testTFLite(self):
|
760 | 742 |
|
761 | 743 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
762 | 744 | model_content = converter.convert()
|
763 |
| - if tf.__version__.startswith("2.20."): |
764 |
| - self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported " |
765 |
| - f"next to tf-nightly~=2.20.0; got TF {tf.__version__}") |
766 | 745 | interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
767 | 746 | signature_runner = interpreter.get_signature_runner("serving_default")
|
768 | 747 | obtained = signature_runner(**test_graph_134_dict)["final_edge_states"]
|
@@ -961,9 +940,6 @@ def testTFLite(self, tag, location):
|
961 | 940 |
|
962 | 941 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
963 | 942 | model_content = converter.convert()
|
964 |
| - if tf.__version__.startswith("2.20."): |
965 |
| - self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported " |
966 |
| - f"next to tf-nightly~=2.20.0; got TF {tf.__version__}") |
967 | 943 | interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
968 | 944 | signature_runner = interpreter.get_signature_runner("serving_default")
|
969 | 945 | obtained = signature_runner(**test_values)["test_broadcast"]
|
@@ -1268,9 +1244,6 @@ def testTFLite(self, tag, location, reduce_type):
|
1268 | 1244 |
|
1269 | 1245 | converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
1270 | 1246 | model_content = converter.convert()
|
1271 |
| - if tf.__version__.startswith("2.20."): |
1272 |
| - self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported " |
1273 |
| - f"next to tf-nightly~=2.20.0; got TF {tf.__version__}") |
1274 | 1247 | interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
1275 | 1248 | signature_runner = interpreter.get_signature_runner("serving_default")
|
1276 | 1249 | obtained = signature_runner(**test_values)["test_pool"]
|
|
0 commit comments