Currently, we rely on `AllGatherGrad` to compute gather for GPUs. TODO: - [] Extend this class to support TPU - [] Add tests