-
Notifications
You must be signed in to change notification settings - Fork 523
Support llama3 autoparallel + pipelining #1657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: autoparallel
Are you sure you want to change the base?
Conversation
assert parallel_dims.cp_enabled is False, "CP not supported yet" | ||
assert parallel_dims.pp_enabled is False, "PP not supported yet" | ||
|
||
pp_degree = job_config.parallelism.pipeline_parallel_degree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused pp degree config, should probably raise error when its not local world size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i deleted it (it was unused/unneeded). I don't think we need to raise any error. pp_degree does not need to equal any particular size, and pp can even be disabled.
spmd_dims.append("tp") | ||
spmd_mesh = world_mesh[spmd_dims] | ||
|
||
dp_degree = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, config could specify dp_degree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputs, target=targets, losses=losses, input_batch=inputs | ||
# TODO: input_batch kwarg only needed for CP, but | ||
# autoparallel doesn't accept kwargs in its forward | ||
inputs, target=targets, losses=losses #, input_batch=inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, why does CP need input_batch
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assumed you would know. Am I wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, there was a change to remove the need of input_batch
. We may want to do the same change to autoparalle.
|
||
pp_degree = job_config.parallelism.pipeline_parallel_degree | ||
local_batch_size = job_config.training.local_batch_size | ||
spmd_batch_size = local_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops this is a bug for the non-pp case. should be local *dp degree
and put in an 'else' branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
self.pp_schedule.step( | ||
inputs, target=targets, losses=losses, input_batch=inputs | ||
# TODO: input_batch kwarg only needed for CP, but | ||
# autoparallel doesn't accept kwargs in its forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just fix this LOL
# # step. | ||
# dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard | ||
# global_batch_size = job_config.training.local_batch_size * dp_degree | ||
if parallel_dims.pp_enabled and pp_rank > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What a mess. No action here needed, but it's definitely worth thinking about what the terminal UX state here should be.
so far just tested locally `LOG_RANK=4 CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.pipeline_parallel_degree 2 --training.steps 100` Runs and loss converges. Left one TODO about global-batch-size and gradient accumulation
7e60a61
to
188b002
Compare
so far just tested locally
LOG_RANK=4 CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.pipeline_parallel_degree 2 --training.steps 100
Runs and loss converges.
Left one TODO about global-batch-size and gradient accumulation