Skip to content

Commit 4fed7de

Browse files
authored
Merge pull request #2411 from tnybny/fix/add-inf-handling-to-select
[fix] mul trick for tf.where(tf.math.is_inf(x, true_val, false_val)) pattern
2 parents c34ac1d + 617368a commit 4fed7de

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

tests/test_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3673,6 +3673,22 @@ def func(x):
36733673
return tf.identity(picks, name=_TFOUTPUT)
36743674
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
36753675

3676+
@check_opset_min_version(10, "IsInf")
3677+
def test_where_with_isinf_condition(self):
3678+
def func(x, y, z):
3679+
# Use is_inf as condition to trigger the IsInf code path
3680+
condition = tf.math.is_inf(x)
3681+
result = tf.where(condition, y, z)
3682+
return tf.identity(result, name=_TFOUTPUT)
3683+
3684+
# Create test data with some infinite values
3685+
x_val = np.array([1.0, np.inf, 3.0, -np.inf, 5.0], dtype=np.float32)
3686+
y_val = np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32)
3687+
z_val = np.array([100.0, 200.0, 300.0, 400.0, 500.0], dtype=np.float32)
3688+
3689+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val, _INPUT2: z_val})
3690+
3691+
36763692
@check_opset_min_version(9, "IsNaN")
36773693
def test_where_isnan(self):
36783694
x_val = np.array([1, 2, -3, float('nan'), -5, -6, float('nan'), 8, 9, 0], dtype=np.float32)

tf2onnx/onnx_opset/controlflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def version_9(cls, ctx, node, **kwargs):
195195
handles_nan = node.get_attr_value("handles_nan", False)
196196
if ctx.get_dtype(node.output[0]) in [TensorProto.FLOAT, TensorProto.DOUBLE]:
197197
cond_node = node.inputs[0]
198-
if cond_node.type == "IsNaN":
198+
if cond_node.type in {"IsNaN", "IsInf"}:
199+
# We can't use the mul trick if Inf is involved since Inf * 0 = NaN as per IEEE 754.
199200
handles_nan = True
200201
if cond_node.type == "NotEqual" and cond_node.input[0] == cond_node.input[1]:
201202
handles_nan = True

0 commit comments

Comments
 (0)