diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..8b2387fcea481 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1472,6 +1472,10 @@ def forward(self, x): ) example_inputs = self.example_input_array + if kwargs.get("check_inputs") is not None: + kwargs["check_inputs"] = self._on_before_batch_transfer(kwargs["check_inputs"]) + kwargs["check_inputs"] = self._apply_batch_transfer_handler(kwargs["check_inputs"]) + # automatically send example inputs to the right device and use trace example_inputs = self._on_before_batch_transfer(example_inputs) example_inputs = self._apply_batch_transfer_handler(example_inputs) diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 10a19974971eb..29f251044c0b5 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -105,6 +105,26 @@ def test_torchscript_device(device_str): assert script_output.device == device +@pytest.mark.parametrize( + "device_str", + [ + "cpu", + pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps:0", marks=RunIf(mps=True)), + ], +) +def test_torchscript_device_with_check_inputs(device_str): + """Test that scripted module is on the correct device.""" + device = torch.device(device_str) + model = BoringModel().to(device) + model.example_input_array = torch.randn(5, 32) + + check_inputs = torch.rand(5, 32) + + script = model.to_torchscript(method="trace", check_inputs=check_inputs) + assert isinstance(script, torch.jit.ScriptModule) + + def test_torchscript_retain_training_state(): """Test that torchscript export does not alter the training mode of original model.""" model = BoringModel()