Skip to content

Commit 0d822e4

Browse files
stekiripre-commit-ci[bot]Bordaawaelchlicarmocca
authored
Make gradients available for all_gather on TPU (#15003)
* Make gradients available for all_gather on TPU * Modify switch and tests * Apply suggestions from code review * Modify tests * Fix test * Drop test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent d5b9c67 commit 0d822e4

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

src/lightning_lite/strategies/xla.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,22 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
156156
return obj
157157

158158
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
159-
"""
160-
Function to gather a tensor from several distributed processes
159+
"""Function to gather a tensor from several distributed processes.
160+
161161
Args:
162162
tensor: tensor of shape (batch, ...)
163163
group: not available with TPUs
164-
sync_grads: not available with TPUs
164+
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
165165
Return:
166166
A tensor of shape (world_size, batch, ...)
167167
"""
168168
if isinstance(tensor, Tensor) and tensor.dim() == 0:
169169
tensor = tensor.unsqueeze(0)
170+
171+
import torch_xla.core.functions as xf
170172
import torch_xla.core.xla_model as xm
171173

172-
return xm.all_gather(tensor)
174+
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
173175

174176
def save_checkpoint(
175177
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,20 +289,22 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
289289
self.checkpoint_io.remove_checkpoint(filepath)
290290

291291
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
292-
"""
293-
Function to gather a tensor from several distributed processes
292+
"""Function to gather a tensor from several distributed processes.
293+
294294
Args:
295295
tensor: tensor of shape (batch, ...)
296296
group: not available with TPUs
297-
sync_grads: not available with TPUs
297+
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
298298
Return:
299299
A tensor of shape (world_size, batch, ...)
300300
"""
301301
if isinstance(tensor, Tensor) and tensor.dim() == 0:
302302
tensor = tensor.unsqueeze(0)
303+
304+
import torch_xla.core.functions as xf
303305
import torch_xla.core.xla_model as xm
304306

305-
return xm.all_gather(tensor)
307+
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
306308

307309
def teardown(self) -> None:
308310
super().teardown()

tests/tests_lite/strategies/test_xla.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from unittest.mock import Mock
1818

1919
import pytest
20+
import torch
2021
from tests_lite.helpers.dataloaders import CustomNotImplementedErrorDataloader
2122
from tests_lite.helpers.models import RandomDataset, RandomIterableDataset
2223
from tests_lite.helpers.runif import RunIf
@@ -113,3 +114,24 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc
113114

114115
with pytest.raises(TypeError, match="TPUs do not currently support"):
115116
XLAStrategy().process_dataloader(dataloader)
117+
118+
119+
def tpu_all_gather_fn(strategy):
120+
for sync_grads in [True, False]:
121+
tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True)
122+
result = strategy.all_gather(tensor, sync_grads=sync_grads)
123+
summed = result.sum()
124+
assert torch.equal(summed, torch.tensor(8.0))
125+
summed.backward()
126+
if sync_grads:
127+
assert torch.equal(tensor.grad, torch.tensor(1.0))
128+
else:
129+
# As gradients are not synced, the original tensor will not have gradients.
130+
assert tensor.grad is None
131+
132+
133+
@RunIf(tpu=True)
134+
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
135+
def test_tpu_all_gather():
136+
"""Test the all_gather operation on TPU."""
137+
xla_launch(tpu_all_gather_fn)

0 commit comments

Comments
 (0)