-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Currently, FullyShardedDataParallel.state_dict returns full state dict on all ranks. If state_dict_device is CPU (https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2221), the memory footprint of the operation will be [SIZE_OF_FULL_STATE_DICT] * [NUM_DEVICES_PER_NODE]. This can sometimes cause hosts without beefy RAM to OOM.
The issue has been addressed in fairscale, PR in review facebookresearch/fairscale#885
Motivation
Solve OOM
Pitch
Add state_dict_on_rank_0_only
to fsdp init and pass into wrapper params with default True
In init:
def __init__(
self,
cpu_offload: bool = False,
flatten_parameters: bool = True,
reshard_after_forward: bool = True,
move_grads_to_cpu: Optional[bool] = None,
fp32_reduce_scatter: Optional[bool] = None,
compute_dtype: Optional[torch.dtype] = None,
bucket_cap_mb: int = 25,
min_num_params: int = 1e8,
state_dict_to_cpu: bool = True,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
state_dict_on_rank_0_only: bool = True
):
....
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
auto_wrap_policy=wrap_policy,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters,
mixed_precision=precision == "mixed",
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
state_dict_device=self.state_dict_device,
state_dict_on_rank_0_only=self.state_dict_on_rank_0_only,
)
Alternatives
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.