Skip to content

Commit 87f3b8c

Browse files
committed
Skip casting model inputs to fp32 if weights and inputs are all fp16
1 parent e6be416 commit 87f3b8c

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

coremltools/converters/mil/frontend/torch/converter.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,23 @@ def __init__(
585585
self.opset_version = _target(opset_version) if opset_version is not None else None
586586
self._prog = mil.Program()
587587

588+
self.src_model_has_all_fp16_weights = False
589+
588590
if isinstance(loaded_model, torch.jit.ScriptModule):
591+
# src_model_has_all_fp16_weights will be True
592+
# if there are more than one trainable layers in the model
593+
# and if all those trainable layers have the fp16 dtype
594+
# eg: if pytorch_model.half() has been explicitly used.
595+
num_trainable_layers = 0
596+
num_trainable_fp16_layers = 0
597+
for param in loaded_model.parameters():
598+
if param.requires_grad:
599+
num_trainable_layers += 1
600+
if param.dtype == torch.float16:
601+
num_trainable_fp16_layers += 1
602+
if num_trainable_layers > 0:
603+
self.src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers
604+
589605
self.context = TranscriptionContext(frontend=TorchFrontend.TORCHSCRIPT)
590606
self.graph = InternalTorchIRGraph.from_torchscript(
591607
torchscript=loaded_model, inputs=self.inputs, cut_at_symbols=cut_at_symbols
@@ -1261,6 +1277,11 @@ def convert(self) -> Program:
12611277
user_names = list(ssa_func_inputs.keys())
12621278
internal_names = list(self.graph.inputs.keys())
12631279
internal_names.extend(user_names[len(internal_names) :])
1280+
input_dtypes = []
1281+
for torch_name, ssa_name in zip(internal_names, user_names):
1282+
input_var = ssa_func.inputs[ssa_name]
1283+
input_dtypes.append(input_var.dtype)
1284+
all_fp16_inputs = all(x == types.fp16 for x in input_dtypes)
12641285
for torch_name, ssa_name in zip(internal_names, user_names):
12651286
input_var = ssa_func.inputs[ssa_name]
12661287
if self.context.frontend == TorchFrontend.TORCHSCRIPT:
@@ -1272,7 +1293,7 @@ def convert(self) -> Program:
12721293
# So here we perform the "cast input to fp32" step
12731294
if (
12741295
types.is_tensor(input_var.sym_type) or types.is_scalar(input_var.sym_type)
1275-
) and input_var.dtype == types.fp16:
1296+
) and input_var.dtype == types.fp16 and not (all_fp16_inputs and self.src_model_has_all_fp16_weights):
12761297
# This cast should have placeholder scope
12771298
with mb.scope(
12781299
ScopeInfo(

coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,6 +1522,30 @@ def forward(self, x, y):
15221522
result[name], expected.detach().numpy(), rtol=rtol, atol=atol
15231523
)
15241524

1525+
@staticmethod
1526+
@pytest.mark.parametrize(
1527+
"backend",
1528+
backends,
1529+
)
1530+
def test_torch_fp16_model_with_fp16_inputs(torch_model, backend):
1531+
if backend[0] == "neuralnetwork":
1532+
pytest.skip(
1533+
"Input float16 needs target >= iOS16, which doesn't support neuralnetwork."
1534+
)
1535+
traced_torch_model = torch.jit.trace(torch_model.half(), torch.rand(1, 10).half())
1536+
ct.convert(
1537+
traced_torch_model,
1538+
source="pytorch",
1539+
inputs=[
1540+
ct.TensorType(
1541+
shape=(1, 10),
1542+
)
1543+
],
1544+
outputs=[ct.TensorType(dtype=np.float16)],
1545+
convert_to=backend[0],
1546+
minimum_deployment_target=ct.target.macOS13,
1547+
)
1548+
15251549

15261550
@pytest.fixture
15271551
def int32_input_model():

0 commit comments

Comments
 (0)