Skip to content

Conversation

therealnaveenkamal
Copy link
Contributor

Description

This PR fixes an issue where gradient clipping modifications are not reflected in the global gradient norm calculation when CPU offloading is enabled. The issue occurs because the averaged_gradients are not being updated with the clipped gradients when CPU offloading is active.

Problem

When using CPU offloading with gradient clipping:

  1. The gradients are successfully clipped using safe_set_local_grad
  2. However, the _global_grad_norm calculation still uses the original unclipped gradients.
  3. This leads to incorrect gradient norm reporting and potential issues with gradient clipping effectiveness

Solution

The fix ensures that the averaged_gradients are properly updated with the clipped gradients when CPU offloading is enabled, similar to how it works when CPU offloading is disabled.

Testing

The fix has been tested with:

  • CPU offloading enabled and disabled
  • Different gradient clipping values
  • A simple model with linear layers
  • Both FP16 and BF16

Related Issues

Fixes #7292

@sfc-gh-truwase
Copy link
Collaborator

sfc-gh-truwase commented May 22, 2025

@therealnaveenkamal thanks for the PR. Unfortunately, the current approach won't work because

  1. perf issue for backward
  2. averaged_gradients (despite the poor naming) is meant for the GPU-only execution.

It seems the out-of-sync data structure is self.norm_for_param_grads. In offload case, norms are computed on-the-fly and maintained in self.norm_for_param_grads. I wonder if safe_set_* APIs can be handled by also calling self.set_norm_for_param_grad_in_gpu(param) similar to during backward pass.

What. do you think?

@therealnaveenkamal
Copy link
Contributor Author

therealnaveenkamal commented May 22, 2025

@sfc-gh-truwase I agree. I've reverted my changes and added a norm update in the local API. Please let me know your thoughts

if self.offload_optimizer:
            self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(value)

@sfc-gh-truwase
Copy link
Collaborator

@sfc-gh-truwase I agree. I've reverted my changes and added a norm update in the local API. Please let me know your thoughts

if self.offload_optimizer:
            self.norm_for_param_grads[self.get_param_id(param)] = self._constant_buffered_norm2(value)

@therealnaveenkamal, looks good to me. Did you get a chance to test? Also, is it possible to convert your test case into a unit test?

@therealnaveenkamal
Copy link
Contributor Author

@sfc-gh-truwase I've added a unit-test file test_zero_grad_clip.py. Runs fours tests with BF16 and FP16.

post_clip_norm = clamped_grad.norm().item()

if pre_clip_norm > clip_value:
print(f"DEBUG: Param {param.ds_id} - Pre-clip norm: {pre_clip_norm:.6f}, Post-clip norm: {post_clip_norm:.6f}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the print() so the unit test is not noisy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Thanks for letting me know. I've updated the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Looks like the test is failing...module not found error: mpi4py. Can you please help me here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look on the next run. Usually, mpi4py should have been handled by the requirements.txt.

Also, did you run the formatting checks?
https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md#prerequisites

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase I updated my script, works without mpi4py. Also, did the formatting tests. I did sign-off, but still the DCO shows it failed.

and sorry for the trouble caused, I'm a first time contributor!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@therealnaveenkamal, no apologies needed. We love first time contributors :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase looks like there were issues with torch.distributed setup. When I test locally, the tests passed. I've handled exceptions and pushed my updates. hopefully, the workflow passes this time.

configfile: pytest.ini
plugins: forked-1.6.0
collected 4 items                                                                                                             

tests/unit/runtime/zero/test_zero_grad_clip.py::TestZeroGradClip::test_grad_clip_and_norm_update[fp16-0.5-cpu] PASSED   [ 25%]
tests/unit/runtime/zero/test_zero_grad_clip.py::TestZeroGradClip::test_grad_clip_and_norm_update[bf16-0.05-cpu] PASSED  [ 50%]
tests/unit/runtime/zero/test_zero_grad_clip.py::TestZeroGradClip::test_grad_clip_and_norm_update[fp16-0.5-none] PASSED  [ 75%]
tests/unit/runtime/zero/test_zero_grad_clip.py::TestZeroGradClip::test_grad_clip_and_norm_update[bf16-0.05-none] PASSED [100%]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the flaky CI. I will also keep an eye on it. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Thanks for the support. Would love to contribute more.

@sfc-gh-truwase sfc-gh-truwase added this pull request to the merge queue May 27, 2025
Merged via the queue into deepspeedai:master with commit b9af5d8 May 27, 2025
13 checks passed
deepcharm pushed a commit to deepcharm/DeepSpeed that referenced this pull request Jun 16, 2025
## Description
This PR fixes an issue where gradient clipping modifications are not
reflected in the global gradient norm calculation when CPU offloading is
enabled. The issue occurs because the `averaged_gradients` are not being
updated with the clipped gradients when CPU offloading is active.

## Problem
When using CPU offloading with gradient clipping:
1. The gradients are successfully clipped using `safe_set_local_grad`
2. However, the `_global_grad_norm` calculation still uses the original
unclipped gradients.
3. This leads to incorrect gradient norm reporting and potential issues
with gradient clipping effectiveness

## Solution
The fix ensures that the `averaged_gradients` are properly updated with
the clipped gradients when CPU offloading is enabled, similar to how it
works when CPU offloading is disabled.

## Testing
The fix has been tested with:
- CPU offloading enabled and disabled
- Different gradient clipping values
- A simple model with linear layers
- Both FP16 and BF16

## Related Issues
Fixes deepspeedai#7292

---------

Signed-off-by: Naveenraj Kamalakannan <[email protected]>
Signed-off-by: Max Kovalenko <[email protected]>
Antlera pushed a commit to Antlera/DeepSpeed that referenced this pull request Jun 27, 2025
## Description
This PR fixes an issue where gradient clipping modifications are not
reflected in the global gradient norm calculation when CPU offloading is
enabled. The issue occurs because the `averaged_gradients` are not being
updated with the clipped gradients when CPU offloading is active.

## Problem
When using CPU offloading with gradient clipping:
1. The gradients are successfully clipped using `safe_set_local_grad`
2. However, the `_global_grad_norm` calculation still uses the original
unclipped gradients.
3. This leads to incorrect gradient norm reporting and potential issues
with gradient clipping effectiveness

## Solution
The fix ensures that the `averaged_gradients` are properly updated with
the clipped gradients when CPU offloading is enabled, similar to how it
works when CPU offloading is disabled.

## Testing
The fix has been tested with:
- CPU offloading enabled and disabled
- Different gradient clipping values
- A simple model with linear layers
- Both FP16 and BF16

## Related Issues
Fixes deepspeedai#7292

---------

Signed-off-by: Naveenraj Kamalakannan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Question about safe_set_local_grad and safe_get_local_grad
2 participants