Skip to content

Commit 4ebc3f9

Browse files
arnoegwtensorflower-gardener
authored andcommitted
Reinstate the tests that were skipped for TF/Keras [nightly] 2.20 by
commit f9b8f5c PiperOrigin-RevId: 803321430
1 parent d62bc48 commit 4ebc3f9

File tree

12 files changed

+36
-124
lines changed

12 files changed

+36
-124
lines changed

tensorflow_gnn/graph/graph_tensor_ops_test.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@
2525
from tensorflow_gnn.graph import graph_tensor_ops as ops
2626
from tensorflow_gnn.graph import pool_ops
2727
from tensorflow_gnn.graph import readout
28-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
29-
if not tf.__version__.startswith('2.20.'): # TODO: b/441006328 - Remove this.
30-
# The following import crashes with tf-nightly~=2.20.0.
31-
from ai_edge_litert import interpreter as tfl_interpreter
32-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
28+
# pylint: disable=g-direct-tensorflow-import
29+
from ai_edge_litert import interpreter as tfl_interpreter
30+
# pylint: enable=g-direct-tensorflow-import
3331

3432
as_tensor = tf.convert_to_tensor
3533
as_ragged = tf.ragged.constant
@@ -643,9 +641,6 @@ def testTFLite(self):
643641

644642
converter = tf.lite.TFLiteConverter.from_keras_model(model)
645643
model_content = converter.convert()
646-
if tf.__version__.startswith('2.20.'):
647-
self.skipTest('TODO: b/441006328 - tfl_interpreter cannot be imported '
648-
f'next to tf-nightly~=2.20.0; got TF {tf.__version__}')
649644
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
650645
signature_runner = interpreter.get_signature_runner('serving_default')
651646
obtained = signature_runner(**test_graph_dict)['final_edge_adjacency']
@@ -1308,9 +1303,6 @@ def testTFLite(self):
13081303

13091304
converter = tf.lite.TFLiteConverter.from_keras_model(model)
13101305
model_content = converter.convert()
1311-
if tf.__version__.startswith('2.20.'):
1312-
self.skipTest('TODO: b/441006328 - tfl_interpreter cannot be imported '
1313-
f'next to tf-nightly~=2.20.0; got TF {tf.__version__}')
13141306
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
13151307
signature_runner = interpreter.get_signature_runner('serving_default')
13161308
obtained = signature_runner(**test_graph_dict)['final_edge_adjacency']
@@ -1679,9 +1671,6 @@ def testTFLite(self):
16791671

16801672
converter = tf.lite.TFLiteConverter.from_keras_model(model)
16811673
model_content = converter.convert()
1682-
if tf.__version__.startswith('2.20.'):
1683-
self.skipTest('TODO: b/441006328 - tfl_interpreter cannot be imported '
1684-
f'next to tf-nightly~=2.20.0; got TF {tf.__version__}')
16851674
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
16861675
signature_runner = interpreter.get_signature_runner('serving_default')
16871676
obtained = signature_runner(**test_graph_dict)['final_edge_adjacency']

tensorflow_gnn/graph/graph_tensor_test.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,10 @@
2424
from tensorflow_gnn.graph import graph_tensor as gt
2525
from tensorflow_gnn.graph import graph_tensor_test_utils as tu
2626

27-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
28-
if not tf.__version__.startswith('2.20.'): # TODO: b/441006328 - Remove this.
29-
# The following import crashes with tf-nightly~=2.20.0.
30-
from ai_edge_litert import interpreter as tfl_interpreter
27+
# pylint: disable=g-direct-tensorflow-import
28+
from ai_edge_litert import interpreter as tfl_interpreter
3129
from tensorflow.python.framework import type_spec
32-
# pylint: enable=g-import-not-at-top,g-direct-tensorflow-import
30+
# pylint: enable=g-direct-tensorflow-import
3331

3432
as_tensor = tf.convert_to_tensor
3533
as_ragged = tf.ragged.constant
@@ -1548,9 +1546,6 @@ def testTFLite(self):
15481546

15491547
converter = tf.lite.TFLiteConverter.from_keras_model(model)
15501548
model_content = converter.convert()
1551-
if tf.__version__.startswith('2.20.'):
1552-
self.skipTest('TODO: b/441006328 - tfl_interpreter cannot be imported '
1553-
f'next to tf-nightly~=2.20.0; got TF {tf.__version__}')
15541549
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
15551550
signature_runner = interpreter.get_signature_runner('serving_default')
15561551
obtained = signature_runner(
@@ -1754,9 +1749,6 @@ def testTFLite(self):
17541749

17551750
converter = tf.lite.TFLiteConverter.from_keras_model(model)
17561751
model_content = converter.convert()
1757-
if tf.__version__.startswith('2.20.'):
1758-
self.skipTest('TODO: b/441006328 - tfl_interpreter cannot be imported '
1759-
f'next to tf-nightly~=2.20.0; got TF {tf.__version__}')
17601752
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
17611753
signature_runner = interpreter.get_signature_runner('serving_default')
17621754
obtained = signature_runner(

tensorflow_gnn/keras/layers/graph_ops_test.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@
2323
from tensorflow_gnn.graph import graph_constants as const
2424
from tensorflow_gnn.graph import graph_tensor as gt
2525
from tensorflow_gnn.keras.layers import graph_ops
26-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
27-
if not tf.__version__.startswith("2.20."): # TODO: b/441006328 - Remove this.
28-
# The following import crashes with tf-nightly~=2.20.0.
29-
from ai_edge_litert import interpreter as tfl_interpreter
30-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
26+
# pylint: disable=g-direct-tensorflow-import
27+
from ai_edge_litert import interpreter as tfl_interpreter
28+
# pylint: enable=g-direct-tensorflow-import
3129

3230

3331
class ReadoutTest(tf.test.TestCase, parameterized.TestCase):
@@ -172,9 +170,6 @@ def testTFLite(self, location):
172170

173171
converter = tf.lite.TFLiteConverter.from_keras_model(model)
174172
model_content = converter.convert()
175-
if tf.__version__.startswith("2.20."):
176-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
177-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
178173
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
179174
signature_runner = interpreter.get_signature_runner("serving_default")
180175
obtained = signature_runner(**test_graph_134_dict)["test_readout"]
@@ -303,12 +298,8 @@ def testTFLite(self):
303298
model = tf.keras.Model(inputs, outputs)
304299
expected = model(test_graph_22_dict)
305300

306-
# TODO(b/276291104): Remove when TF 2.11+ is required by all of TFGNN
307301
converter = tf.lite.TFLiteConverter.from_keras_model(model)
308302
model_content = converter.convert()
309-
if tf.__version__.startswith("2.20."):
310-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
311-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
312303
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
313304
signature_runner = interpreter.get_signature_runner("serving_default")
314305
obtained = signature_runner(**test_graph_22_dict)["test_readout_first"]
@@ -436,9 +427,6 @@ def testTFLite(self):
436427

437428
converter = tf.lite.TFLiteConverter.from_keras_model(model)
438429
model_content = converter.convert()
439-
if tf.__version__.startswith("2.20."):
440-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
441-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
442430
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
443431
signature_runner = interpreter.get_signature_runner("serving_default")
444432
actual = signature_runner(
@@ -573,9 +561,6 @@ def testTFLite(self):
573561

574562
converter = tf.lite.TFLiteConverter.from_keras_model(model)
575563
model_content = converter.convert()
576-
if tf.__version__.startswith("2.20."):
577-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
578-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
579564
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
580565
signature_runner = interpreter.get_signature_runner("serving_default")
581566
actual = signature_runner(
@@ -639,9 +624,6 @@ def testTFLite(self):
639624

640625
converter = tf.lite.TFLiteConverter.from_keras_model(model)
641626
model_content = converter.convert()
642-
if tf.__version__.startswith("2.20."):
643-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
644-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
645627
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
646628
signature_runner = interpreter.get_signature_runner("serving_default")
647629
actual = signature_runner(
@@ -760,9 +742,6 @@ def testTFLite(self):
760742

761743
converter = tf.lite.TFLiteConverter.from_keras_model(model)
762744
model_content = converter.convert()
763-
if tf.__version__.startswith("2.20."):
764-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
765-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
766745
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
767746
signature_runner = interpreter.get_signature_runner("serving_default")
768747
obtained = signature_runner(**test_graph_134_dict)["final_edge_states"]
@@ -961,9 +940,6 @@ def testTFLite(self, tag, location):
961940

962941
converter = tf.lite.TFLiteConverter.from_keras_model(model)
963942
model_content = converter.convert()
964-
if tf.__version__.startswith("2.20."):
965-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
966-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
967943
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
968944
signature_runner = interpreter.get_signature_runner("serving_default")
969945
obtained = signature_runner(**test_values)["test_broadcast"]
@@ -1268,9 +1244,6 @@ def testTFLite(self, tag, location, reduce_type):
12681244

12691245
converter = tf.lite.TFLiteConverter.from_keras_model(model)
12701246
model_content = converter.convert()
1271-
if tf.__version__.startswith("2.20."):
1272-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
1273-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
12741247
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
12751248
signature_runner = interpreter.get_signature_runner("serving_default")
12761249
obtained = signature_runner(**test_values)["test_pool"]

tensorflow_gnn/keras/layers/next_state_test.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
from tensorflow_gnn.graph import graph_constants as const
2020
from tensorflow_gnn.keras.layers import next_state as next_state_lib
2121
from tensorflow_gnn.utils import tf_test_utils as tftu
22-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
23-
if not tf.__version__.startswith("2.20."): # TODO: b/441006328 - Remove this.
24-
# The following import crashes with tf-nightly~=2.20.0.
25-
from ai_edge_litert import interpreter as tfl_interpreter
26-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
22+
# pylint: disable=g-direct-tensorflow-import
23+
from ai_edge_litert import interpreter as tfl_interpreter
24+
# pylint: enable=g-direct-tensorflow-import
2725

2826

2927
class NextStateFromConcatTest(tf.test.TestCase, parameterized.TestCase):
@@ -184,9 +182,6 @@ def testTFLite(self):
184182

185183
converter = tf.lite.TFLiteConverter.from_keras_model(model)
186184
model_content = converter.convert()
187-
if tf.__version__.startswith("2.20."):
188-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
189-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
190185
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
191186
signature_runner = interpreter.get_signature_runner("serving_default")
192187
obtained = signature_runner(**test_input_dict)["residual_next_state"]

tensorflow_gnn/keras/layers/padding_ops_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@
2323
from tensorflow_gnn.keras import keras_tensors # For registration. pylint: disable=unused-import
2424
from tensorflow_gnn.keras.layers import padding_ops
2525
from tensorflow_gnn.utils import tf_test_utils as tftu
26-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
27-
if not tf.__version__.startswith("2.20."): # TODO: b/441006328 - Remove this.
28-
# The following import crashes with tf-nightly~=2.20.0.
29-
from ai_edge_litert import interpreter as tfl_interpreter
30-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
26+
# pylint: disable=g-direct-tensorflow-import
27+
from ai_edge_litert import interpreter as tfl_interpreter
28+
# pylint: enable=g-direct-tensorflow-import
3129

3230

3331
class PadToTotalSizesTest(tf.test.TestCase, parameterized.TestCase):

tensorflow_gnn/models/gat_v2/layers_test.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
import tensorflow_gnn as tfgnn
1818
from tensorflow_gnn.models import gat_v2
1919
from tensorflow_gnn.utils import tf_test_utils as tftu
20-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
21-
if not tf.__version__.startswith("2.20."): # TODO: b/441006328 - Remove this.
22-
# The following import crashes with tf-nightly~=2.20.0.
23-
from ai_edge_litert import interpreter as tfl_interpreter
24-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
20+
# pylint: disable=g-direct-tensorflow-import
21+
from ai_edge_litert import interpreter as tfl_interpreter
22+
# pylint: enable=g-direct-tensorflow-import
2523

2624

2725
class GATv2Test(tf.test.TestCase, parameterized.TestCase):
@@ -707,9 +705,6 @@ def testBasic(self):
707705

708706
converter = tf.lite.TFLiteConverter.from_keras_model(model)
709707
model_content = converter.convert()
710-
if tf.__version__.startswith("2.20."):
711-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
712-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
713708
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
714709
signature_runner = interpreter.get_signature_runner("serving_default")
715710
obtained = signature_runner(**test_graph_1_dict)["final_node_states"]

tensorflow_gnn/models/gcn/gcn_conv_test.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
import tensorflow_gnn as tfgnn
2020
from tensorflow_gnn.models.gcn import gcn_conv
2121
from tensorflow_gnn.utils import tf_test_utils as tftu
22-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
23-
if not tf.__version__.startswith('2.20.'): # TODO: b/441006328 - Remove this.
24-
# The following import crashes with tf-nightly~=2.20.0.
25-
from ai_edge_litert import interpreter as tfl_interpreter
26-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
22+
# pylint: disable=g-direct-tensorflow-import
23+
from ai_edge_litert import interpreter as tfl_interpreter
24+
# pylint: enable=g-direct-tensorflow-import
2725

2826

2927
class GcnConvTest(tf.test.TestCase, parameterized.TestCase):
@@ -873,9 +871,6 @@ def testBasic(self, add_self_loops, edge_weight_feature_name):
873871

874872
converter = tf.lite.TFLiteConverter.from_keras_model(model)
875873
model_content = converter.convert()
876-
if tf.__version__.startswith('2.20.'):
877-
self.skipTest('TODO: b/441006328 - tfl_interpreter cannot be imported '
878-
f'next to tf-nightly~=2.20.0; got TF {tf.__version__}')
879874
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
880875
signature_runner = interpreter.get_signature_runner('serving_default')
881876
obtained = signature_runner(**test_graph_1_dict)['final_node_states']

tensorflow_gnn/models/graph_sage/layers_test.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
import tensorflow_gnn as tfgnn
2020
from tensorflow_gnn.models.graph_sage import layers as graph_sage
2121
from tensorflow_gnn.utils import tf_test_utils as tftu
22-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
23-
if not tf.__version__.startswith("2.20."): # TODO: b/441006328 - Remove this.
24-
# The following import crashes with tf-nightly~=2.20.0.
25-
from ai_edge_litert import interpreter as tfl_interpreter
26-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
22+
# pylint: disable=g-direct-tensorflow-import
23+
from ai_edge_litert import interpreter as tfl_interpreter
24+
# pylint: enable=g-direct-tensorflow-import
2725

2826
_FEATURE_NAME = "f"
2927

@@ -624,9 +622,6 @@ def testBasic(self, use_pooling, hidden_units, combine_type):
624622

625623
converter = tf.lite.TFLiteConverter.from_keras_model(model)
626624
model_content = converter.convert()
627-
if tf.__version__.startswith("2.20."):
628-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
629-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
630625
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
631626
signature_runner = interpreter.get_signature_runner("serving_default")
632627
obtained = signature_runner(**test_graph_1_dict)["final_node_states"]

tensorflow_gnn/models/hgt/layers_test.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
import tensorflow_gnn as tfgnn
2020
from tensorflow_gnn.models.hgt import layers
2121
from tensorflow_gnn.utils import tf_test_utils as tftu
22-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
23-
if not tf.__version__.startswith("2.20."): # TODO: b/441006328 - Remove this.
24-
# The following import crashes with tf-nightly~=2.20.0.
25-
from ai_edge_litert import interpreter as tfl_interpreter
26-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
22+
# pylint: disable=g-direct-tensorflow-import
23+
from ai_edge_litert import interpreter as tfl_interpreter
24+
# pylint: enable=g-direct-tensorflow-import
2725

2826

2927
def _homogeneous_cycle_graph(node_state, edge_state=None):
@@ -810,9 +808,6 @@ def testBasic(self):
810808

811809
converter = tf.lite.TFLiteConverter.from_keras_model(model)
812810
model_content = converter.convert()
813-
if tf.__version__.startswith("2.20."):
814-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
815-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
816811
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
817812
signature_runner = interpreter.get_signature_runner("serving_default")
818813
obtained = signature_runner(**test_graph_1_dict)["final_engine_states"]

tensorflow_gnn/models/mt_albis/layers_test.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@
2222
import tensorflow_gnn as tfgnn
2323
from tensorflow_gnn.models.mt_albis import layers
2424
from tensorflow_gnn.utils import tf_test_utils as tftu
25-
# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
26-
if not tf.__version__.startswith("2.20."): # TODO: b/441006328 - Remove this.
27-
# The following import crashes with tf-nightly~=2.20.0.
28-
from ai_edge_litert import interpreter as tfl_interpreter
29-
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top
25+
# pylint: disable=g-direct-tensorflow-import
26+
from ai_edge_litert import interpreter as tfl_interpreter
27+
# pylint: enable=g-direct-tensorflow-import
3028

3129

3230
class MtAlbisNextNodeStateTest(tf.test.TestCase, parameterized.TestCase):
@@ -383,9 +381,6 @@ def test(self,
383381

384382
converter = tf.lite.TFLiteConverter.from_keras_model(model)
385383
model_content = converter.convert()
386-
if tf.__version__.startswith("2.20."):
387-
self.skipTest("TODO: b/441006328 - tfl_interpreter cannot be imported "
388-
f"next to tf-nightly~=2.20.0; got TF {tf.__version__}")
389384
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
390385
signature_runner = interpreter.get_signature_runner("serving_default")
391386
obtained = signature_runner(**test_graph_1_dict)["final_node_states"]

0 commit comments

Comments
 (0)