7
7
import torch
8
8
import torch_tensorrt
9
9
from torch .fx .passes .pass_manager import PassManager
10
- from torch .fx .passes .splitter_base import SplitResult
11
10
from torch_tensorrt ._Device import Device
12
11
from torch_tensorrt ._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
13
12
EngineCapability ,
21
20
PASS_THROUGH_BUILD_FAILURES ,
22
21
PRECISION ,
23
22
TRUNCATE_LONG_AND_DOUBLE ,
23
+ USE_FAST_PARTITIONER ,
24
24
USE_PYTHON_RUNTIME ,
25
25
VERSION_COMPATIBLE ,
26
26
WORKSPACE_SIZE ,
27
27
)
28
28
from torch_tensorrt .dynamo .backend .backends import _compile_module
29
- from torch_tensorrt .dynamo .conversion import convert_module
30
29
from torch_tensorrt .dynamo .lowering ._fusers import (
31
30
fuse_permute_linear ,
32
31
fuse_permute_matmul ,
33
32
)
34
33
from torch_tensorrt .dynamo .utils import prepare_device , prepare_inputs
35
- from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter , TRTSplitterSetting
36
34
37
35
logger = logging .getLogger (__name__ )
38
36
@@ -64,6 +62,7 @@ def compile(
64
62
version_compatible : bool = VERSION_COMPATIBLE ,
65
63
optimization_level : Optional [int ] = OPTIMIZATION_LEVEL ,
66
64
use_python_runtime : bool = USE_PYTHON_RUNTIME ,
65
+ use_fast_partitioner : bool = USE_FAST_PARTITIONER ,
67
66
** kwargs : Any ,
68
67
) -> torch .fx .GraphModule :
69
68
if debug :
@@ -75,7 +74,7 @@ def compile(
75
74
"The Dynamo backend is an experimental feature, for which only the "
76
75
+ "following arguments are supported: "
77
76
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
78
- + "torch_executed_ops, pass_through_build_failures}"
77
+ + "torch_executed_ops, pass_through_build_failures, use_fast_partitioner }"
79
78
)
80
79
81
80
if not isinstance (inputs , collections .abc .Sequence ):
@@ -115,55 +114,12 @@ def compile(
115
114
"optimization_level" : optimization_level ,
116
115
"use_python_runtime" : use_python_runtime ,
117
116
"truncate_long_and_double" : truncate_long_and_double ,
117
+ "use_fast_partitioner" : use_fast_partitioner ,
118
118
}
119
119
120
120
settings = CompilationSettings (** compilation_options )
121
- if kwargs .get ("use_capability_partitioner" , None ):
122
- model = lower_model (gm , torch_inputs )
123
- return _compile_module (model , torch_inputs , settings )
124
- else :
125
- split_result = lower_model_using_trt_splitter (gm , torch_inputs )
126
- trt_module = _compile_graph (split_result , torch_inputs , settings )
127
-
128
- return trt_module
129
121
130
-
131
- def _compile_graph (
132
- split_result : SplitResult ,
133
- inputs : Any ,
134
- settings : CompilationSettings = CompilationSettings (),
135
- ** kwargs : Any ,
136
- ) -> torch .fx .GraphModule :
137
- for submod_name , submod_inputs in split_result .submodule_inputs .items ():
138
- submod = getattr (split_result .split_module , submod_name )
139
- # Only acc submodules will be lowered.
140
- if not submod_name .startswith (split_result .non_acc_submodule_prefix ):
141
- # Create TRT Module from submodule
142
- trt_mod = convert_module (
143
- submod ,
144
- submod_inputs ,
145
- settings = settings ,
146
- name = submod_name ,
147
- )
148
- setattr (split_result .split_module , submod_name , trt_mod )
149
-
150
- return split_result .split_module
151
-
152
-
153
- def lower_model_using_trt_splitter (
154
- model : torch .nn .Module , inputs : Any , ** kwargs : Any
155
- ) -> SplitResult :
156
- # Perform basic lowering
157
- model = lower_model (model , inputs )
158
- splitter_setting = TRTSplitterSetting ()
159
- splitter_setting .use_implicit_batch_dim = False
160
- splitter_setting .min_acc_module_size = 1
161
- splitter_setting .use_experimental_rt = False
162
- splitter = TRTSplitter (model , inputs , settings = splitter_setting )
163
- splitter .node_support_preview ()
164
- split_result = splitter .generate_split_results ()
165
-
166
- return split_result
122
+ return _compile_module (gm , torch_inputs , settings )
167
123
168
124
169
125
def lower_model (
0 commit comments