Skip to content

Conversation

feifeibear
Copy link

No description provided.

@feifeibear
Copy link
Author

fix #21

zhuzilin pushed a commit that referenced this pull request Apr 8, 2025
Fix "Cannot call numel() on tensor with symbolic sizes/strides"
when torch compile is enabled.

Full log:

```log
Traceback (most recent call last):
  File "ring-flash-attention/benchmark/benchmark_varlen_kvpacked_func.py", line 212, in <module>
    benchmark(
  File "ring-flash-attention/benchmark/benchmark_varlen_kvpacked_func.py", line 175, in benchmark
    out = wrapper(i)
          ^^^^^^^^^^
  File "ring-flash-attention/benchmark/benchmark_varlen_kvpacked_func.py", line 155, in wrapper
    return f(
           ^^
  File "torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 1164, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
           ^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/misc.py", line 1022, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/misc.py", line 759, in call_method
    return self.call_apply(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/misc.py", line 681, in call_apply
    ).call_function(tx, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/higher_order_ops.py", line 2422, in call_function
    (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
                                            ^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/higher_order_ops.py", line 556, in speculate_subgraph
    output = f.call_function(tx, args, sub_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/symbolic_convert.py", line 2341, in CALL
    self._call(inst)
  File "torch/_dynamo/symbolic_convert.py", line 2335, in _call
    self.call_function(fn, args, kwargs)
  File "torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/torch.py", line 953, in call_function
    tensor_variable = wrap_fx_proxy(
                      ^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
           ^^^^^^^^^^^^^^^
  File "torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/utils.py", line 2536, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "torch/_dynamo/utils.py", line 2471, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
    return fn()
           ^^^^
  File "torch/_dynamo/utils.py", line 2472, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/utils.py", line 2604, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "torch/_dynamo/utils.py", line 2586, in run_node
    return node.target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method zeros of type object at 0x1464c03bff60>(*((FakeTensor(..., device='cuda:1', size=(), dtype=torch.int32),),), **{'dtype': torch.bool}):
Cannot call numel() on tensor with symbolic sizes/strides
Exception raised from throw_cannot_call_with_symbolic at /pytorch/c10/core/TensorImpl.cpp:291 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x146474c491b6 in torch/lib/libc10.so)
frame #1: c10::TensorImpl::throw_cannot_call_with_symbolic(char const*) const + 0x9c (0x146474bf0f92 in torch/lib/libc10.so)
frame #2: <unknown function> + 0x666af (0x146474c216af in torch/lib/libc10.so)
frame #3: <unknown function> + 0x64b79e (0x1464bf67579e in torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0x6b4b7b (0x1464bf6deb7b in torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x224918 (0x564f1b9fb918 in rlhf/.conda/bin/python)
frame #6: _PyObject_Call + 0xb5 (0x564f1ba0c515 in rlhf/.conda/bin/python)
frame #7: <unknown function> + 0x113479 (0x564f1b8ea479 in rlhf/.conda/bin/python)
frame #8: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #9: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #10: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #11: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #12: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #13: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #14: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #15: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #16: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #17: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #18: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #19: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #20: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #21: <unknown function> + 0x113479 (0x564f1b8ea479 in rlhf/.conda/bin/python)
frame #22: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #23: _PyObject_Call + 0x12b (0x564f1ba0c58b in rlhf/.conda/bin/python)
frame #24: <unknown function> + 0x113479 (0x564f1b8ea479 in rlhf/.conda/bin/python)
frame #25: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #26: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #27: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #28: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #29: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #30: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #31: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #32: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #33: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #34: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #35: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #36: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #37: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #38: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #39: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #40: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #41: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #42: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #43: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #44: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #45: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #46: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #47: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #48: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #49: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #50: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #51: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #52: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #53: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #54: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #55: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #56: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #57: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #58: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #59: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)
frame #60: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python)
frame #61: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so)
frame #62: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python)

from user code:
   File "ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 459, in zigzag_ring_flash_attn_varlen_kvpacked_func
    return ZigZagRingFlashAttnVarlenFunc.apply(
  File "ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 344, in forward
    half_index0 = get_half_index(cu_seqlens, front=True)
  File "ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 31, in get_half_index
    index = torch.zeros((cu_seqlens[-1],), dtype=bool)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```

Signed-off-by: Hollow Man <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant