Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/lightning/pytorch/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,21 @@ def summarize(
from rich.table import Table

console = get_console()
column_names = list(zip(*summary_data))[0]

header_style: str = summarize_kwargs.get("header_style", "bold magenta")
table = Table(header_style=header_style)
table.add_column(" ", style="dim")
table.add_column("Name", justify="left", no_wrap=True)
table.add_column("Type")
table.add_column("Params", justify="right")

if "Params per Device" in column_names:
table.add_column("Params per Device", justify="right")

table.add_column("Mode")
table.add_column("FLOPs", justify="right")

column_names = list(zip(*summary_data))[0]

for column_name in ["In sizes", "Out sizes"]:
if column_name in column_names:
table.add_column(column_name, justify="right", style="white")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
("Params", list(map(get_human_readable_count, self.param_nums))),
("Params per Device", list(map(get_human_readable_count, self.parameters_per_layer))),
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))),
]
if self._model.example_input_array is not None:
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))
Expand Down
39 changes: 39 additions & 0 deletions tests/tests_pytorch/utilities/test_deepspeed_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import torch

import lightning.pytorch as pl
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
Expand Down Expand Up @@ -51,3 +55,38 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
)

trainer.fit(model)


@RunIf(min_cuda_gpus=1, deepspeed=True, rich=True)
@mock.patch("rich.table.Table.add_row", autospec=True)
def test_deepspeed_summary_with_rich_model_summary(mock_table_add_row, tmp_path):
from lightning.pytorch.callbacks import RichModelSummary

model = BoringModel()
model.example_input_array = torch.randn(4, 32)

trainer = Trainer(
strategy=DeepSpeedStrategy(stage=3),
default_root_dir=tmp_path,
accelerator="gpu",
fast_dev_run=True,
devices=1,
enable_model_summary=True,
callbacks=[RichModelSummary()],
)

trainer.fit(model)

# assert that the input summary data was converted correctly
args, _ = mock_table_add_row.call_args_list[0]
assert args[1:] == (
"0",
"layer",
"Linear",
"66 ",
"66 ",
"train",
"512 ",
"[4, 32]",
"[4, 2]",
)
Loading