Skip to content

Commit e52d6c5

Browse files
authored
Fix TensorBoardLogger's validation of example input when logging graph (#15323)
1 parent be9f4ee commit e52d6c5

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4646

4747
### Fixed
4848

49-
-
49+
- Fixed `TensorBoardLogger` not validating the input array type when logging the model graph ([#15323](https://github.com/Lightning-AI/lightning/pull/15323))
5050

5151
-
5252

src/pytorch_lightning/core/module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
109109
self.precision: Union[int, str] = 32
110110

111111
# optionally can be set by user
112-
self._example_input_array = None
112+
self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None
113113
self._current_fx_name: Optional[str] = None
114114
self._automatic_optimization: bool = True
115115
self._truncated_bptt_steps: int = 0
@@ -189,7 +189,7 @@ def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
189189
self._trainer = trainer
190190

191191
@property
192-
def example_input_array(self) -> Any:
192+
def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]:
193193
"""The example input array is a specification of what the module can consume in the :meth:`forward` method.
194194
The return type is interpreted as follows:
195195
@@ -203,7 +203,7 @@ def example_input_array(self) -> Any:
203203
return self._example_input_array
204204

205205
@example_input_array.setter
206-
def example_input_array(self, example: Any) -> None:
206+
def example_input_array(self, example: Optional[Union[Tensor, Tuple, Dict]]) -> None:
207207
self._example_input_array = example
208208

209209
@property

src/pytorch_lightning/loggers/tensorboard.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -235,21 +235,27 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
235235

236236
@rank_zero_only
237237
def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None:
238-
if self._log_graph:
239-
if input_array is None:
240-
input_array = model.example_input_array
241-
242-
if input_array is not None:
243-
input_array = model._on_before_batch_transfer(input_array)
244-
input_array = model._apply_batch_transfer_handler(input_array)
245-
with pl.core.module._jit_is_scripting():
246-
self.experiment.add_graph(model, input_array)
247-
else:
248-
rank_zero_warn(
249-
"Could not log computational graph since the"
250-
" `model.example_input_array` attribute is not set"
251-
" or `input_array` was not given",
252-
)
238+
if not self._log_graph:
239+
return
240+
241+
input_array = model.example_input_array if input_array is None else input_array
242+
243+
if input_array is None:
244+
rank_zero_warn(
245+
"Could not log computational graph to TensorBoard: The `model.example_input_array` attribute"
246+
" is not set or `input_array` was not given."
247+
)
248+
elif not isinstance(input_array, (Tensor, tuple)):
249+
rank_zero_warn(
250+
"Could not log computational graph to TensorBoard: The `input_array` or `model.example_input_array`"
251+
f" has type {type(input_array)} which can't be traced by TensorBoard. Make the input array a tuple"
252+
f" representing the positional arguments to the model's `forward()` implementation."
253+
)
254+
else:
255+
input_array = model._on_before_batch_transfer(input_array)
256+
input_array = model._apply_batch_transfer_handler(input_array)
257+
with pl.core.module._jit_is_scripting():
258+
self.experiment.add_graph(model, input_array)
253259

254260
@rank_zero_only
255261
def save(self) -> None:

tests/tests_pytorch/loggers/test_tensorboard.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,13 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
227227
logger = TensorBoardLogger(tmpdir, log_graph=True)
228228
with pytest.warns(
229229
UserWarning,
230-
match="Could not log computational graph since the `model.example_input_array`"
231-
" attribute is not set or `input_array` was not given",
230+
match="Could not log computational graph to TensorBoard: The `model.example_input_array` .* was not given",
231+
):
232+
logger.log_graph(model)
233+
234+
model.example_input_array = dict(x=1, y=2)
235+
with pytest.warns(
236+
UserWarning, match="Could not log computational graph to TensorBoard: .* can't be traced by TensorBoard"
232237
):
233238
logger.log_graph(model)
234239

0 commit comments

Comments
 (0)