Skip to content

Conversation

Flink-ddd
Copy link
Contributor

@Flink-ddd Flink-ddd commented Jun 16, 2025

This PR fixes issue #7303.

1. Description of the Bug

Currently, when using the WarmupLR scheduler, if warmup_max_lr is not explicitly set in the scheduler's parameters, it incorrectly falls back to its internal default value (0.001), ignoring the learning rate set in the optimizer's parameters. This can lead to unexpected training behavior and diverges from user expectations.

2. Description of the Fix

This fix modifies the __init__ method of the WarmupLR scheduler in deepspeed/runtime/lr_schedules.py.

  • The default value for the warmup_max_lr argument in the function signature is changed from 0.001 to None.
  • Logic is added to check if warmup_max_lr is None upon initialization. If it is, the scheduler now correctly inherits the learning rate from the optimizer's parameter groups.

This change ensures that the optimizer's learning rate is respected as the default warmup_max_lr, aligning the scheduler's behavior with the user's configuration intent.

3. Verification

The fix has been verified using a minimal reproduction script that clearly demonstrates the behavioral change.

Before Fix:
Without warmup_max_lr in the scheduler config, the learning rate incorrectly defaults to 0.001.
Screenshot 2025-06-16 at 18 34 31

After Workaround (Demonstrating the Mechanism):
By explicitly adding warmup_max_lr to the scheduler config, the learning rate behaves as expected. My code change makes this the default behavior.
Screenshot 2025-06-16 at 20 17 11

@Flink-ddd Flink-ddd force-pushed the fix/warmup-lr-override branch 2 times, most recently from 1144e16 to d73e624 Compare June 16, 2025 12:34
@Flink-ddd Flink-ddd force-pushed the fix/warmup-lr-override branch 2 times, most recently from 1ccba7b to 04d33d3 Compare June 18, 2025 02:06
This commit ensures that the WarmupLR scheduler correctly inherits the
learning rate from the optimizer's parameters when `warmup_max_lr`
is not explicitly provided in the scheduler's configuration. This
prevents the scheduler from falling back to a hard-coded default value,
aligning its behavior with user expectations.

Fixes deepspeedai#7303

Signed-off-by: Vensenmu <[email protected]>
@Flink-ddd Flink-ddd force-pushed the fix/warmup-lr-override branch from 04d33d3 to 6ef55d9 Compare June 18, 2025 02:15
@Flink-ddd
Copy link
Contributor Author

Hi @tjruwase , thanks again for the earlier approval!

After rebasing my branch onto the latest master to resolve the previous test failures, it seems a new, unrelated test test_autofill_dsconfig_from_ds_plugin is now failing with a ValueError. This seems to be related to recent changes on the master branch.

Could you please take a look when you have a moment? My changes are confined to lr_schedules.py and should not affect this config autofill logic. Thank you!

@Flink-ddd
Copy link
Contributor Author

Hi @tjruwase , thank you so much for merging the master branch and helping to get the CI to pass! It looks like all checks have passed now. Is there anything else needed from my end for this PR to be merged?

Thank you again for your time and guidance!

@tjruwase
Copy link
Contributor

@Flink-ddd, thanks for your contribution. It is greatly appreciated.

@tjruwase tjruwase merged commit 22cf1a4 into deepspeedai:master Jun 19, 2025
9 checks passed
@Flink-ddd
Copy link
Contributor Author

Hi @tjruwase , Thank you! My pleasure! It's an honor to contribute to the DeepSpeed project. I've learned a lot from the process and look forward to contributing more in the future. Thanks again for your response on time!

Antlera pushed a commit to Antlera/DeepSpeed that referenced this pull request Jun 27, 2025
…epspeedai#7360)

This PR fixes issue deepspeedai#7303.

### 1. Description of the Bug

Currently, when using the `WarmupLR` scheduler, if `warmup_max_lr` is
not explicitly set in the scheduler's parameters, it incorrectly falls
back to its internal default value (`0.001`), ignoring the learning rate
set in the optimizer's parameters. This can lead to unexpected training
behavior and diverges from user expectations.

### 2. Description of the Fix

This fix modifies the `__init__` method of the `WarmupLR` scheduler in
`deepspeed/runtime/lr_schedules.py`.

- The default value for the `warmup_max_lr` argument in the function
signature is changed from `0.001` to `None`.
- Logic is added to check if `warmup_max_lr` is `None` upon
initialization. If it is, the scheduler now correctly inherits the
learning rate from the optimizer's parameter groups.

This change ensures that the optimizer's learning rate is respected as
the default `warmup_max_lr`, aligning the scheduler's behavior with the
user's configuration intent.

### 3. Verification

The fix has been verified using a minimal reproduction script that
clearly demonstrates the behavioral change.

**Before Fix:**
Without `warmup_max_lr` in the scheduler config, the learning rate
incorrectly defaults to `0.001`.
<img width="1711" alt="Screenshot 2025-06-16 at 18 34 31"
src="https://github.com/user-attachments/assets/fe68f39e-2bbc-4f94-b322-546d9ce43bb0"
/>


**After Workaround (Demonstrating the Mechanism):**
By explicitly adding `warmup_max_lr` to the scheduler config, the
learning rate behaves as expected. My code change makes this the default
behavior.
<img width="1195" alt="Screenshot 2025-06-16 at 20 17 11"
src="https://github.com/user-attachments/assets/cc170246-fdac-4a56-8b9c-f204ebb47895"
/>

Signed-off-by: Vensenmu <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants