6
6
7
7
from typing import Any , Optional , Sequence
8
8
from torch_tensorrt import EngineCapability , Device
9
- from torch_tensorrt .fx .utils import LowerPrecision
10
9
from torch .fx .passes .pass_manager import PassManager
11
10
from torch .fx .passes .shape_prop import ShapeProp
12
11
from torch_tensorrt .dynamo .aten_tracer import trace
@@ -78,29 +77,29 @@ def compile(
78
77
if not isinstance (inputs , collections .abc .Sequence ):
79
78
inputs = [inputs ]
80
79
81
- inputs = prepare_inputs (inputs , prepare_device (device ))
80
+ torchtrt_inputs , torch_inputs = prepare_inputs (inputs , prepare_device (device ))
82
81
83
82
if (
84
83
torch .float16 in enabled_precisions
85
84
or torch_tensorrt .dtype .half in enabled_precisions
86
85
):
87
- lower_precision = LowerPrecision . FP16
86
+ precision = torch . float16
88
87
elif (
89
88
torch .float32 in enabled_precisions
90
89
or torch_tensorrt .dtype .float in enabled_precisions
91
90
):
92
- lower_precision = LowerPrecision . FP32
91
+ precision = torch . float32
93
92
elif len (enabled_precisions ) == 0 :
94
93
logger .info (f"No precision specified, defaulting to { PRECISION } " )
95
- lower_precision = PRECISION
94
+ precision = PRECISION
96
95
else :
97
96
raise ValueError (
98
97
f"Precision { enabled_precisions } not supported in the Dynamo Path"
99
98
)
100
99
101
100
if kwargs .get ("ir" , "dynamo" ) == "torch_compile" :
102
101
custom_backend = create_backend (
103
- precision = lower_precision ,
102
+ precision = precision ,
104
103
debug = debug ,
105
104
workspace_size = workspace_size ,
106
105
min_block_size = min_block_size ,
@@ -114,13 +113,13 @@ def compile(
114
113
)
115
114
model = torch .compile (gm , backend = custom_backend )
116
115
# Ensure compilation occurs by calling the function with provided inputs
117
- model (* inputs )
116
+ model (* torch_inputs )
118
117
return model
119
118
120
119
else :
121
120
settings = CompilationSettings (
122
121
debug = debug ,
123
- precision = lower_precision ,
122
+ precision = precision ,
124
123
workspace_size = workspace_size ,
125
124
min_block_size = min_block_size ,
126
125
torch_executed_ops = torch_executed_ops ,
@@ -131,20 +130,20 @@ def compile(
131
130
use_python_runtime = use_python_runtime ,
132
131
)
133
132
134
- model = trace (gm , inputs , ** kwargs )
133
+ model = trace (gm , torch_inputs , ** kwargs )
135
134
136
135
if kwargs .get ("use_capability_partitioner" , None ):
137
- model = lower_model (model , inputs )
138
- return _compile_module (model , inputs , settings )
136
+ model = lower_model (model , torch_inputs )
137
+ return _compile_module (model , torch_inputs , settings )
139
138
else :
140
- split_result = lower_model_using_trt_splitter (model , inputs )
141
- trt_module = _compile_graph (split_result , inputs , settings )
139
+ split_result = lower_model_using_trt_splitter (model , torch_inputs )
140
+ trt_module = _compile_graph (split_result , torch_inputs , settings )
142
141
143
142
return trt_module
144
143
145
144
146
145
def create_backend (
147
- precision : LowerPrecision = PRECISION ,
146
+ precision : torch . dtype = PRECISION ,
148
147
debug : bool = DEBUG ,
149
148
workspace_size : int = WORKSPACE_SIZE ,
150
149
min_block_size : int = MIN_BLOCK_SIZE ,
@@ -234,7 +233,7 @@ def lower_model(model: torch.nn.Module, inputs: Any, **kwargs):
234
233
[fuse_permute_matmul , fuse_permute_linear ]
235
234
)
236
235
lowered_model = graph_optimization_pm (model )
237
- if isinstance (lowered_model , torch .fx .GraphModule ):
238
- ShapeProp (lowered_model ).propagate (* inputs )
236
+ # if isinstance(lowered_model, torch.fx.GraphModule):
237
+ # ShapeProp(lowered_model).propagate(*inputs)
239
238
240
239
return lowered_model
0 commit comments