Skip to content

Commit 8926db9

Browse files
Move checkpoint resharding test to checkpointing DAG (#869)
1 parent 528cda8 commit 8926db9

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

dags/multipod/legacy.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,3 @@
9595
docker_image=DOCKER_IMAGE[test_mode].value,
9696
test_owner=test_owner.MOHIT_K,
9797
).run()
98-
99-
# v4-8 2 slices checkpoint resharding test
100-
gke_config.get_gke_config(
101-
num_slices=2,
102-
time_out_in_min=60,
103-
test_name=f"maxtext-checkpoint-reshard-{SetupMode.STABLE.value}",
104-
run_model_cmds=(
105-
f"bash end_to_end/tpu/test_checkpoint_resharding.sh xlml-checkpoint-resharding-v4-8-2slice-{SetupMode.STABLE.value} gs://maxtext-xlml gs://maxtext-xlml/dataset",
106-
),
107-
docker_image=DOCKER_IMAGE[SetupMode.STABLE].value,
108-
test_owner=test_owner.PRIYANKA_G,
109-
).run()

dags/multipod/maxtext_checkpointing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,18 @@
6767
docker_image=image.value,
6868
test_owner=test_owner.SURBHI_J,
6969
).run()
70+
71+
# Checkpoint resharding test - trains a model with a specific sharding strategy and saves a checkpoint.
72+
# Then train again by restoring this checkpoint using a different sharding strategy.
73+
# Finally, asserts that the learning metrics are consistent, ensuring that checkpoints can be successfully loaded across different sharding strategies.
74+
gke_config.get_gke_config(
75+
num_slices=2,
76+
cluster=XpkClusters.TPU_V5P_8_CLUSTER,
77+
time_out_in_min=60,
78+
test_name=f"maxtext-checkpoint-resharding-{mode.value}",
79+
run_model_cmds=(
80+
f"bash end_to_end/tpu/test_checkpoint_resharding.sh checkpoint-resharding-{mode.value} {base_output_directory} {dataset_path}",
81+
),
82+
docker_image=image.value,
83+
test_owner=test_owner.SURBHI_J,
84+
).run()

0 commit comments

Comments
 (0)