Skip to content

Commit f1e0fda

Browse files
authored
Rename Strategy.reduce to Strategy.all_reduce in Lite (#16370)
1 parent 596494b commit f1e0fda

File tree

11 files changed

+47
-44
lines changed

11 files changed

+47
-44
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
- Renamed `Strategy.reduce` to `Strategy.all_reduce` in all strategies ([#16370](https://github.com/Lightning-AI/lightning/issues/16370))
9+
10+
811
## [1.9.0] - 2023-01-12
912

1013
### Added

src/lightning_fabric/strategies/ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
120120
def module_to_device(self, module: Module) -> None:
121121
module.to(self.root_device)
122122

123-
def reduce(
123+
def all_reduce(
124124
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
125125
) -> Tensor:
126126
"""Reduces a tensor from several distributed processes to one aggregated tensor.

src/lightning_fabric/strategies/dp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) ->
6565
# DataParallel handles the transfer of batch to the device
6666
return batch
6767

68-
def reduce(
68+
def all_reduce(
6969
self, collection: TReduce, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
7070
) -> TReduce:
7171
def mean(t: Tensor) -> Tensor:

src/lightning_fabric/strategies/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def module_sharded_context(self) -> Generator:
245245
):
246246
yield
247247

248-
def reduce(
248+
def all_reduce(
249249
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
250250
) -> Tensor:
251251
if isinstance(tensor, Tensor):

src/lightning_fabric/strategies/parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
9494
bool: The reduced boolean decision.
9595
"""
9696
decision = torch.tensor(int(decision), device=self.root_device)
97-
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
97+
decision = self.all_reduce(decision, reduce_op=ReduceOp.SUM)
9898
decision = bool(decision == self.world_size) if all else bool(decision)
9999
return decision
100100

src/lightning_fabric/strategies/single_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def is_global_zero(self) -> bool:
5353
def module_to_device(self, module: Module) -> None:
5454
module.to(self.root_device)
5555

56-
def reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor:
56+
def all_reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor:
5757
"""Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only
5858
operates with a single device, the reduction is simply the identity.
5959

src/lightning_fabric/strategies/strategy.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,17 @@ def optimizer_step(
169169
return self.precision.optimizer_step(optimizer, **kwargs)
170170

171171
@abstractmethod
172-
def reduce(
172+
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
173+
"""Perform an all_gather on all processes.
174+
175+
Args:
176+
tensor: the tensor to all_gather
177+
group: the process group to gather results from
178+
sync_grads: flag that allows users to synchronize gradients for all_gather op
179+
"""
180+
181+
@abstractmethod
182+
def all_reduce(
173183
self,
174184
tensor: Union[Tensor, Any],
175185
group: Optional[Any] = None,
@@ -201,16 +211,6 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
201211
src: source rank
202212
"""
203213

204-
@abstractmethod
205-
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
206-
"""Perform an all_gather on all processes.
207-
208-
Args:
209-
tensor: the tensor to all_gather
210-
group: the process group to gather results from
211-
sync_grads: flag that allows users to synchronize gradients for all_gather op
212-
"""
213-
214214
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
215215
"""Reduce a boolean decision across all processes."""
216216
return decision

src/lightning_fabric/strategies/xla.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,25 @@ def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
118118
dataloader.dataset = dataloader._loader.dataset
119119
return dataloader
120120

121-
def reduce(
121+
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
122+
"""Function to gather a tensor from several distributed processes.
123+
124+
Args:
125+
tensor: tensor of shape (batch, ...)
126+
group: not available with TPUs
127+
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
128+
Return:
129+
A tensor of shape (world_size, batch, ...)
130+
"""
131+
if isinstance(tensor, Tensor) and tensor.dim() == 0:
132+
tensor = tensor.unsqueeze(0)
133+
134+
import torch_xla.core.functions as xf
135+
import torch_xla.core.xla_model as xm
136+
137+
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
138+
139+
def all_reduce(
122140
self, output: Union[Tensor, Any], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
123141
) -> Tensor:
124142
if not isinstance(output, Tensor):
@@ -160,24 +178,6 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
160178
obj = torch.load(buffer)
161179
return obj
162180

163-
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
164-
"""Function to gather a tensor from several distributed processes.
165-
166-
Args:
167-
tensor: tensor of shape (batch, ...)
168-
group: not available with TPUs
169-
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
170-
Return:
171-
A tensor of shape (world_size, batch, ...)
172-
"""
173-
if isinstance(tensor, Tensor) and tensor.dim() == 0:
174-
tensor = tensor.unsqueeze(0)
175-
176-
import torch_xla.core.functions as xf
177-
import torch_xla.core.xla_model as xm
178-
179-
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
180-
181181
def save_checkpoint(
182182
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
183183
) -> None:

tests/tests_fabric/strategies/launchers/test_xla.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ def test_broadcast_on_tpu():
6060

6161
def tpu_reduce_fn(strategy):
6262
with pytest.raises(ValueError, match="XLAStrategy only supports"):
63-
strategy.reduce(1, reduce_op="undefined")
63+
strategy.all_reduce(1, reduce_op="undefined")
6464

6565
with pytest.raises(ValueError, match="XLAStrategy only supports"):
66-
strategy.reduce(1, reduce_op=ReduceOp.MAX)
66+
strategy.all_reduce(1, reduce_op=ReduceOp.MAX)
6767

6868
# it is faster to loop over here than to parameterize the test
6969
for reduce_op in ("mean", "AVG", "sum", ReduceOp.SUM):
70-
result = strategy.reduce(1, reduce_op=reduce_op)
70+
result = strategy.all_reduce(1, reduce_op=reduce_op)
7171
if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"):
7272
assert result.item() == 1
7373
else:
@@ -77,7 +77,7 @@ def tpu_reduce_fn(strategy):
7777
@RunIf(tpu=True)
7878
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
7979
def test_tpu_reduce():
80-
"""Test tpu spawn reduce operation."""
80+
"""Test tpu spawn all_reduce operation."""
8181
xla_launch(tpu_reduce_fn)
8282

8383

tests/tests_fabric/strategies/test_single_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_single_device_collectives():
4242
strategy = SingleDeviceStrategy()
4343
tensor = Mock()
4444
assert strategy.all_gather(tensor) == tensor
45-
assert strategy.reduce(tensor) == tensor
45+
assert strategy.all_reduce(tensor) == tensor
4646
assert strategy.broadcast(tensor) == tensor
4747

4848

0 commit comments

Comments
 (0)