Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* `Fabric.load` returns a dictionary of objects that weren't loaded into the state
* `Strategy.save_checkpoint` and `Fabric.load_checkpoint` are now responsible for accessing the state of the model and optimizers

- `DataParallelStrategy.get_module_state_dict()` and `DDPStrategy.get_module_state_dict()` now correctly extracts the state dict without keys prefixed with 'module' ([#16487](https://github.com/Lightning-AI/lightning/pull/16487))


### Deprecated

Expand Down
5 changes: 5 additions & 0 deletions src/lightning_fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]:
if isinstance(module, DistributedDataParallel):
module = module.module
return super().get_module_state_dict(module)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
entries = (
Expand Down
5 changes: 5 additions & 0 deletions src/lightning_fabric/strategies/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
return decision

def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]:
if isinstance(module, DataParallel):
module = module.module
return super().get_module_state_dict(module)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("dp", cls, description=cls.__class__.__name__)
20 changes: 20 additions & 0 deletions tests/tests_fabric/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,23 @@ def test_ddp_extra_kwargs(ddp_mock):
strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")], find_unused_parameters=True)
strategy.setup_module(module)
ddp_mock.assert_called_with(module=module, device_ids=None, find_unused_parameters=True)


def test_ddp_module_state_dict():
"""Test that the module state dict gets retrieved without the prefixed wrapper keys from DDP."""

class DistributedDataParallelMock(MagicMock):
def __instancecheck__(self, instance):
# to make the strategy's `isinstance(model, DistributedDataParallel)` pass with a mock as class
return True

strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])

# Without DDP applied (no setup call)
original_module = torch.nn.Linear(2, 3)
assert strategy.get_module_state_dict(original_module).keys() == original_module.state_dict().keys()

# With DDP applied (setup called)
with mock.patch("lightning_fabric.strategies.ddp.DistributedDataParallel", DistributedDataParallelMock):
wrapped_module = strategy.setup_module(original_module)
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
22 changes: 21 additions & 1 deletion tests/tests_fabric/strategies/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock

import torch

Expand Down Expand Up @@ -48,3 +48,23 @@ def test_data_parallel_module_to_device():
module = Mock()
strategy.module_to_device(module)
module.to.assert_called_with(torch.device("cuda", 2))


def test_dp_module_state_dict():
"""Test that the module state dict gets retrieved without the prefixed wrapper keys from DP."""

class DataParallelMock(MagicMock):
def __instancecheck__(self, instance):
# to make the strategy's `isinstance(model, DataParallel)` pass with a mock as class
return True

strategy = DataParallelStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])

# Without DP applied (no setup call)
original_module = torch.nn.Linear(2, 3)
assert strategy.get_module_state_dict(original_module).keys() == original_module.state_dict().keys()

# With DP applied (setup called)
with mock.patch("lightning_fabric.strategies.dp.DataParallel", DataParallelMock):
wrapped_module = strategy.setup_module(original_module)
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()