@@ -55,7 +55,7 @@ def preprocess_weights(
55
55
56
56
# (M // bits, K, bits)
57
57
w = np .stack ([(w >> ib ) & 1 for ib in range (bits )], axis = - 1 )
58
- # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g)
58
+ # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g)
59
59
w = w .transpose (0 , 2 , 1 ).reshape (M // bits , bits , K // g , g )
60
60
w = sum ([(w [:, :, :, ig ] << ig ) for ig in range (g )])
61
61
# 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31
@@ -65,7 +65,7 @@ def preprocess_weights(
65
65
w = w .reshape (M // bits // simd_n_out , simd_n_out , bits , K // g ).transpose (0 , 2 , 1 , 3 )
66
66
mgroup = ngroups_per_elem * simd_n_in
67
67
w = w .reshape (M // mgroup , ngroups_per_elem , simd_n_in , K // g ).transpose (0 , 2 , 1 , 3 )
68
- # 0 1 2 3 4 5
68
+ # 0 1 2 3 4 5
69
69
w = w .reshape (M // bm , bm // mgroup , simd_n_in , ngroups_per_elem , K // g // kfactor , kfactor ).transpose (0 , 4 , 1 , 5 , 2 , 3 )
70
70
w = sum ([(w [:, :, :, :, :, ng ] << (ng * g )) for ng in range (ngroups_per_elem )])
71
71
w = w .reshape (M // bm , K // g // kfactor , bm // mgroup , kfactor , simd_n_in )
0 commit comments