Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2d618ec
integrate `trace_structured` & `trace_structured_artifact`
crcrpar Jun 3, 2025
c34db96
use `GraphModule.print_readable` not `str(GraphModule.graph)`
crcrpar Jun 3, 2025
8f11a06
(ab)use `trace_structured_artifact` to make split reasons handled by …
crcrpar Jun 4, 2025
04f64b8
Store execution prologue, computation, and epilogue traces
crcrpar Aug 12, 2025
7441d67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2025
a140a17
use `trace_structured` to specify `compile_id`
crcrpar Aug 12, 2025
ddced44
removing `trace_structured` as it requires `tlparse` customization
crcrpar Aug 12, 2025
8e26859
remove trace_structured
crcrpar Aug 12, 2025
032e109
clean up following `trace_structured` removal
crcrpar Aug 12, 2025
2c98d8a
fix name of GraphModules inside splitter
crcrpar Aug 12, 2025
d3b5378
Apply suggestions from code review of copilot
crcrpar Aug 12, 2025
e6209e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2025
5a9f1e5
deduplicate `metadata_fn`
crcrpar Aug 12, 2025
c2eb191
avoid redefinition of trace
crcrpar Aug 12, 2025
b722c0d
store backward_trc if available
crcrpar Aug 13, 2025
23dd2bb
store trace before/after grad_transform
crcrpar Aug 13, 2025
2482974
check `compile_id` arg
crcrpar Aug 13, 2025
5280335
add todo comment
crcrpar Aug 13, 2025
6e1689a
remove `compile_id` from `_trace_structured` callsite
crcrpar Aug 13, 2025
2bb659e
fix
crcrpar Aug 13, 2025
c64b83e
[no ci] todo: use node index
crcrpar Aug 14, 2025
ac7b851
propagate chunk index of GraphModule
crcrpar Aug 14, 2025
aa87d1e
wrap trace_structured for trace/graph-module
crcrpar Aug 18, 2025
59765a6
fix typos
crcrpar Aug 18, 2025
b48e9fe
clean up
crcrpar Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections import defaultdict, namedtuple
from collections.abc import Callable, Sequence
from contextvars import ContextVar
from functools import wraps
from functools import partial, wraps
from typing import Any
import dis
import inspect
import os
import time
import warnings
Expand Down Expand Up @@ -56,6 +57,7 @@
wrap_return_value_together_with_arguments,
)
from thunder.core.update_aliases import insert_alias_updates
from thunder.dynamo._trace_structured import _log_to_torch_trace
from thunder.executors.torch_autograd import connect_to_autograd
import thunder.extend as extend
from thunder.extend import Executor, add_default_executor
Expand Down Expand Up @@ -436,6 +438,19 @@ def acquire_initial_trace(fn, args, kwargs, cd, cs, ad_hoc_executor):
last_interpreter_log = jit_results.interpreter_log
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))

for name_in_artifact, trace_to_store in (
Copy link
Collaborator

@kshitij12345 kshitij12345 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should only invoke this logging function (_trace_structured) if logging is specified by the user. This will also prevent failures on main path if this internal API changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. _log_to_torch_trace could be _maybe_log_to_torch_trace which incorporates the check on the user specifications.

Another thought:

def _log_to_torch_trace(
    string_format:str,
    trace_tuples: list[tuple],
    compile_id,
):
    for trc, *format_args in trace_tuples:
        if trc is None:
            continue
        name = string_format.format(*format_args)
        helper(name, trc, compile_id)

where helper is _log_to_torch_trace as defined below would allow these lines to be replaced with

trace_tuples = [
    (computation_trc, "computation"), ...
]
_log_to_torch_trace(
    "thunder_module_initial_{}_trc",
    trace_tuples,
    compile_id,
)

Maybe this isn't so much shorter, but it is less indentation. But up to you @crcrpar, whichever you find more readable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not that inclined to update the helper to take a list of tuples of trace and name. That indentation feels fine to me

("computation", computation_trc),
("prologue", prologue_trc),
("epilogue", epilogue_trc),
):
if trace_to_store is None:
continue
_log_to_torch_trace(
f"thunder_module_initial_{name_in_artifact}_trc",
trace_to_store,
compile_id=compile_options.get("torch_compile_compile_id", None),
)
return prologue_trc, computation_trc, epilogue_trc

def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, computation_trc, epilogue_trc):
Expand Down Expand Up @@ -532,7 +547,17 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
if requires_grad:
from thunder.transforms.autodiff import grad_transform_on_trace

_log_to_torch_trace(
"thunder_module_computation_trc_before_grad_transform",
computation_trc,
compile_id=compile_options.get("torch_compile_compile_id", None),
)
computation_trc = grad_transform_on_trace(computation_trc)
_log_to_torch_trace(
"thunder_module_computation_trc_after_grad_transform",
computation_trc,
compile_id=compile_options.get("torch_compile_compile_id", None),
)

from thunder.executors.passes import _transform_for_operator_executor_execution
from thunder.distributed.utils import maybe_sort_waits
Expand Down Expand Up @@ -598,6 +623,22 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
computation_trc = transform_to_torch_types(computation_trc)
comp = computation_trc.python_callable()

for name_in_artifact, trace_to_store in (
("computation", computation_trc),
("prologue", prologue_trc),
("epilogue", epilogue_trc),
("backward", backward_trc),
):
if trace_to_store is None:
continue

_idx_of_graph_module = compile_options.get("graph_module_idx", 0)
_log_to_torch_trace(
f"thunder_module_execution_{name_in_artifact}_trc_of_module_{_idx_of_graph_module}",
trace_to_store,
compile_id=compile_options.get("torch_compile_compile_id", None),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implicit dependency on compile_options should be improved.

)

# TODO RC1 Update the cache
cache_entry = CacheEntry(
pro,
Expand Down
59 changes: 59 additions & 0 deletions thunder/dynamo/_trace_structured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import inspect

from torch.fx import GraphModule
from torch._logging._internal import trace_structured
from torch._logging._internal import trace_structured_artifact

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from torch._guards import CompileId


_SUPPORT_COMPILE_ID_KWARG: bool = "compile_id" in inspect.signature(trace_structured).parameters


def payload_fn_of(fn: GraphModule | Callable[[Any], Any]) -> Callable[[], str]:
if isinstance(fn, GraphModule):

def f() -> str:
return fn.print_readable(
print_output=False,
include_stride=True,
include_device=True,
)

return f

def f() -> str:
return f"{fn}\n"

return f


# TODO: use `trace_structured_artifact` once `compile_id` is merged.
# https://github.com/pytorch/pytorch/pull/160440.
# note: `compile_id` is a kwarg since v2.7.0.
def _log_to_torch_trace(
name: str,
fn: GraphModule | Callable[[Any], Any],
compile_id: CompileId | None = None,
) -> None:
payload_fn = payload_fn_of(fn)
if compile_id is not None and _SUPPORT_COMPILE_ID_KWARG:
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": name,
"encoding": "string",
},
payload_fn=payload_fn,
)
else:
trace_structured_artifact(
name=name,
encoding="string",
payload_fn=payload_fn,
)
28 changes: 26 additions & 2 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy

import torch
from torch._logging._internal import trace_structured_artifact

from thunder.dynamo.utils import (
recompile_graph,
Expand All @@ -22,6 +23,7 @@
)
from thunder.dynamo.splitter import _splitter
from thunder.dynamo.benchmark_utils import ThunderCompileSpecification
from thunder.dynamo._trace_structured import _log_to_torch_trace
from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform

if TYPE_CHECKING:
Expand Down Expand Up @@ -106,16 +108,38 @@ def __init__(self, **thunder_options):
self._torch_compile = torch.compile

def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
remove_empty_autocast(gm)
torch_compile_compile_id = torch._guards.CompileContext.current_compile_id()
thunder_options = {"torch_compile_compile_id": torch_compile_compile_id, **self.thunder_options}

_log_to_torch_trace("thunder_original_graph", gm)

# Dynamo uses lazy generation of the underlying Python code, so we need to
# force recompilation of the GraphModule before passing it to Thunder.
recompile_graph(gm)

# The whole graph may not be supported by `thunder`, so we split it in `thunder` supported sections
# and unsupported sections which are passed to `torch.compile(backend='inductor')`
split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args)
split_module, subgraph_info = _splitter(
gm,
self._thunder_jit,
self._torch_compile,
sample_args,
thunder_options=thunder_options,
)
self.subgraph_infos.append(subgraph_info)

