@@ -46,33 +46,54 @@ def interpolate_bilinear(
46
46
ValueError: if the indexing mode is invalid, or if the shape of the
47
47
inputs invalid.
48
48
"""
49
+ return _interpolate_bilinear_with_checks (grid , query_points , indexing , name )
50
+
51
+
52
+ def _interpolate_bilinear_with_checks (
53
+ grid : types .TensorLike ,
54
+ query_points : types .TensorLike ,
55
+ indexing : str ,
56
+ name : Optional [str ],
57
+ ) -> tf .Tensor :
58
+ """Perform checks on inputs without tf.function decorator to avoid flakiness."""
49
59
if indexing != "ij" and indexing != "xy" :
50
60
raise ValueError ("Indexing mode must be 'ij' or 'xy'" )
51
61
62
+ grid = tf .convert_to_tensor (grid )
63
+ query_points = tf .convert_to_tensor (query_points )
64
+ grid_shape = tf .shape (grid )
65
+ query_shape = tf .shape (query_points )
66
+
67
+ with tf .control_dependencies (
68
+ [
69
+ tf .debugging .assert_equal (tf .rank (grid ), 4 , "Grid must be 4D Tensor" ),
70
+ tf .debugging .assert_greater_equal (
71
+ grid_shape [1 ], 2 , "Grid height must be at least 2."
72
+ ),
73
+ tf .debugging .assert_greater_equal (
74
+ grid_shape [2 ], 2 , "Grid width must be at least 2."
75
+ ),
76
+ tf .debugging .assert_equal (
77
+ tf .rank (query_points ), 3 , "Query points must be 3 dimensional."
78
+ ),
79
+ tf .debugging .assert_equal (
80
+ query_shape [2 ], 2 , "Query points last dimension must be 2."
81
+ ),
82
+ ]
83
+ ):
84
+ return _interpolate_bilinear_impl (grid , query_points , indexing , name )
85
+
86
+
87
+ def _interpolate_bilinear_impl (
88
+ grid : types .TensorLike ,
89
+ query_points : types .TensorLike ,
90
+ indexing : str ,
91
+ name : Optional [str ],
92
+ ) -> tf .Tensor :
93
+ """tf.function implementation of interpolate_bilinear."""
52
94
with tf .name_scope (name or "interpolate_bilinear" ):
53
- grid = tf .convert_to_tensor (grid )
54
- query_points = tf .convert_to_tensor (query_points )
55
-
56
- # grid shape checks
57
- grid_static_shape = grid .shape
58
95
grid_shape = tf .shape (grid )
59
- if grid_static_shape .dims is not None :
60
- if len (grid_static_shape ) != 4 :
61
- raise ValueError ("Grid must be 4D Tensor" )
62
- if grid_static_shape [1 ] is not None and grid_static_shape [1 ] < 2 :
63
- raise ValueError ("Grid height must be at least 2." )
64
- if grid_static_shape [2 ] is not None and grid_static_shape [2 ] < 2 :
65
- raise ValueError ("Grid width must be at least 2." )
66
-
67
- # query_points shape checks
68
- query_static_shape = query_points .shape
69
96
query_shape = tf .shape (query_points )
70
- if query_static_shape .dims is not None :
71
- if len (query_static_shape ) != 3 :
72
- raise ValueError ("Query points must be 3 dimensional." )
73
- query_hw = query_static_shape [2 ]
74
- if query_hw is not None and query_hw != 2 :
75
- raise ValueError ("Query points last dimension must be 2." )
76
97
77
98
batch_size , height , width , channels = (
78
99
grid_shape [0 ],
0 commit comments