@@ -141,14 +141,14 @@ def produce_guards_expression(self, *args, **kwargs):
141
141
return ""
142
142
143
143
144
- def wrap_inductor (graph ,
144
+ def wrap_inductor (graph : fx . GraphModule ,
145
145
example_inputs ,
146
146
additional_inductor_config ,
147
147
compilation_config : CompilationConfig ,
148
148
graph_index : int = 0 ,
149
149
num_graphs : int = 1 ,
150
150
runtime_shape : Optional [int ] = None ,
151
- use_inductor : bool = True ):
151
+ use_inductor : bool = True ) -> Any :
152
152
if graph_index == 0 :
153
153
# before compiling the first graph, record the start time
154
154
global compilation_start_time
@@ -209,7 +209,7 @@ def wrap_inductor(graph,
209
209
returns_tuple = graph_returns_tuple (graph )
210
210
211
211
# this is the graph we return to Dynamo to run
212
- def compiled_graph (* args ):
212
+ def compiled_graph (* args ) -> Optional [ fx . CompiledFxGraph ] :
213
213
# convert args to list
214
214
list_args = list (args )
215
215
graph_output = inductor_compiled_graph (list_args )
@@ -247,7 +247,7 @@ def _check_can_cache(*args, **kwargs):
247
247
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
248
248
return
249
249
250
- def _get_shape_env ():
250
+ def _get_shape_env () -> AlwaysHitShapeEnv :
251
251
return AlwaysHitShapeEnv ()
252
252
253
253
with patch (# for hijacking the hash of the compiled graph
@@ -537,7 +537,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
537
537
example_inputs [x ].clone () for x in self .sym_tensor_indices
538
538
]
539
539
540
- def copy_and_call (* args ):
540
+ def copy_and_call (* args ) -> fx . GraphModule :
541
541
list_args = list (args )
542
542
for i , index in enumerate (self .sym_tensor_indices ):
543
543
runtime_tensor = list_args [index ]
0 commit comments