Skip to content

FullyShardedDataParallel: only return full state dict on rank 0 #11207

@four4fish

Description

@four4fish

🚀 Feature

Currently, FullyShardedDataParallel.state_dict returns full state dict on all ranks. If state_dict_device is CPU (https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2221), the memory footprint of the operation will be [SIZE_OF_FULL_STATE_DICT] * [NUM_DEVICES_PER_NODE]. This can sometimes cause hosts without beefy RAM to OOM.

The issue has been addressed in fairscale, PR in review facebookresearch/fairscale#885

Motivation

Solve OOM

Pitch

Add state_dict_on_rank_0_only to fsdp init and pass into wrapper params with default True
In init:

def __init__(
    self,
    cpu_offload: bool = False,
    flatten_parameters: bool = True,
    reshard_after_forward: bool = True,
    move_grads_to_cpu: Optional[bool] = None,
    fp32_reduce_scatter: Optional[bool] = None,
    compute_dtype: Optional[torch.dtype] = None,
    bucket_cap_mb: int = 25,
    min_num_params: int = 1e8,
    state_dict_to_cpu: bool = True,
    parallel_devices: Optional[List[torch.device]] = None,
    cluster_environment: Optional[ClusterEnvironment] = None,
    checkpoint_io: Optional[CheckpointIO] = None,
    state_dict_on_rank_0_only: bool = True
):
 ....
 self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
        with enable_wrap(
            wrapper_cls=FullyShardedDataParallel,
            auto_wrap_policy=wrap_policy,
            process_group=self.process_group,
            cpu_offload=self.cpu_offload,
            move_grads_to_cpu=self.move_grads_to_cpu,
            flatten_parameters=self.flatten_parameters,
            mixed_precision=precision == "mixed",
            reshard_after_forward=self.reshard_after_forward,
            fp32_reduce_scatter=self.fp32_reduce_scatter,
            compute_dtype=self.compute_dtype,
            bucket_cap_mb=self.bucket_cap_mb,
            state_dict_device=self.state_dict_device,
            state_dict_on_rank_0_only=self.state_dict_on_rank_0_only,
    )

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @SeanNaren @awaelchli

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions