-
Notifications
You must be signed in to change notification settings - Fork 376
Multi-tier checkpointing + orbax replicator #1332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_integer( | ||
"assume_data_parallelism", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
future follow up. I think orbax has a way of figuring this out automatically since it also needs to know this info. Orbax requires you to specify the batch dimension afair so it can know this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_integer( | ||
"assume_data_parallelism", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.
Integrate multi-tier checkpointer + orbax replicator into axlearn