Skip to content

Conversation

ehorning
Copy link
Contributor

@ehorning ehorning commented Aug 12, 2025

Integrate multi-tier checkpointer + orbax replicator into axlearn

@ehorning ehorning marked this pull request as ready for review September 16, 2025 19:39
@ehorning ehorning requested review from a team as code owners September 16, 2025 19:39
FLAGS = flags.FLAGS

flags.DEFINE_integer(
"assume_data_parallelism",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future follow up. I think orbax has a way of figuring this out automatically since it also needs to know this info. Orbax requires you to specify the batch dimension afair so it can know this.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.

FLAGS = flags.FLAGS

flags.DEFINE_integer(
"assume_data_parallelism",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxText sets it to the number of slices. However it may not be correct if there is intra-slice DDP, so we plan to make it configurable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants