Skip to content

Commit 657bfc5

Browse files
Fix device placement when setting up FSDP model in Lite (#15822)
* fix * debug test * simplify Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3fad651 commit 657bfc5

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/lightning_lite/lite.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,9 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _LiteM
218218
module = self._strategy.setup_module(module)
219219
module = _LiteModule(module, self._precision, original_module=original_module)
220220

221-
# Update the _DeviceDtypeModuleMixin's device parameter
222-
module.to(self.device if move_to_device else next(module.parameters()).device)
221+
if not isinstance(self._strategy, FSDPStrategy):
222+
# Update the _DeviceDtypeModuleMixin's device parameter
223+
module.to(self.device if move_to_device else next(module.parameters()).device)
223224

224225
self._models_setup += 1
225226
return module

tests/tests_lite/strategies/test_fsdp_integration.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import tempfile
15+
from unittest import mock
1516

1617
import pytest
1718
import torch
@@ -116,3 +117,29 @@ def test_fsdp_train_save_load(manual_wrapping, precision):
116117
lite._strategy.save_checkpoint(model.state_dict(), ckpt_path)
117118

118119
_assert_save_equality(lite, model, ckpt_path)
120+
121+
122+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
123+
@pytest.mark.parametrize("move_to_device", [True, False])
124+
@mock.patch("lightning_lite.wrappers._LiteModule")
125+
def test_setup_module_move_to_device(lite_module_mock, move_to_device):
126+
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
127+
(sharding)."""
128+
strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy)
129+
lite = LightningLite(accelerator="cuda", devices=2, strategy=strategy)
130+
lite.launch()
131+
132+
model = torch.nn.Linear(10, 10, bias=False) # total params: 10 * 10 = 100
133+
lite_model = lite.setup_module(model, move_to_device=move_to_device)
134+
lite_module_mock.assert_not_called()
135+
136+
assert list(param.device for param in model.parameters()) == []
137+
assert len(list(lite_model.parameters())) == 1
138+
139+
# the linear layer got sharded and each part is on the expected device
140+
assert next(lite_model.parameters()).device == torch.device("cuda", lite.local_rank)
141+
assert next(lite_model.parameters()).numel() == 50
142+
143+
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models
144+
assert lite_model.device == torch.device("cpu")
145+
assert lite.device == torch.device("cuda", lite.local_rank)

0 commit comments

Comments
 (0)