@@ -585,7 +585,23 @@ def __init__(
585
585
self .opset_version = _target (opset_version ) if opset_version is not None else None
586
586
self ._prog = mil .Program ()
587
587
588
+ self .src_model_has_all_fp16_weights = False
589
+
588
590
if isinstance (loaded_model , torch .jit .ScriptModule ):
591
+ # src_model_has_all_fp16_weights will be True
592
+ # if there are more than one trainable layers in the model
593
+ # and if all those trainable layers have the fp16 dtype
594
+ # eg: if pytorch_model.half() has been explicitly used.
595
+ num_trainable_layers = 0
596
+ num_trainable_fp16_layers = 0
597
+ for param in loaded_model .parameters ():
598
+ if param .requires_grad :
599
+ num_trainable_layers += 1
600
+ if param .dtype == torch .float16 :
601
+ num_trainable_fp16_layers += 1
602
+ if num_trainable_layers > 0 :
603
+ self .src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers
604
+
589
605
self .context = TranscriptionContext (frontend = TorchFrontend .TORCHSCRIPT )
590
606
self .graph = InternalTorchIRGraph .from_torchscript (
591
607
torchscript = loaded_model , inputs = self .inputs , cut_at_symbols = cut_at_symbols
@@ -1261,6 +1277,11 @@ def convert(self) -> Program:
1261
1277
user_names = list (ssa_func_inputs .keys ())
1262
1278
internal_names = list (self .graph .inputs .keys ())
1263
1279
internal_names .extend (user_names [len (internal_names ) :])
1280
+ input_dtypes = []
1281
+ for torch_name , ssa_name in zip (internal_names , user_names ):
1282
+ input_var = ssa_func .inputs [ssa_name ]
1283
+ input_dtypes .append (input_var .dtype )
1284
+ all_fp16_inputs = all (x == types .fp16 for x in input_dtypes )
1264
1285
for torch_name , ssa_name in zip (internal_names , user_names ):
1265
1286
input_var = ssa_func .inputs [ssa_name ]
1266
1287
if self .context .frontend == TorchFrontend .TORCHSCRIPT :
@@ -1272,7 +1293,7 @@ def convert(self) -> Program:
1272
1293
# So here we perform the "cast input to fp32" step
1273
1294
if (
1274
1295
types .is_tensor (input_var .sym_type ) or types .is_scalar (input_var .sym_type )
1275
- ) and input_var .dtype == types .fp16 :
1296
+ ) and input_var .dtype == types .fp16 and not ( all_fp16_inputs and self . src_model_has_all_fp16_weights ) :
1276
1297
# This cast should have placeholder scope
1277
1298
with mb .scope (
1278
1299
ScopeInfo (
0 commit comments