@@ -118,7 +118,25 @@ def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
118
118
dataloader .dataset = dataloader ._loader .dataset
119
119
return dataloader
120
120
121
- def reduce (
121
+ def all_gather (self , tensor : Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> Tensor :
122
+ """Function to gather a tensor from several distributed processes.
123
+
124
+ Args:
125
+ tensor: tensor of shape (batch, ...)
126
+ group: not available with TPUs
127
+ sync_grads: flag that allows users to synchronize gradients for the all_gather operation
128
+ Return:
129
+ A tensor of shape (world_size, batch, ...)
130
+ """
131
+ if isinstance (tensor , Tensor ) and tensor .dim () == 0 :
132
+ tensor = tensor .unsqueeze (0 )
133
+
134
+ import torch_xla .core .functions as xf
135
+ import torch_xla .core .xla_model as xm
136
+
137
+ return xf .all_gather (tensor ) if sync_grads else xm .all_gather (tensor )
138
+
139
+ def all_reduce (
122
140
self , output : Union [Tensor , Any ], group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
123
141
) -> Tensor :
124
142
if not isinstance (output , Tensor ):
@@ -160,24 +178,6 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
160
178
obj = torch .load (buffer )
161
179
return obj
162
180
163
- def all_gather (self , tensor : Tensor , group : Optional [Any ] = None , sync_grads : bool = False ) -> Tensor :
164
- """Function to gather a tensor from several distributed processes.
165
-
166
- Args:
167
- tensor: tensor of shape (batch, ...)
168
- group: not available with TPUs
169
- sync_grads: flag that allows users to synchronize gradients for the all_gather operation
170
- Return:
171
- A tensor of shape (world_size, batch, ...)
172
- """
173
- if isinstance (tensor , Tensor ) and tensor .dim () == 0 :
174
- tensor = tensor .unsqueeze (0 )
175
-
176
- import torch_xla .core .functions as xf
177
- import torch_xla .core .xla_model as xm
178
-
179
- return xf .all_gather (tensor ) if sync_grads else xm .all_gather (tensor )
180
-
181
181
def save_checkpoint (
182
182
self , checkpoint : Dict [str , Any ], filepath : _PATH , storage_options : Optional [Any ] = None
183
183
) -> None :
0 commit comments