Skip to content

Commit e2ae50d

Browse files
authored
Generate Assertion Ops for interpolate_bilinear (#2609)
1 parent a4113f6 commit e2ae50d

File tree

3 files changed

+46
-30
lines changed

3 files changed

+46
-30
lines changed

tensorflow_addons/image/dense_image_warp.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,54 @@ def interpolate_bilinear(
4646
ValueError: if the indexing mode is invalid, or if the shape of the
4747
inputs invalid.
4848
"""
49+
return _interpolate_bilinear_with_checks(grid, query_points, indexing, name)
50+
51+
52+
def _interpolate_bilinear_with_checks(
53+
grid: types.TensorLike,
54+
query_points: types.TensorLike,
55+
indexing: str,
56+
name: Optional[str],
57+
) -> tf.Tensor:
58+
"""Perform checks on inputs without tf.function decorator to avoid flakiness."""
4959
if indexing != "ij" and indexing != "xy":
5060
raise ValueError("Indexing mode must be 'ij' or 'xy'")
5161

62+
grid = tf.convert_to_tensor(grid)
63+
query_points = tf.convert_to_tensor(query_points)
64+
grid_shape = tf.shape(grid)
65+
query_shape = tf.shape(query_points)
66+
67+
with tf.control_dependencies(
68+
[
69+
tf.debugging.assert_equal(tf.rank(grid), 4, "Grid must be 4D Tensor"),
70+
tf.debugging.assert_greater_equal(
71+
grid_shape[1], 2, "Grid height must be at least 2."
72+
),
73+
tf.debugging.assert_greater_equal(
74+
grid_shape[2], 2, "Grid width must be at least 2."
75+
),
76+
tf.debugging.assert_equal(
77+
tf.rank(query_points), 3, "Query points must be 3 dimensional."
78+
),
79+
tf.debugging.assert_equal(
80+
query_shape[2], 2, "Query points last dimension must be 2."
81+
),
82+
]
83+
):
84+
return _interpolate_bilinear_impl(grid, query_points, indexing, name)
85+
86+
87+
def _interpolate_bilinear_impl(
88+
grid: types.TensorLike,
89+
query_points: types.TensorLike,
90+
indexing: str,
91+
name: Optional[str],
92+
) -> tf.Tensor:
93+
"""tf.function implementation of interpolate_bilinear."""
5294
with tf.name_scope(name or "interpolate_bilinear"):
53-
grid = tf.convert_to_tensor(grid)
54-
query_points = tf.convert_to_tensor(query_points)
55-
56-
# grid shape checks
57-
grid_static_shape = grid.shape
5895
grid_shape = tf.shape(grid)
59-
if grid_static_shape.dims is not None:
60-
if len(grid_static_shape) != 4:
61-
raise ValueError("Grid must be 4D Tensor")
62-
if grid_static_shape[1] is not None and grid_static_shape[1] < 2:
63-
raise ValueError("Grid height must be at least 2.")
64-
if grid_static_shape[2] is not None and grid_static_shape[2] < 2:
65-
raise ValueError("Grid width must be at least 2.")
66-
67-
# query_points shape checks
68-
query_static_shape = query_points.shape
6996
query_shape = tf.shape(query_points)
70-
if query_static_shape.dims is not None:
71-
if len(query_static_shape) != 3:
72-
raise ValueError("Query points must be 3 dimensional.")
73-
query_hw = query_static_shape[2]
74-
if query_hw is not None and query_hw != 2:
75-
raise ValueError("Query points last dimension must be 2.")
7697

7798
batch_size, height, width, channels = (
7899
grid_shape[0],

tensorflow_addons/image/tests/dense_image_warp_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ def test_interpolation():
239239
def test_size_exception():
240240
"""Make sure it throws an exception for images that are too small."""
241241
shape = [1, 2, 1, 1]
242-
with pytest.raises(ValueError, match="Grid width must be at least 2."):
242+
with pytest.raises(
243+
tf.errors.InvalidArgumentError, match="Grid width must be at least 2."
244+
):
243245
_check_interpolation_correctness(shape, "float32", "float32")
244246

245247

@@ -250,11 +252,3 @@ def test_unknown_shapes():
250252
shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
251253
for shape in shapes_to_try:
252254
_check_interpolation_correctness(shape, "float32", "float32", True)
253-
254-
255-
@pytest.mark.usefixtures("only_run_functions_eagerly")
256-
def test_symbolic_tensor_shape():
257-
image = tf.keras.layers.Input(shape=(7, 7, 192))
258-
flow = tf.ones((1, 7, 7, 2))
259-
interp = dense_image_warp(image, flow)
260-
np.testing.assert_array_equal(interp.shape.as_list(), [None, 7, 7, 192])

tools/testing/source_code_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def test_no_tf_control_dependencies():
178178
allowlist = [
179179
"tensorflow_addons/layers/wrappers.py",
180180
"tensorflow_addons/image/utils.py",
181+
"tensorflow_addons/image/dense_image_warp.py",
181182
"tensorflow_addons/optimizers/average_wrapper.py",
182183
"tensorflow_addons/optimizers/yogi.py",
183184
"tensorflow_addons/optimizers/lookahead.py",

0 commit comments

Comments
 (0)