Skip to content

Commit b8b2771

Browse files
committed
[Fix] Fix shape broadcast for multi-batch lut preprocessor (#55)
1 parent a234dc1 commit b8b2771

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

python/t_mac/ops/qgemm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,10 @@ def recp(s):
508508
return 1.0 / s if s != 0 else 0
509509

510510
ils = np.vectorize(recp)(lut_scales).astype(self.out_dtype)
511-
qlut = np.rint((qlut.transpose(0, 2, 1) * ils).transpose(0, 2, 1).reshape(N, K // self.g, 1 << self.g)).astype(self.dtype)
511+
qlut = np.rint(
512+
(qlut.transpose(2, 0, 1).reshape(-1, qlut.shape[0] * qlut.shape[1]) * ils.reshape(1, qlut.shape[0] * qlut.shape[1]))
513+
.reshape(qlut.shape[2], qlut.shape[0], qlut.shape[1]).transpose(1, 2, 0).reshape(N, K // self.g, 1 << self.g)
514+
).astype(self.dtype)
512515

513516
return [b_t, lut_scales, lut_biases, qlut]
514517

0 commit comments

Comments
 (0)