Skip to content

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Jun 3, 2025

What does this PR do?

This enables ThunderFX to save traces such as execution computation and backward traces as artifacts of TORCH_TRACE.
By running scripts with TORCH_TRACE=/path/to/torch_trace_log_dir, without any changes in the scripts, we can check the traces as well as TorchDynamo output torch.fx.GraphModules and Inductor output, if fallback path is used.

This PR inserts trace_structured and trace_structured_artifact into thunderfx optimization path.

Ref:

In the following capture of an example of TORCH_TRACE, thunder_module_execution_computation_trc_12.txt has the execution forward trace.

image

Example:

$ TORCH_TRACE="torch_trace_test/" python thunder/benchmarks/benchmark_peft.py --model nvidia/Nemotron-Mini-4B-Instruct --compile thunder --max-steps 3 --fixed-num-hidden-layers 2 --trust-remote-code
# This command would open the generated HTML file on your brower
$ tlparse ./torch_trace_test/dedicated_log_torch_trace_<random string>_<random string>.log

@crcrpar crcrpar requested review from mruberry, lantiga and t-vi as code owners June 3, 2025 19:24
@crcrpar crcrpar added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Jun 3, 2025
@crcrpar

This comment was marked as outdated.

@crcrpar crcrpar force-pushed the logging-for-torch_trace-tlparse branch from 84fb0db to 05108cc Compare June 4, 2025 12:40
@crcrpar crcrpar marked this pull request as draft June 13, 2025 09:26
@crcrpar crcrpar force-pushed the logging-for-torch_trace-tlparse branch from 05108cc to 6d8b13c Compare August 12, 2025 18:39
crcrpar

This comment was marked as resolved.

"encoding": "string",
},
payload_fn=lambda: f"{trace_to_store}\n",
compile_id=compile_options.get("torch_compile_compile_id", None),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implicit dependency on compile_options should be improved.

@@ -114,8 +127,29 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor

# The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections
# and unsupported sections which are passed to `torch.compile(backend='inductor')`
split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args)
split_module, subgraph_info = _splitter(gm, wrapped_thunder_jit, self._torch_compile, sample_args)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@crcrpar
Copy link
Collaborator Author

crcrpar commented Aug 12, 2025

This gist might be helpful https://gist.github.com/crcrpar/cfdc7dbb499cbe6b327f04be5856f078 about CompileID

@crcrpar crcrpar requested a review from Copilot August 12, 2025 20:20
@crcrpar crcrpar changed the title Integrate trace_structured and trace_structured_artifact into ThunderCompiler Integrate trace_structured and trace_structured_artifact into ThunderCompiler to use TORCH_TRACE and tlparse Aug 12, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR integrates structured tracing functionality from PyTorch's torch._logging._internal into Thunder's compiler and splitter components to enable better debugging and analysis through tlparse. The changes add trace artifacts at key compilation and splitting stages to provide visibility into Thunder's internal operations.

Key changes:

  • Added structured tracing to capture graph modules at various stages (original, post-checkpoint conversion, split graphs)
  • Added tracing for split reasons and unsupported nodes/contexts
  • Integrated Thunder execution traces (computation, prologue, epilogue) with torch compile ID tracking

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
thunder/dynamo/splitter.py Adds trace artifacts for unsupported nodes/contexts and original/processed graph modules during splitting
thunder/dynamo/compiler.py Adds tracing for original and split graphs, split reasons, and torch compile ID integration
thunder/init.py Adds structured tracing for Thunder execution traces with compile ID support

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@crcrpar crcrpar marked this pull request as ready for review August 13, 2025 07:05
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have an example of end-to-end workflow in the PR description of how to use it with tlparse. It would be helpful for those (like me) who are not familiar with it.

@@ -428,6 +438,22 @@ def acquire_initial_trace(fn, args, kwargs, cd, cs, ad_hoc_executor):
last_interpreter_log = jit_results.interpreter_log
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))

for name_in_artifact, trace_to_store in (
Copy link
Collaborator

@kshitij12345 kshitij12345 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should only invoke this logging function (_trace_structured) if logging is specified by the user. This will also prevent failures on main path if this internal API changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. _log_to_torch_trace could be _maybe_log_to_torch_trace which incorporates the check on the user specifications.

Another thought:

def _log_to_torch_trace(
    string_format:str,
    trace_tuples: list[tuple],
    compile_id,
):
    for trc, *format_args in trace_tuples:
        if trc is None:
            continue
        name = string_format.format(*format_args)
        helper(name, trc, compile_id)

where helper is _log_to_torch_trace as defined below would allow these lines to be replaced with

trace_tuples = [
    (computation_trc, "computation"), ...
]
_log_to_torch_trace(
    "thunder_module_initial_{}_trc",
    trace_tuples,
    compile_id,
)

Maybe this isn't so much shorter, but it is less indentation. But up to you @crcrpar, whichever you find more readable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not that inclined to update the helper to take a list of tuples of trace and name. That indentation feels fine to me

@@ -524,7 +550,23 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
if requires_grad:
from thunder.transforms.autodiff import grad_transform_on_trace

_trace_structured(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a helper function log_thunder_trace (or something like that). It would be easier to read and also to use.

@@ -183,9 +207,30 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders
)
example_input_metadatas.append(list(example_input_metadata))

trace_structured_artifact(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a helper called log_fx_graph for readability and usage.

@crcrpar crcrpar changed the title Integrate trace_structured and trace_structured_artifact into ThunderCompiler to use TORCH_TRACE and tlparse [thunderfx] Integrate trace_structured and trace_structured_artifact into ThunderCompiler to use TORCH_TRACE and tlparse Aug 13, 2025
@crcrpar
Copy link
Collaborator Author

crcrpar commented Aug 13, 2025

Oops, I noticed the current implementation only saves the last set of execution traces of a GraphModule. So if a graph module is chunked into thunder_0 -> inductor_1 -> thunder_2, then only the traces for thunder_2 seem to be saved.

@crcrpar crcrpar force-pushed the logging-for-torch_trace-tlparse branch from 2bd6c36 to ec7fb0c Compare August 17, 2025 16:24
@t-vi
Copy link
Collaborator

t-vi commented Aug 27, 2025

@kshitij12345 @kiya00 @beverlylytle @riccardofelluga @IvanYashchuk any of you wanting to review this? (or someone else)

crcrpar and others added 9 commits September 3, 2025 13:17
as the former would be more Python-like.

Signed-off-by: Masaki Kozuki <[email protected]>
also, propagating `CompileID` of torch.compile from
`ThunderCompiler.__call__` to `thunder.jit` as compile id does not seem
to be available when saving those artifacts

Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
crcrpar and others added 16 commits September 3, 2025 13:17
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the logging-for-torch_trace-tlparse branch from acf6999 to b48e9fe Compare September 3, 2025 04:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants