Skip to content

Commit 12b4dc1

Browse files
authored
Fix DeepCompile for PyTorch v2.8 (#7496)
This PR updates the kernel generation function arguments in Inductor to ensure DeepCompile is compatible with PyTorch v2.8. It also fixes the logging output of DeepCompile.
1 parent 1c03d1b commit 12b4dc1

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

deepspeed/compile/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def make_bw_graph(gm, sample_inputs):
311311
graph_index = get_index_by_graph_id(graph_order, graph_id)
312312
log_rank0(
313313
f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
314-
enable=True)
314+
enable=debug_log)
315315

316316
bwd_inputs_stack = get_backward_inputs()
317317

deepspeed/compile/inductor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
except ImportError:
1818
pass
1919

20+
from deepspeed.utils.torch import required_torch_version
2021
from .util import get_input_nodes
2122
from .graph_param import DSGraphParamManager
2223

@@ -172,7 +173,11 @@ def codegen(self, wrapper):
172173
self.codegen_comment(wrapper)
173174
args = [*self.codegen_args(), *self.codegen_kwargs()]
174175

175-
V.graph.wrapper_code.generate_fallback_kernel(self, args)
176+
if required_torch_version(min_version=2.8):
177+
V.graph.wrapper_code.generate_fallback_kernel(self)
178+
else:
179+
V.graph.wrapper_code.generate_fallback_kernel(self, args)
180+
176181
if isinstance(self.layout, Layout):
177182
self.codegen_size_asserts(wrapper)
178183

0 commit comments

Comments
 (0)