Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ set -ex
# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
# NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"4"}
# export LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
export LOG_RANK=${LOG_RANK:-0,1,2,3}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}

Expand Down
8 changes: 4 additions & 4 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.llama3.infra.pipeline import pipeline_llama
from torchtitan.models.llama3.infra.pipeline import pipeline_llama, pipeline_llama_tracer
from torchtitan.models.moe import MoEArgs

from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
Expand All @@ -32,10 +32,10 @@
deepseekv3_configs = {
"debugmodel": DeepSeekV3ModelArgs(
vocab_size=2000,
dim=256,
dim=4,
inter_dim=1024,
moe_inter_dim=256,
n_layers=6,
n_layers=16,
n_dense_layers=1,
n_heads=16,
moe_args=MoEArgs(
Expand Down Expand Up @@ -166,7 +166,7 @@
model_cls=DeepSeekV3Model,
model_args=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=pipeline_llama,
pipelining_fn=pipeline_llama_tracer,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
Expand Down
19 changes: 10 additions & 9 deletions torchtitan/models/deepseek_v3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ description = "DeepSeek-V3 debug training"
print_args = false

[profiling]
enable_profiling = false
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
profile_freq = 5
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

Expand Down Expand Up @@ -36,22 +36,23 @@ decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
local_batch_size = 10
seq_len = 4
max_norm = 1.0 # grad norm clipping
steps = 10
steps = 6
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
# dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
pipeline_parallel_schedule = "1F1B"
pipeline_parallel_degree = 2
expert_parallel_degree = 2
context_parallel_degree = 1
expert_parallel_degree = 1
pipeline_parallel_schedule = "DualPipeV"
expert_tensor_parallel_degree = 1

[checkpoint]
Expand All @@ -63,7 +64,7 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
mode = "none" # ["none", "selective", "full"]
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
Expand Down
75 changes: 75 additions & 0 deletions torchtitan/models/llama3/infra/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
pipeline_module_split,
)

from torch.distributed.pipelining import SplitPoint, pipeline
from torch.distributed.pipelining.stage import _PipelineStage

from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction
from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -148,3 +151,75 @@ def pipeline_llama(
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage


def pipeline_llama_tracer(
model: nn.Module,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: torch.device,
model_args: BaseModelArgs,
parallelize_fn: ParallelizeFunction,
loss_fn: LossFunction,
):
assert (
parallel_dims.pp_enabled
), "can't apply pipeline parallelism if it is not enabled"

# if job_config.model.norm_type == "fused_rmsnorm":
# # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
# # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
# raise NotImplementedError(
# "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
# )
pp_mesh = parallel_dims.world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
stage_idx = pp_mesh.get_local_rank()
layers_per_rank = model_args.n_layers // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
}
# Get example input
input_shape = (job_config.training.local_batch_size, job_config.training.seq_len)
assert hasattr(model_args, "vocab_size")
input_ids = torch.randint(
model_args.vocab_size, input_shape, dtype=torch.int64, device="meta"
)

# Create a pipeline representation from the model
pipe = pipeline(
model, mb_args=(input_ids,), split_spec=split_spec
)
model = pipe.get_stage_module(stage_idx)
stage = _PipelineStage(
stage_module=model,
stage_index=pp_rank,
pipe_info=pipe.pipe_info,
device=device,
group=pp_mesh.get_group(),
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(m, parallel_dims, job_config)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage
3 changes: 2 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ def close(self) -> None:
if self.metrics_processor:
self.metrics_processor.close()


import fbvscode
fbvscode.attach_debugger()
if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
Expand Down
Loading