Skip to content

Conversation

k223kim
Copy link
Contributor

@k223kim k223kim commented Jun 18, 2025

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

This PR fixes #1670; it adds a test that shows how the current thunder's no_sync context manager already behaves like torch. If accepted, I will have a follow up PR that removes sync_grads entirely and update the docs.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@k223kim k223kim changed the title feat: match no_sync with pytorch [wip] feat: match no_sync with pytorch Jun 18, 2025
@k223kim k223kim marked this pull request as ready for review June 20, 2025 15:41
@k223kim k223kim changed the title [wip] feat: match no_sync with pytorch fix: match no_sync with pytorch Jun 20, 2025
device = torch.device("cuda", test_case.rank)

gradients = defaultdict(list)
for use_no_sync in (True, False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have it rather as an argument ot we really need this loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, we do not need this loop. Would it make a difference if we use arguments?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the perspective of what we test, there's no differences but making use_no_sync argument would make test logs way better. It clearly tells us which setting fails or not. I'd have use_no_sync as an argument as @Borda suggests.

Copy link
Collaborator

@t-vi t-vi Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should split the test.
That said, I would expect the no-sync test to roughly do the following:

for j in range(2):
    for i in range(2):
        with no_sync():
            fw
            loss
            bw
        grab grad.clones for comparison  # these would not be synced, even outside the context manager.
    fw
    loss
    bw  # sync happens only(!) here and below
    grab grad.clones for comparison  # these are synced
    zero_grad
    fw
    loss
    bw  # sync happens only(!) here and above
    grab grad.clones for comparison  # these are synced
    zero_grad

and then the comparison is between the gradients in order between plain PT DDP and Thunder DDP.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the grabbing of the grads happens outside the no_sync.

Co-authored-by: Jirka Borovec <[email protected]>
test_case.assertGreater(len(no_sync_bwd_trc.bound_symbols), 1)
assert torch.allclose(torch_loss, loss, atol=1e-4, rtol=1e-4)

torch.testing.assert_close(torch_grad, thunder_grad, atol=1e-3, rtol=1e-3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi I think this should fail if thunder and torch are different

@k223kim k223kim marked this pull request as draft June 23, 2025 09:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make ThunderModule.no_sync behave as PyTorch's Distributed one
4 participants