Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016))
- fix `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809))


---
Expand Down
10 changes: 6 additions & 4 deletions src/lightning/pytorch/profilers/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import pstats
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Optional, Union

Expand Down Expand Up @@ -66,14 +67,15 @@ def __init__(
If you attempt to stop recording an action which was never started.
"""
super().__init__(dirpath=dirpath, filename=filename)
self.profiled_actions: dict[str, cProfile.Profile] = {}
self.profiled_actions: dict[str, cProfile.Profile] = defaultdict(cProfile.Profile)
self.line_count_restriction = line_count_restriction
self.dump_stats = dump_stats

@override
def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
# Disable all profilers before starting a new one
for pr in self.profiled_actions.values():
pr.disable()
self.profiled_actions[action_name].enable()

@override
Expand Down Expand Up @@ -114,7 +116,7 @@ def summary(self) -> str:
@override
def teardown(self, stage: Optional[str]) -> None:
super().teardown(stage=stage)
self.profiled_actions = {}
self.profiled_actions.clear()

def __reduce__(self) -> tuple:
# avoids `TypeError: cannot pickle 'cProfile.Profile' object`
Expand Down
6 changes: 6 additions & 0 deletions tests/tests_pytorch/profilers/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,12 @@ def test_advanced_profiler_deepcopy(advanced_profiler):
assert deepcopy(advanced_profiler)


def test_advanced_profiler_nested(advanced_profiler):
"""Ensure AdvancedProfiler does not raise ValueError for nested profiling actions (Python 3.12+ compatibility)."""
with advanced_profiler.profile("outer"), advanced_profiler.profile("inner"):
pass # Should not raise ValueError


@pytest.fixture
def pytorch_profiler(tmp_path):
return PyTorchProfiler(dirpath=tmp_path, filename="profiler")
Expand Down
Loading