5
5
6
6
import torch
7
7
from deepspeed .accelerator .abstract_accelerator import DeepSpeedAccelerator
8
- import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
9
- import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
10
8
import functools
11
-
12
9
import importlib
13
10
import inspect
14
11
12
+ try :
13
+ import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
14
+ oneccl_imported_p = True
15
+ except ImportError as e :
16
+ oneccl_imported_p = False
17
+
18
+ try :
19
+ import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
20
+ ipex_imported_p = True
21
+ except ImportError as e :
22
+ ipex_imported_p = False
23
+
15
24
16
25
class XPU_Accelerator (DeepSpeedAccelerator ):
17
26
18
27
def __init__ (self ):
19
28
self ._name = 'xpu'
20
- self ._communication_backend_name = 'ccl'
29
+ if oneccl_imported_p :
30
+ self ._communication_backend_name = 'ccl'
31
+ else :
32
+ # changed to xccl if not using torch-CCL on XPU device
33
+ self ._communication_backend_name = 'xccl'
21
34
self ._compile_backend = "inductor"
22
35
self .aligned_tensors = []
23
36
self .class_dict = None
@@ -27,10 +40,14 @@ def is_synchronized_device(self):
27
40
28
41
def use_host_timers (self ):
29
42
# WA XPU event will be consolidated in 2.6
30
- if ipex .__version__ < '2.6' :
31
- return True
32
- else :
43
+ if not ipex_imported_p :
33
44
return self .is_synchronized_device ()
45
+ else :
46
+ # WA XPU event will be consolidated in 2.6
47
+ if ipex .__version__ < '2.6' :
48
+ return True
49
+ else :
50
+ return self .is_synchronized_device ()
34
51
35
52
def resolves_data_dependency (self ):
36
53
return self .is_synchronized_device ()
@@ -290,10 +307,13 @@ def get_op_builder(self, class_name):
290
307
return self .class_dict ['NotImplementedBuilder' ]
291
308
292
309
def build_extension (self ):
293
- try :
294
- from intel_extension_for_pytorch .xpu .cpp_extension import DpcppBuildExtension
295
- except ImportError :
296
- from intel_extension_for_pytorch .xpu .utils import DpcppBuildExtension
310
+ if ipex_imported_p :
311
+ try :
312
+ from intel_extension_for_pytorch .xpu .cpp_extension import DpcppBuildExtension
313
+ except ImportError :
314
+ from intel_extension_for_pytorch .xpu .utils import DpcppBuildExtension
315
+ else :
316
+ from torch .utils .cpp_extension import DpcppBuildExtension
297
317
return DpcppBuildExtension
298
318
299
319
def export_envs (self ):
0 commit comments