diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 82d80d2cbfda9..6f6e1f156d169 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -86,6 +86,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253)) +- Fixed the `XLAProfiler` not recording anything due to mismatching of action names ([#15885](https://github.com/Lightning-AI/lightning/pull/15885)) + + ## [1.8.4] - 2022-12-08 ### Changed diff --git a/src/pytorch_lightning/profilers/xla.py b/src/pytorch_lightning/profilers/xla.py index ef103a9a45842..4bfefbc0bacbb 100644 --- a/src/pytorch_lightning/profilers/xla.py +++ b/src/pytorch_lightning/profilers/xla.py @@ -50,12 +50,14 @@ def __init__(self, port: int = 9012) -> None: def start(self, action_name: str) -> None: import torch_xla.debug.profiler as xp - if action_name in self.RECORD_FUNCTIONS: + # The action name is formatted as '[TYPE]{class name}.{hook name}' + # Example: [LightningModule]BoringModel.training_step + if action_name.split(".")[-1] in self.RECORD_FUNCTIONS: if not self._start_trace: self.server = xp.start_server(self.port) self._start_trace = True - if action_name in self.STEP_FUNCTIONS: + if action_name.split(".")[-1] in self.STEP_FUNCTIONS: step = self._get_step_num(action_name) recording = xp.StepTrace(action_name, step_num=step) else: