-
Notifications
You must be signed in to change notification settings - Fork 81
add flash-attn install in setup.py #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
feifeibear
wants to merge
1
commit into
zhuzilin:main
Choose a base branch
from
feifeibear:0327
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fix #21 |
zhuzilin
pushed a commit
that referenced
this pull request
Apr 8, 2025
Fix "Cannot call numel() on tensor with symbolic sizes/strides" when torch compile is enabled. Full log: ```log Traceback (most recent call last): File "ring-flash-attention/benchmark/benchmark_varlen_kvpacked_func.py", line 212, in <module> benchmark( File "ring-flash-attention/benchmark/benchmark_varlen_kvpacked_func.py", line 175, in benchmark out = wrapper(i) ^^^^^^^^^^ File "ring-flash-attention/benchmark/benchmark_varlen_kvpacked_func.py", line 155, in wrapper return f( ^^ File "torch/_dynamo/eval_frame.py", line 574, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/convert_frame.py", line 1380, in __call__ return self._torchdynamo_orig_callable( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/convert_frame.py", line 1164, in __call__ result = self._inner_convert( ^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/convert_frame.py", line 547, in __call__ return _compile( ^^^^^^^^^ File "torch/_dynamo/convert_frame.py", line 986, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/convert_frame.py", line 715, in compile_inner return _compile_inner(code, one_graph, hooks, transform) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_utils_internal.py", line 95, in wrapper_function return function(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/convert_frame.py", line 750, in _compile_inner out_code = transform_code_object(code, transform) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object transformations(instructions, code_options) File "torch/_dynamo/convert_frame.py", line 231, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/convert_frame.py", line 662, in transform tracer.run() File "torch/_dynamo/symbolic_convert.py", line 2868, in run super().run() File "torch/_dynamo/symbolic_convert.py", line 1052, in run while self.step(): ^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 962, in step self.dispatch_table[inst.opcode](self, inst) File "torch/_dynamo/symbolic_convert.py", line 659, in wrapper return inner_fn(self, inst) ^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 2341, in CALL self._call(inst) File "torch/_dynamo/symbolic_convert.py", line 2335, in _call self.call_function(fn, args, kwargs) File "torch/_dynamo/symbolic_convert.py", line 897, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/misc.py", line 1022, in call_function return self.obj.call_method(tx, self.name, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/misc.py", line 759, in call_method return self.call_apply(tx, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/misc.py", line 681, in call_apply ).call_function(tx, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/higher_order_ops.py", line 2422, in call_function (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph( ^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/higher_order_ops.py", line 556, in speculate_subgraph output = f.call_function(tx, args, sub_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/functions.py", line 317, in call_function return super().call_function(tx, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/functions.py", line 118, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 3072, in inline_call return cls.inline_call_(parent, func, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_ tracer.run() File "torch/_dynamo/symbolic_convert.py", line 1052, in run while self.step(): ^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 962, in step self.dispatch_table[inst.opcode](self, inst) File "torch/_dynamo/symbolic_convert.py", line 659, in wrapper return inner_fn(self, inst) ^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 2341, in CALL self._call(inst) File "torch/_dynamo/symbolic_convert.py", line 2335, in _call self.call_function(fn, args, kwargs) File "torch/_dynamo/symbolic_convert.py", line 897, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/functions.py", line 317, in call_function return super().call_function(tx, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/functions.py", line 118, in call_function return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 3072, in inline_call return cls.inline_call_(parent, func, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_ tracer.run() File "torch/_dynamo/symbolic_convert.py", line 1052, in run while self.step(): ^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 962, in step self.dispatch_table[inst.opcode](self, inst) File "torch/_dynamo/symbolic_convert.py", line 659, in wrapper return inner_fn(self, inst) ^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/symbolic_convert.py", line 2341, in CALL self._call(inst) File "torch/_dynamo/symbolic_convert.py", line 2335, in _call self.call_function(fn, args, kwargs) File "torch/_dynamo/symbolic_convert.py", line 897, in call_function self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/torch.py", line 953, in call_function tensor_variable = wrap_fx_proxy( ^^^^^^^^^^^^^^ File "torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls return _wrap_fx_proxy( ^^^^^^^^^^^^^^^ File "torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/utils.py", line 2536, in get_fake_value raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None File "torch/_dynamo/utils.py", line 2471, in get_fake_value ret_val = wrap_fake_exception( ^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/utils.py", line 2017, in wrap_fake_exception return fn() ^^^^ File "torch/_dynamo/utils.py", line 2472, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "torch/_dynamo/utils.py", line 2604, in run_node raise RuntimeError(make_error_message(e)).with_traceback( File "torch/_dynamo/utils.py", line 2586, in run_node return node.target(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method zeros of type object at 0x1464c03bff60>(*((FakeTensor(..., device='cuda:1', size=(), dtype=torch.int32),),), **{'dtype': torch.bool}): Cannot call numel() on tensor with symbolic sizes/strides Exception raised from throw_cannot_call_with_symbolic at /pytorch/c10/core/TensorImpl.cpp:291 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x146474c491b6 in torch/lib/libc10.so) frame #1: c10::TensorImpl::throw_cannot_call_with_symbolic(char const*) const + 0x9c (0x146474bf0f92 in torch/lib/libc10.so) frame #2: <unknown function> + 0x666af (0x146474c216af in torch/lib/libc10.so) frame #3: <unknown function> + 0x64b79e (0x1464bf67579e in torch/lib/libtorch_python.so) frame #4: <unknown function> + 0x6b4b7b (0x1464bf6deb7b in torch/lib/libtorch_python.so) frame #5: <unknown function> + 0x224918 (0x564f1b9fb918 in rlhf/.conda/bin/python) frame #6: _PyObject_Call + 0xb5 (0x564f1ba0c515 in rlhf/.conda/bin/python) frame #7: <unknown function> + 0x113479 (0x564f1b8ea479 in rlhf/.conda/bin/python) frame #8: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #9: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #10: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #11: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #12: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #13: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #14: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #15: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #16: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #17: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #18: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #19: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #20: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #21: <unknown function> + 0x113479 (0x564f1b8ea479 in rlhf/.conda/bin/python) frame #22: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #23: _PyObject_Call + 0x12b (0x564f1ba0c58b in rlhf/.conda/bin/python) frame #24: <unknown function> + 0x113479 (0x564f1b8ea479 in rlhf/.conda/bin/python) frame #25: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #26: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #27: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #28: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #29: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #30: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #31: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #32: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #33: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #34: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #35: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #36: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #37: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #38: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #39: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #40: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #41: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #42: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #43: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #44: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #45: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #46: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #47: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #48: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #49: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #50: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #51: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #52: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #53: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #54: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #55: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #56: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #57: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #58: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #59: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) frame #60: <unknown function> + 0x1127e4 (0x564f1b8e97e4 in rlhf/.conda/bin/python) frame #61: <unknown function> + 0x92bb14 (0x1464bf955b14 in torch/lib/libtorch_python.so) frame #62: PyObject_Vectorcall + 0x2e (0x564f1b9f055e in rlhf/.conda/bin/python) from user code: File "ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 459, in zigzag_ring_flash_attn_varlen_kvpacked_func return ZigZagRingFlashAttnVarlenFunc.apply( File "ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 344, in forward half_index0 = get_half_index(cu_seqlens, front=True) File "ring_flash_attn/zigzag_ring_flash_attn_varlen.py", line 31, in get_half_index index = torch.zeros((cu_seqlens[-1],), dtype=bool) Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information You can suppress this exception and fall back to eager by setting: import torch._dynamo torch._dynamo.config.suppress_errors = True ``` Signed-off-by: Hollow Man <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.