_log_to_torch_trace("thunder_split_graph", split_module)

if subgraph_info.split_reasons:
trace_structured_artifact(
name="thunder_split_reasons",
encoding="json",
payload_fn=lambda: [
{"reason_type": reason.reason_type.name, "info": reason.info, "exception": reason.exception}
for reason in subgraph_info.split_reasons
],
)

return split_module

def save_reproducer_to_folder(
Expand Down
53 changes: 50 additions & 3 deletions thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch.fx.passes.split_module import split_module
from torch._logging._internal import trace_structured_artifact

from thunder.dynamo.utils import (
SubgraphInfo,
Expand All @@ -20,8 +21,10 @@
_get_example_inputs_from_placeholder,
_ThunderSplitGraphModule,
)
from thunder.dynamo._trace_structured import _log_to_torch_trace

if TYPE_CHECKING:
from typing import Any
from collections.abc import Callable


Expand All @@ -30,6 +33,8 @@ def _splitter(
thunder_jit: Callable,
torch_inductor: Callable,
_unused_sample_args: list[torch.SymInt, torch.Tensor],
*,
thunder_options: dict[str, Any] = {},
) -> tuple[torch.fx.GraphModule, SubgraphInfo]:
"""
This method will split graph into multiple graph modules based on thunder supported operations.
Expand Down Expand Up @@ -118,11 +123,34 @@ def callback(node) -> int:
info=f"node with name: {node.name} and target: {node.target} is not supported probably because it is in unsupported context.",
)
split_reasons.append(split_reason)

trace_structured_artifact(
name="thunder_unsupported_ctx_regions",
encoding="json",
payload_fn=lambda n=node, r=split_reason: {
"node_name": n.name,
"node_target": str(n.target),
"reason_type": r.reason_type.name,
"reason_info": r.info,
},
)
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
if split_reason is not None:
split_reasons.append(split_reason)

trace_structured_artifact(
name="thunder_unsupported_node",
encoding="json",
payload_fn=lambda n=node, r=split_reason: {
"node_name": n.name,
"node_target": str(n.target),
"reason_type": r.reason_type.name,
"reason_info": r.info,
"exception": r.exception,
},
)

if prev_value == is_thunder_supported: # We are in the same region.
return partition_cnt

Expand All @@ -144,7 +172,7 @@ def callback(node) -> int:
gm.recompile()

# `split_module` iterates over nodes and determines the partition to place them based on the callback.
split_gm: torch.fx.GraphModule = split_module(
original_split_gm: torch.fx.GraphModule = split_module(
gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True
)

Expand All @@ -166,7 +194,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
thunder_compiled_fns = []
example_input_metadatas = []
submodule_to_compiled_fns = {}
for node in split_gm.graph.nodes:
for node_idx, node in enumerate(split_gm.graph.nodes):
node_name = node.name
if is_thunder_supported_partition(node):
graph_module = getattr(split_gm, node.name)
Expand All @@ -186,10 +214,27 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders
)
example_input_metadatas.append(list(example_input_metadata))

_log_to_torch_trace("thunder_module_original", graph_module)

# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
checkpoint_converter(split_gm, graph_module)

jit_fn = thunder_jit(graph_module, is_differentiable_outputs=is_differentiable_outputs)
_log_to_torch_trace("thunder_module_post_checkpoint_converter_applied", graph_module)

if not thunder_options:
jit_fn = thunder_jit(graph_module, is_differentiable_outputs=is_differentiable_outputs)
else:
from thunder import jit

jit_fn = jit(
graph_module,
**{
"is_differentiable_outputs": is_differentiable_outputs,
"graph_module_idx": node_idx,
**thunder_options,
},
)
# Update the node name from "submod_*" to "thunder_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn)
thunder_compiled_fns.append(jit_fn)
Expand All @@ -198,6 +243,8 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
)
elif node.name.startswith("submod"): # For inductor
graph_module = getattr(split_gm, node.name)
_log_to_torch_trace("inductor_module_original", graph_module)

jit_fn = torch_inductor(graph_module)
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)
Expand Down
Loading