@@ -57,35 +57,41 @@ def _interpolate_bilinear_with_checks(
57
57
) -> tf .Tensor :
58
58
"""Perform checks on inputs without tf.function decorator to avoid flakiness."""
59
59
if indexing != "ij" and indexing != "xy" :
60
- raise ValueError ("Indexing mode must be 'ij' or 'xy'" )
60
+ raise ValueError ("Indexing mode must be 'ij' or 'xy'" )
61
61
62
62
grid = tf .convert_to_tensor (grid )
63
63
query_points = tf .convert_to_tensor (query_points )
64
64
grid_shape = tf .shape (grid )
65
65
query_shape = tf .shape (query_points )
66
66
67
- with tf .control_dependencies ([
68
- tf .Assert (tf .equal (tf .rank (grid ), 4 ), ["Grid must be 4D Tensor" ]),
69
- tf .Assert (
70
- tf .greater_equal (grid_shape [1 ], 2 ),
71
- ["Grid height must be at least 2." ]),
72
- tf .Assert (
73
- tf .greater_equal (grid_shape [2 ], 2 ),
74
- ["Grid width must be at least 2." ]),
75
- tf .Assert (
76
- tf .equal (tf .rank (query_points ), 3 ),
77
- ["Query points must be 3 dimensional." ]),
78
- tf .Assert (
79
- tf .equal (query_shape [2 ], 2 ),
80
- ["Query points last dimension must be 2." ])
81
- ]):
67
+ with tf .control_dependencies (
68
+ [
69
+ tf .Assert (tf .equal (tf .rank (grid ), 4 ), ["Grid must be 4D Tensor" ]),
70
+ tf .Assert (
71
+ tf .greater_equal (grid_shape [1 ], 2 ), ["Grid height must be at least 2." ]
72
+ ),
73
+ tf .Assert (
74
+ tf .greater_equal (grid_shape [2 ], 2 ), ["Grid width must be at least 2." ]
75
+ ),
76
+ tf .Assert (
77
+ tf .equal (tf .rank (query_points ), 3 ),
78
+ ["Query points must be 3 dimensional." ],
79
+ ),
80
+ tf .Assert (
81
+ tf .equal (query_shape [2 ], 2 ), ["Query points last dimension must be 2." ]
82
+ ),
83
+ ]
84
+ ):
82
85
return _interpolate_bilinear_impl (grid , query_points , indexing , name )
83
86
84
87
85
88
@tf .function
86
- def _interpolate_bilinear_impl (grid : types .TensorLike ,
87
- query_points : types .TensorLike , indexing : str ,
88
- name : Optional [str ]) -> tf .Tensor :
89
+ def _interpolate_bilinear_impl (
90
+ grid : types .TensorLike ,
91
+ query_points : types .TensorLike ,
92
+ indexing : str ,
93
+ name : Optional [str ],
94
+ ) -> tf .Tensor :
89
95
"""tf.function implementation of interpolate_bilinear."""
90
96
with tf .name_scope (name or "interpolate_bilinear" ):
91
97
grid_shape = tf .shape (grid )
0 commit comments