-
Notifications
You must be signed in to change notification settings - Fork 106
fix: match no_sync with pytorch #2255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
thunder/tests/distributed/helper.py
Outdated
device = torch.device("cuda", test_case.rank) | ||
|
||
gradients = defaultdict(list) | ||
for use_no_sync in (True, False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have it rather as an argument ot we really need this loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, we do not need this loop. Would it make a difference if we use arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from the perspective of what we test, there's no differences but making use_no_sync
argument would make test logs way better. It clearly tells us which setting fails or not. I'd have use_no_sync
as an argument as @Borda suggests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should split the test.
That said, I would expect the no-sync test to roughly do the following:
for j in range(2):
for i in range(2):
with no_sync():
fw
loss
bw
grab grad.clones for comparison # these would not be synced, even outside the context manager.
fw
loss
bw # sync happens only(!) here and below
grab grad.clones for comparison # these are synced
zero_grad
fw
loss
bw # sync happens only(!) here and above
grab grad.clones for comparison # these are synced
zero_grad
and then the comparison is between the gradients in order between plain PT DDP and Thunder DDP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the grabbing of the grads happens outside the no_sync.
Co-authored-by: Jirka Borovec <[email protected]>
thunder/tests/distributed/helper.py
Outdated
test_case.assertGreater(len(no_sync_bwd_trc.bound_symbols), 1) | ||
assert torch.allclose(torch_loss, loss, atol=1e-4, rtol=1e-4) | ||
|
||
torch.testing.assert_close(torch_grad, thunder_grad, atol=1e-3, rtol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@t-vi I think this should fail if thunder and torch are different
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Before submitting
What does this PR do?
This PR fixes #1670; it adds a test that shows how the current
thunder
'sno_sync
context manager already behaves like torch. If accepted, I will have a follow up PR that removessync_grads
entirely and update the docs.PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