Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))


- Fixed misalignment column while using rich model summary in `DeepSpeedstrategy` ([#21100](https://github.com/Lightning-AI/pytorch-lightning/pull/21100))

---

## [2.5.3] - 2025-08-13
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,21 @@ def summarize(
from rich.table import Table

console = get_console()
column_names = list(zip(*summary_data))[0]

header_style: str = summarize_kwargs.get("header_style", "bold magenta")
table = Table(header_style=header_style)
table.add_column(" ", style="dim")
table.add_column("Name", justify="left", no_wrap=True)
table.add_column("Type")
table.add_column("Params", justify="right")

if "Params per Device" in column_names:
table.add_column("Params per Device", justify="right")

table.add_column("Mode")
table.add_column("FLOPs", justify="right")

column_names = list(zip(*summary_data))[0]

for column_name in ["In sizes", "Out sizes"]:
if column_name in column_names:
table.add_column(column_name, justify="right", style="white")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
("Params", list(map(get_human_readable_count, self.param_nums))),
("Params per Device", list(map(get_human_readable_count, self.parameters_per_layer))),
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))),
]
if self._model.example_input_array is not None:
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))
Expand Down
39 changes: 39 additions & 0 deletions tests/tests_pytorch/utilities/test_deepspeed_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import torch

import lightning.pytorch as pl
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -51,3 +55,38 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
)

trainer.fit(model)


@RunIf(min_cuda_gpus=1, deepspeed=True, rich=True)
@mock.patch("rich.table.Table.add_row", autospec=True)
def test_deepspeed_summary_with_rich_model_summary(mock_table_add_row, tmp_path):
from lightning.pytorch.callbacks import RichModelSummary

model = BoringModel()
model.example_input_array = torch.randn(4, 32)

trainer = Trainer(
strategy=DeepSpeedStrategy(stage=3),
default_root_dir=tmp_path,
accelerator="gpu",
fast_dev_run=True,
devices=1,
enable_model_summary=True,
callbacks=[RichModelSummary()],
)

trainer.fit(model)

# assert that the input summary data was converted correctly
args, _ = mock_table_add_row.call_args_list[0]
assert args[1:] == (
"0",
"layer",
"Linear",
"66 ",
"66 ",
"train",
"512 ",
"[4, 32]",
"[4, 2]",
)
Loading