Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions requirements/fabric/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ pytest-cov ==6.3.0
pytest-timeout ==2.4.0
pytest-rerunfailures ==16.0.1
pytest-random-order ==1.2.0
click ==8.1.8; python_version < "3.11"
click ==8.2.1; python_version > "3.10"
jsonargparse[signatures,jsonnet] >=4.39.0, <4.41.0
tensorboardX >=2.6, <2.7.0 # todo: relax it back to `>=2.2` after fixing tests
173 changes: 97 additions & 76 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import logging
import os
import re
import sys
from argparse import Namespace
from typing import Any, Optional
from typing import Optional

import torch
from lightning_utilities.core.imports import RequirementCache
Expand All @@ -31,9 +32,12 @@

_log = logging.getLogger(__name__)

_CLICK_AVAILABLE = RequirementCache("click")
_JSONARGPARSE_AVAILABLE = RequirementCache("jsonargparse")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")

if _JSONARGPARSE_AVAILABLE:
from jsonargparse import ArgumentParser

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto")


Expand All @@ -45,127 +49,112 @@ def _get_supported_strategies() -> list[str]:
return [strategy for strategy in available_strategies if not re.match(excluded, strategy)]


if _CLICK_AVAILABLE:
import click
def _build_parser() -> "ArgumentParser":
"""Build the jsonargparse-based CLI parser with subcommands."""
if not _JSONARGPARSE_AVAILABLE: # pragma: no cover
raise RuntimeError(
"To use the Lightning Fabric CLI, you must have `jsonargparse` installed. "
"Install it by running `pip install -U jsonargparse`."
)

@click.group()
def _main() -> None:
pass
parser = ArgumentParser(description="Lightning Fabric command line tool")
subcommands = parser.add_subcommands()

@_main.command(
"run",
context_settings={
"ignore_unknown_options": True,
},
)
@click.argument(
"script",
type=click.Path(exists=True),
)
@click.option(
# run subcommand
run_parser = ArgumentParser(description="Run a Lightning Fabric script.")
run_parser.add_argument(
"--accelerator",
type=click.Choice(_SUPPORTED_ACCELERATORS),
type=str,
choices=_SUPPORTED_ACCELERATORS,
default=None,
help="The hardware accelerator to run on.",
)
@click.option(
run_parser.add_argument(
"--strategy",
type=click.Choice(_get_supported_strategies()),
type=str,
choices=_get_supported_strategies(),
default=None,
help="Strategy for how to run across multiple devices.",
)
@click.option(
run_parser.add_argument(
"--devices",
type=str,
default="1",
help=(
"Number of devices to run on (``int``), which devices to run on (``list`` or ``str``), or ``'auto'``."
" The value applies per node."
"Number of devices to run on (int), which devices to run on (list or str), or 'auto'. "
"The value applies per node."
),
)
@click.option(
"--num-nodes",
run_parser.add_argument(
"--num_nodes",
"--num-nodes",
type=int,
default=1,
help="Number of machines (nodes) for distributed execution.",
)
@click.option(
"--node-rank",
run_parser.add_argument(
"--node_rank",
"--node-rank",
type=int,
default=0,
help=(
"The index of the machine (node) this command gets started on. Must be a number in the range"
" 0, ..., num_nodes - 1."
"The index of the machine (node) this command gets started on. Must be a number in the range "
"0, ..., num_nodes - 1."
),
)
@click.option(
"--main-address",
run_parser.add_argument(
"--main_address",
"--main-address",
type=str,
default="127.0.0.1",
help="The hostname or IP address of the main machine (usually the one with node_rank = 0).",
)
@click.option(
"--main-port",
run_parser.add_argument(
"--main_port",
"--main-port",
type=int,
default=29400,
help="The main port to connect to the main machine.",
)
@click.option(
run_parser.add_argument(
"--precision",
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
type=str,
choices=list(get_args(_PRECISION_INPUT_STR)) + list(get_args(_PRECISION_INPUT_STR_ALIAS)),
default=None,
help=(
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``32``), "
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
"Double precision ('64-true' or '64'), full precision ('32-true' or '32'), "
"half precision ('16-mixed' or '16') or bfloat16 precision ('bf16-mixed' or 'bf16')."
),
)
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)
def _run(**kwargs: Any) -> None:
"""Run a Lightning Fabric script.

SCRIPT is the path to the Python script with the code to run. The script must contain a Fabric object.

SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed
there.

"""
script_args = list(kwargs.pop("script_args", []))
main(args=Namespace(**kwargs), script_args=script_args)
run_parser.add_argument(
"script",
type=str,
help="Path to the Python script with the code to run. The script must contain a Fabric object.",
)
subcommands.add_subcommand("run", run_parser, help="Run a Lightning Fabric script")

@_main.command(
"consolidate",
context_settings={
"ignore_unknown_options": True,
},
# consolidate subcommand
con_parser = ArgumentParser(
description="Convert a distributed/sharded checkpoint into a single file that can be loaded with torch.load()."
)
@click.argument(
con_parser.add_argument(
"checkpoint_folder",
type=click.Path(exists=True),
type=str,
help="Path to the checkpoint folder to consolidate.",
)
@click.option(
con_parser.add_argument(
"--output_file",
type=click.Path(exists=True),
type=str,
default=None,
help=(
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
" and a '.consolidated' suffix."
"Path to the file where the converted checkpoint should be saved. The file should not already exist. "
"If not provided, the file will be saved next to the input checkpoint folder with the same name and a "
"'.consolidated' suffix."
),
)
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None:
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.

Only supports FSDP sharded checkpoints at the moment.
subcommands.add_subcommand("consolidate", con_parser, help="Consolidate a distributed checkpoint")

"""
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
config = _process_cli_args(args)
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
torch.save(checkpoint, config.output_file)
return parser


def _set_env_variables(args: Namespace) -> None:
Expand Down Expand Up @@ -234,12 +223,44 @@ def main(args: Namespace, script_args: Optional[list[str]] = None) -> None:
_torchrun_launch(args, script_args or [])


if __name__ == "__main__":
if not _CLICK_AVAILABLE: # pragma: no cover
def _run_command(cfg: Namespace, script_args: list[str]) -> None:
"""Execute the 'run' subcommand with the provided config and extra script args."""
main(args=Namespace(**cfg), script_args=script_args)


def _consolidate_command(cfg: Namespace) -> None:
"""Execute the 'consolidate' subcommand with the provided config."""
args = Namespace(checkpoint_folder=cfg.checkpoint_folder, output_file=cfg.output_file)
config = _process_cli_args(args)
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
torch.save(checkpoint, config.output_file)


def cli_main(argv: Optional[list[str]] = None) -> None:
"""Entry point for the Fabric CLI using jsonargparse."""
if not _JSONARGPARSE_AVAILABLE: # pragma: no cover
_log.error(
"To use the Lightning Fabric CLI, you must have `click` installed."
" Install it by running `pip install -U click`."
"To use the Lightning Fabric CLI, you must have `jsonargparse` installed."
" Install it by running `pip install -U jsonargparse`."
)
raise SystemExit(1)

_run()
parser = _build_parser()
# parse_known_args so that for 'run' we can forward unknown args to the user script
cfg, unknown = parser.parse_known_args(argv)

if not getattr(cfg, "subcommand", None):
parser.print_help()
return

if cfg.subcommand == "run":
# unknown contains the script's own args
_run_command(cfg.run, unknown)
elif cfg.subcommand == "consolidate":
_consolidate_command(cfg.consolidate)
else: # pragma: no cover
parser.print_help()


if __name__ == "__main__":
cli_main(sys.argv[1:])
Loading
Loading