File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -75,12 +75,12 @@ class CpuPlatform(Platform):
75
75
def supported_dtypes (self ) -> list [torch .dtype ]:
76
76
if self .get_cpu_architecture () == CpuArchEnum .POWERPC :
77
77
return [torch .bfloat16 , torch .float32 ]
78
- elif sys . platform . startswith (
79
- "darwin" ) and self . get_cpu_architecture () == CpuArchEnum . ARM :
80
- # TODO: change this condition to check if the platform support bf16
81
- # instead of checking the OS. For instance M2 shall supports bf16
82
- # already. But we need to modify `cpu_extension.cmake` to activate
83
- # the feature in the build.
78
+ elif ( self . get_cpu_architecture () == CpuArchEnum . ARM
79
+ and sys . platform . startswith ( "darwin" )) :
80
+ if ( subprocess . check_output (
81
+ [ "sysctl -n hw.optional.arm.FEAT_BF16" ],
82
+ shell = True ). strip () == b"1" ):
83
+ return [ torch . bfloat16 , torch . float16 , torch . float32 ]
84
84
return [torch .float16 , torch .float32 ]
85
85
# x86/aarch64 CPU has supported both bf16 and fp16 natively.
86
86
return [torch .bfloat16 , torch .float16 , torch .float32 ]
You can’t perform that action at this time.
0 commit comments