Skip to content

Commit 7d81140

Browse files
committed
[XPU] Support XCCL on deepspeed side
Signed-off-by: yisheng <[email protected]>
1 parent c2c8199 commit 7d81140

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

accelerator/real_accelerator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,17 @@ def get_accelerator():
131131
if accelerator_name is None:
132132
try:
133133
import intel_extension_for_pytorch as ipex
134-
135134
if ipex._C._has_xpu():
136135
accelerator_name = "xpu"
137136
except ImportError as e:
138-
pass
137+
import torch
138+
if hasattr(torch, 'xpu'):
139+
if torch.xpu.is_available():
140+
accelerator_name = "xpu"
141+
else:
142+
pass
143+
else:
144+
pass
139145
if accelerator_name is None:
140146
try:
141147
import torch_npu # noqa: F401,F811 # type: ignore

accelerator/xpu_accelerator.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,32 @@
55

66
import torch
77
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
8-
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
9-
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
108
import functools
11-
129
import importlib
1310
import inspect
1411

12+
try:
13+
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
14+
oneccl_imported_p = True
15+
except ImportError as e:
16+
oneccl_imported_p = False
17+
18+
try:
19+
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
20+
ipex_imported_p = True
21+
except ImportError as e:
22+
ipex_imported_p = False
23+
1524

1625
class XPU_Accelerator(DeepSpeedAccelerator):
1726

1827
def __init__(self):
1928
self._name = 'xpu'
20-
self._communication_backend_name = 'ccl'
29+
if oneccl_imported_p:
30+
self._communication_backend_name = 'ccl'
31+
else:
32+
# changed to xccl if not using torch-CCL on XPU device
33+
self._communication_backend_name = 'xccl'
2134
self._compile_backend = "inductor"
2235
self.aligned_tensors = []
2336
self.class_dict = None
@@ -27,10 +40,14 @@ def is_synchronized_device(self):
2740

2841
def use_host_timers(self):
2942
# WA XPU event will be consolidated in 2.6
30-
if ipex.__version__ < '2.6':
31-
return True
32-
else:
43+
if not ipex_imported_p:
3344
return self.is_synchronized_device()
45+
else:
46+
# WA XPU event will be consolidated in 2.6
47+
if ipex.__version__ < '2.6':
48+
return True
49+
else:
50+
return self.is_synchronized_device()
3451

3552
def resolves_data_dependency(self):
3653
return self.is_synchronized_device()
@@ -290,10 +307,13 @@ def get_op_builder(self, class_name):
290307
return self.class_dict['NotImplementedBuilder']
291308

292309
def build_extension(self):
293-
try:
294-
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
295-
except ImportError:
296-
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
310+
if ipex_imported_p:
311+
try:
312+
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
313+
except ImportError:
314+
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
315+
else:
316+
from torch.utils.cpp_extension import DpcppBuildExtension
297317
return DpcppBuildExtension
298318

299319
def export_envs(self):

0 commit comments

Comments
 (0)