Skip to content

Commit 8e68a1d

Browse files
committed
Fix formatting
1 parent 376e22e commit 8e68a1d

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

tensorflow_addons/image/dense_image_warp.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,35 +57,41 @@ def _interpolate_bilinear_with_checks(
5757
) -> tf.Tensor:
5858
"""Perform checks on inputs without tf.function decorator to avoid flakiness."""
5959
if indexing != "ij" and indexing != "xy":
60-
raise ValueError("Indexing mode must be 'ij' or 'xy'")
60+
raise ValueError("Indexing mode must be 'ij' or 'xy'")
6161

6262
grid = tf.convert_to_tensor(grid)
6363
query_points = tf.convert_to_tensor(query_points)
6464
grid_shape = tf.shape(grid)
6565
query_shape = tf.shape(query_points)
6666

67-
with tf.control_dependencies([
68-
tf.Assert(tf.equal(tf.rank(grid), 4), ["Grid must be 4D Tensor"]),
69-
tf.Assert(
70-
tf.greater_equal(grid_shape[1], 2),
71-
["Grid height must be at least 2."]),
72-
tf.Assert(
73-
tf.greater_equal(grid_shape[2], 2),
74-
["Grid width must be at least 2."]),
75-
tf.Assert(
76-
tf.equal(tf.rank(query_points), 3),
77-
["Query points must be 3 dimensional."]),
78-
tf.Assert(
79-
tf.equal(query_shape[2], 2),
80-
["Query points last dimension must be 2."])
81-
]):
67+
with tf.control_dependencies(
68+
[
69+
tf.Assert(tf.equal(tf.rank(grid), 4), ["Grid must be 4D Tensor"]),
70+
tf.Assert(
71+
tf.greater_equal(grid_shape[1], 2), ["Grid height must be at least 2."]
72+
),
73+
tf.Assert(
74+
tf.greater_equal(grid_shape[2], 2), ["Grid width must be at least 2."]
75+
),
76+
tf.Assert(
77+
tf.equal(tf.rank(query_points), 3),
78+
["Query points must be 3 dimensional."],
79+
),
80+
tf.Assert(
81+
tf.equal(query_shape[2], 2), ["Query points last dimension must be 2."]
82+
),
83+
]
84+
):
8285
return _interpolate_bilinear_impl(grid, query_points, indexing, name)
8386

8487

8588
@tf.function
86-
def _interpolate_bilinear_impl(grid: types.TensorLike,
87-
query_points: types.TensorLike, indexing: str,
88-
name: Optional[str]) -> tf.Tensor:
89+
def _interpolate_bilinear_impl(
90+
grid: types.TensorLike,
91+
query_points: types.TensorLike,
92+
indexing: str,
93+
name: Optional[str],
94+
) -> tf.Tensor:
8995
"""tf.function implementation of interpolate_bilinear."""
9096
with tf.name_scope(name or "interpolate_bilinear"):
9197
grid_shape = tf.shape(grid)

0 commit comments

Comments
 (0)