|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import tempfile
|
| 15 | +from unittest import mock |
15 | 16 |
|
16 | 17 | import pytest
|
17 | 18 | import torch
|
@@ -116,3 +117,29 @@ def test_fsdp_train_save_load(manual_wrapping, precision):
|
116 | 117 | lite._strategy.save_checkpoint(model.state_dict(), ckpt_path)
|
117 | 118 |
|
118 | 119 | _assert_save_equality(lite, model, ckpt_path)
|
| 120 | + |
| 121 | + |
| 122 | +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") |
| 123 | +@pytest.mark.parametrize("move_to_device", [True, False]) |
| 124 | +@mock.patch("lightning_lite.wrappers._LiteModule") |
| 125 | +def test_setup_module_move_to_device(lite_module_mock, move_to_device): |
| 126 | + """Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device |
| 127 | + (sharding).""" |
| 128 | + strategy = FSDPStrategy(auto_wrap_policy=_custom_auto_wrap_policy) |
| 129 | + lite = LightningLite(accelerator="cuda", devices=2, strategy=strategy) |
| 130 | + lite.launch() |
| 131 | + |
| 132 | + model = torch.nn.Linear(10, 10, bias=False) # total params: 10 * 10 = 100 |
| 133 | + lite_model = lite.setup_module(model, move_to_device=move_to_device) |
| 134 | + lite_module_mock.assert_not_called() |
| 135 | + |
| 136 | + assert list(param.device for param in model.parameters()) == [] |
| 137 | + assert len(list(lite_model.parameters())) == 1 |
| 138 | + |
| 139 | + # the linear layer got sharded and each part is on the expected device |
| 140 | + assert next(lite_model.parameters()).device == torch.device("cuda", lite.local_rank) |
| 141 | + assert next(lite_model.parameters()).numel() == 50 |
| 142 | + |
| 143 | + # The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models |
| 144 | + assert lite_model.device == torch.device("cpu") |
| 145 | + assert lite.device == torch.device("cuda", lite.local_rank) |
0 commit comments