From 63c2e69168b76d008945bbd726a5e5429b0aec09 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 4 Jun 2025 01:18:38 +0800 Subject: [PATCH 1/2] fix: move `check_inputs` if available during `to_torchscript`. --- src/lightning/pytorch/core/module.py | 4 ++++ .../tests_pytorch/models/test_torchscript.py | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..44ebe01dcf376 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1472,6 +1472,10 @@ def forward(self, x): ) example_inputs = self.example_input_array + if kwargs.get("check_inputs") is not None: + check_inputs = self._on_before_batch_transfer(kwargs["check_inputs"]) + kwargs["check_inputs"] = self._apply_batch_transfer_handler(check_inputs) + # automatically send example inputs to the right device and use trace example_inputs = self._on_before_batch_transfer(example_inputs) example_inputs = self._apply_batch_transfer_handler(example_inputs) diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 10a19974971eb..29f251044c0b5 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -105,6 +105,26 @@ def test_torchscript_device(device_str): assert script_output.device == device +@pytest.mark.parametrize( + "device_str", + [ + "cpu", + pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)), + pytest.param("mps:0", marks=RunIf(mps=True)), + ], +) +def test_torchscript_device_with_check_inputs(device_str): + """Test that scripted module is on the correct device.""" + device = torch.device(device_str) + model = BoringModel().to(device) + model.example_input_array = torch.randn(5, 32) + + check_inputs = torch.rand(5, 32) + + script = model.to_torchscript(method="trace", check_inputs=check_inputs) + assert isinstance(script, torch.jit.ScriptModule) + + def test_torchscript_retain_training_state(): """Test that torchscript export does not alter the training mode of original model.""" model = BoringModel() From e21f0cc863324672673954473d53da98ed99c9e2 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 6 Jun 2025 04:09:22 +0800 Subject: [PATCH 2/2] fix: move `check_inputs` if available during `to_torchscript`. --- src/lightning/pytorch/core/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 44ebe01dcf376..8b2387fcea481 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1473,8 +1473,8 @@ def forward(self, x): example_inputs = self.example_input_array if kwargs.get("check_inputs") is not None: - check_inputs = self._on_before_batch_transfer(kwargs["check_inputs"]) - kwargs["check_inputs"] = self._apply_batch_transfer_handler(check_inputs) + kwargs["check_inputs"] = self._on_before_batch_transfer(kwargs["check_inputs"]) + kwargs["check_inputs"] = self._apply_batch_transfer_handler(kwargs["check_inputs"]) # automatically send example inputs to the right device and use trace example_inputs = self._on_before_batch_transfer(example_inputs)