File tree Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -75,11 +75,13 @@ class CpuPlatform(Platform):
75
75
def supported_dtypes (self ) -> list [torch .dtype ]:
76
76
if self .get_cpu_architecture () == CpuArchEnum .POWERPC :
77
77
return [torch .bfloat16 , torch .float32 ]
78
- elif sys .platform .startswith (
79
- "darwin" ) and self .get_cpu_architecture () == CpuArchEnum .ARM :
80
- # TODO: change this condition to check if the platform support bf16
81
- # instead of checking the OS.
82
- return [torch .bfloat16 , torch .float16 , torch .float32 ]
78
+ elif (self .get_cpu_architecture () == CpuArchEnum .ARM
79
+ and sys .platform .startswith ("darwin" )):
80
+ if (subprocess .check_output (
81
+ ["sysctl" , "-n" , "-e" , "hw.optional.arm.FEAT_BF16" ]
82
+ ).strip () == b"1" ):
83
+ return [torch .bfloat16 , torch .float16 , torch .float32 ]
84
+ return [torch .float16 , torch .float32 ]
83
85
# x86/aarch64 CPU has supported both bf16 and fp16 natively.
84
86
return [torch .bfloat16 , torch .float16 , torch .float32 ]
85
87
You can’t perform that action at this time.
0 commit comments