@@ -2175,67 +2175,92 @@ class tinyBLAS_PPC {
2175
2175
int ith, int nth)
2176
2176
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2177
2177
}
2178
-
2179
2178
void matmul (int64_t m, int64_t n) {
2180
- mnpack (0 , m, 0 , n);
2179
+ int64_t mc = 256 ; int64_t nc = 256 ; int64_t kc = 256 ;
2180
+ if ( m%mc == 0 && n%nc == 0 && k%kc == 0 ) {
2181
+ matmul_tiled (m, n, mc, nc, kc);
2182
+ } else {
2183
+ mnpack (0 , m, 0 , n);
2184
+ }
2181
2185
}
2182
2186
2183
2187
private:
2184
2188
2185
2189
void (tinyBLAS_PPC::*kernel)(int64_t , int64_t );
2186
2190
2191
+ inline void save_acc (acc_t * ACC, int64_t ii, int64_t jj) {
2192
+ vec_t vec_C[4 ];
2193
+ __builtin_mma_disassemble_acc (vec_C, ACC);
2194
+ for (int I = 0 ; I < 4 ; I++) {
2195
+ for (int J = 0 ; J < 4 ; J++) {
2196
+ *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2197
+ }
2198
+ }
2199
+ }
2200
+
2201
+ inline void add_save_acc (acc_t * ACC, int64_t ii, int64_t jj) {
2202
+ vec_t vec_C[4 ];
2203
+ __builtin_mma_disassemble_acc (vec_C, ACC);
2204
+ for (int I = 0 ; I < 4 ; I++) {
2205
+ for (int J = 0 ; J < 4 ; J++) {
2206
+ float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);// += *((float*)&vec_C[I]+J);
2207
+ *c_ptr += *((float *)&vec_C[I]+J);
2208
+ }
2209
+ }
2210
+ }
2211
+
2187
2212
inline void vector_permute_store_4 (vector float *src, float *vecOffset) {
2188
2213
vector float t1, t2, t3, t4, t5, t6, t7, t8;
2189
- t1 = vec_mergeh (src[0 ], src[1 ]);
2190
- t2 = vec_mergeh (src[2 ], src[3 ]);
2191
- t3 = vec_mergel (src[0 ], src[1 ]);
2192
- t4 = vec_mergel (src[2 ], src[3 ]);
2193
-
2194
- t5 = vec_xxpermdi (t1, t2, 0 );
2195
- t6 = vec_xxpermdi (t1, t2, 3 );
2196
- t7 = vec_xxpermdi (t3, t4, 0 );
2197
- t8 = vec_xxpermdi (t3, t4, 3 );
2198
-
2199
- vec_xst (t5, 0 , vecOffset);
2200
- vec_xst (t6, 0 , vecOffset + 4 );
2201
- vec_xst (t7, 0 , vecOffset + 8 );
2202
- vec_xst (t8, 0 , vecOffset + 12 );
2203
- }
2214
+ t1 = vec_mergeh (src[0 ], src[1 ]);
2215
+ t2 = vec_mergeh (src[2 ], src[3 ]);
2216
+ t3 = vec_mergel (src[0 ], src[1 ]);
2217
+ t4 = vec_mergel (src[2 ], src[3 ]);
2218
+
2219
+ t5 = vec_xxpermdi (t1, t2, 0 );
2220
+ t6 = vec_xxpermdi (t1, t2, 3 );
2221
+ t7 = vec_xxpermdi (t3, t4, 0 );
2222
+ t8 = vec_xxpermdi (t3, t4, 3 );
2223
+
2224
+ vec_xst (t5, 0 , vecOffset);
2225
+ vec_xst (t6, 0 , vecOffset + 4 );
2226
+ vec_xst (t7, 0 , vecOffset + 8 );
2227
+ vec_xst (t8, 0 , vecOffset + 12 );
2228
+ }
2204
2229
2205
2230
inline void vector_permute_store_8 (vector float *src, float *vecOffset) {
2206
2231
vector float t1, t2, t3, t4, t5, t6, t7, t8;
2207
- t1 = vec_mergeh (src[0 ], src[1 ]);
2208
- t2 = vec_mergeh (src[2 ], src[3 ]);
2209
- t3 = vec_mergeh (src[4 ], src[5 ]);
2210
- t4 = vec_mergeh (src[6 ], src[7 ]);
2211
-
2212
- t5 = vec_xxpermdi (t1, t2, 0 );
2213
- t6 = vec_xxpermdi (t3, t4, 0 );
2214
- t7 = vec_xxpermdi (t1, t2, 3 );
2215
- t8 = vec_xxpermdi (t3, t4, 3 );
2216
-
2217
- vec_xst (t5, 0 , vecOffset);
2218
- vec_xst (t6, 0 , vecOffset + 4 );
2219
- vec_xst (t7, 0 , vecOffset + 8 );
2220
- vec_xst (t8, 0 , vecOffset + 12 );
2221
-
2222
- t1 = vec_mergel (src[0 ], src[1 ]);
2223
- t2 = vec_mergel (src[2 ], src[3 ]);
2224
- t3 = vec_mergel (src[4 ], src[5 ]);
2225
- t4 = vec_mergel (src[6 ], src[7 ]);
2226
-
2227
- t5 = vec_xxpermdi (t1, t2, 0 );
2228
- t6 = vec_xxpermdi (t3, t4, 0 );
2229
- t7 = vec_xxpermdi (t1, t2, 3 );
2230
- t8 = vec_xxpermdi (t3, t4, 3 );
2231
-
2232
- vec_xst (t5, 0 , vecOffset + 16 );
2233
- vec_xst (t6, 0 , vecOffset + 20 );
2234
- vec_xst (t7, 0 , vecOffset + 24 );
2235
- vec_xst (t8, 0 , vecOffset + 28 );
2236
- }
2237
-
2238
- void packTranspose (const float * a, int64_t lda, int rows, int cols, float * vec) {
2232
+ t1 = vec_mergeh (src[0 ], src[1 ]);
2233
+ t2 = vec_mergeh (src[2 ], src[3 ]);
2234
+ t3 = vec_mergeh (src[4 ], src[5 ]);
2235
+ t4 = vec_mergeh (src[6 ], src[7 ]);
2236
+
2237
+ t5 = vec_xxpermdi (t1, t2, 0 );
2238
+ t6 = vec_xxpermdi (t3, t4, 0 );
2239
+ t7 = vec_xxpermdi (t1, t2, 3 );
2240
+ t8 = vec_xxpermdi (t3, t4, 3 );
2241
+
2242
+ vec_xst (t5, 0 , vecOffset);
2243
+ vec_xst (t6, 0 , vecOffset + 4 );
2244
+ vec_xst (t7, 0 , vecOffset + 8 );
2245
+ vec_xst (t8, 0 , vecOffset + 12 );
2246
+
2247
+ t1 = vec_mergel (src[0 ], src[1 ]);
2248
+ t2 = vec_mergel (src[2 ], src[3 ]);
2249
+ t3 = vec_mergel (src[4 ], src[5 ]);
2250
+ t4 = vec_mergel (src[6 ], src[7 ]);
2251
+
2252
+ t5 = vec_xxpermdi (t1, t2, 0 );
2253
+ t6 = vec_xxpermdi (t3, t4, 0 );
2254
+ t7 = vec_xxpermdi (t1, t2, 3 );
2255
+ t8 = vec_xxpermdi (t3, t4, 3 );
2256
+
2257
+ vec_xst (t5, 0 , vecOffset + 16 );
2258
+ vec_xst (t6, 0 , vecOffset + 20 );
2259
+ vec_xst (t7, 0 , vecOffset + 24 );
2260
+ vec_xst (t8, 0 , vecOffset + 28 );
2261
+ }
2262
+
2263
+ void packTranspose (const float * a, int64_t lda, int rows, int cols, float * vec) {
2239
2264
int64_t i, j;
2240
2265
float * aoffsets[8 ];
2241
2266
float *aoffset = NULL , *boffset = NULL ;
@@ -2247,7 +2272,6 @@ class tinyBLAS_PPC {
2247
2272
boffset = vec;
2248
2273
j = (rows >> 3 );
2249
2274
if (j > 0 ) {
2250
-
2251
2275
do {
2252
2276
aoffsets[0 ] = aoffset;
2253
2277
for (int it = 1 ; it< 8 ; it++)
@@ -2265,10 +2289,13 @@ class tinyBLAS_PPC {
2265
2289
2266
2290
vector_permute_store_8 (c1, boffset);
2267
2291
vector_permute_store_8 (c2, boffset+32 );
2268
- for (int it = 0 ; it < 4 ; it++)
2269
- aoffsets[it] = aoffsets[it] + 8 *lda;
2270
2292
boffset += 64 ;
2271
2293
i--;
2294
+ if (i > 0 ) {
2295
+ for (int it = 0 ; it < 8 ; it++) {
2296
+ aoffsets[it] = aoffsets[it] + 8 ;
2297
+ }
2298
+ }
2272
2299
} while (i > 0 );
2273
2300
}
2274
2301
if (cols & 4 ) {
@@ -2401,6 +2428,83 @@ class tinyBLAS_PPC {
2401
2428
SAVE_ACC (&acc_3, ii+4 , jj+4 );
2402
2429
}
2403
2430
2431
+ inline void MMA_16x8 (vec_t *vec_A0, vec_t * vec_A1, vec_t *vec_B, acc_t * acc) {
2432
+ for (int x = 0 ; x < 16 ; x += 2 ) {
2433
+ __builtin_mma_xvf32gerpp (&acc[0 ], vec_A0[x + 0 ], vec_B[x]);
2434
+ __builtin_mma_xvf32gerpp (&acc[1 ], vec_A0[x + 0 ], vec_B[x + 1 ]);
2435
+ __builtin_mma_xvf32gerpp (&acc[2 ], vec_A0[x + 1 ], vec_B[x]);
2436
+ __builtin_mma_xvf32gerpp (&acc[3 ], vec_A0[x + 1 ], vec_B[x + 1 ]);
2437
+ __builtin_mma_xvf32gerpp (&acc[4 ], vec_A1[x + 0 ], vec_B[x]);
2438
+ __builtin_mma_xvf32gerpp (&acc[5 ], vec_A1[x + 0 ], vec_B[x + 1 ]);
2439
+ __builtin_mma_xvf32gerpp (&acc[6 ], vec_A1[x + 1 ], vec_B[x]);
2440
+ __builtin_mma_xvf32gerpp (&acc[7 ], vec_A1[x + 1 ], vec_B[x + 1 ]);
2441
+ }
2442
+ }
2443
+
2444
+ void KERNEL (int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
2445
+ for (int64_t i = 0 ; i <mc; i += 16 ) {
2446
+ int A_base_addr = (mc/8 )* (i/8 )*16 ;
2447
+ for (int64_t j = 0 ; j < nc; j += 8 ) {
2448
+ int B_base_addr = (nc/8 )* (j/8 )*16 ;
2449
+ acc_t acc[8 ];
2450
+ vec_t A0_block[16 ]; vec_t A1_block[16 ];
2451
+ for (int x = 0 ; x < 8 ; x++)
2452
+ __builtin_mma_xxsetaccz (&acc[x]);
2453
+ for (int64_t l = 0 ; l < kc; l+=8 ) {
2454
+ int A0_block_idx = A_base_addr + (l/8 )*16 ;
2455
+ int A1_block_idx = A0_block_idx + (mc/ 8 ) * 16 ;
2456
+ int B_block_idx = B_base_addr + (l/8 )*16 ;
2457
+ vec_t * A0_block = &vec_A[A0_block_idx];
2458
+ vec_t * A1_block = &vec_A[A1_block_idx];
2459
+ vec_t * B_block = &vec_B[B_block_idx];
2460
+ MMA_16x8 (A0_block, A1_block, B_block, acc);
2461
+ }
2462
+ if ( kk == 0 ) {
2463
+ save_acc (&acc[0 ], ii + i, jj + j);
2464
+ save_acc (&acc[1 ], ii + i, jj + j + 4 );
2465
+ save_acc (&acc[2 ], ii + i + 4 , jj + j);
2466
+ save_acc (&acc[3 ], ii + i + 4 , jj + j + 4 );
2467
+ save_acc (&acc[4 ], ii + i + 8 , jj + j);
2468
+ save_acc (&acc[5 ], ii + i + 8 , jj + j + 4 );
2469
+ save_acc (&acc[6 ], ii + i + 12 , jj + j);
2470
+ save_acc (&acc[7 ], ii + i + 12 , jj + j + 4 );
2471
+ } else {
2472
+ add_save_acc (&acc[0 ], ii + i, jj + j);
2473
+ add_save_acc (&acc[1 ], ii + i, jj + j + 4 );
2474
+ add_save_acc (&acc[2 ], ii + i + 4 , jj + j);
2475
+ add_save_acc (&acc[3 ], ii + i + 4 , jj + j + 4 );
2476
+ add_save_acc (&acc[4 ], ii + i + 8 , jj + j);
2477
+ add_save_acc (&acc[5 ], ii + i + 8 , jj + j + 4 );
2478
+ add_save_acc (&acc[6 ], ii + i + 12 , jj + j);
2479
+ add_save_acc (&acc[7 ], ii + i + 12 , jj + j + 4 );
2480
+ }
2481
+ }
2482
+ }
2483
+ }
2484
+
2485
+ void matmul_tiled (int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2486
+ int64_t ytiles = m / mc;
2487
+ int64_t xtiles = n / nc;
2488
+ int64_t tiles = xtiles * ytiles;
2489
+ int64_t duty = (tiles + nth - 1 ) / nth;
2490
+ int64_t start = duty * ith;
2491
+ int64_t end = start + duty;
2492
+ if (end > tiles) {
2493
+ end = tiles;
2494
+ }
2495
+ for (int64_t job = start; job < end; ++job) {
2496
+ int64_t ii = (job / xtiles) * mc;
2497
+ int64_t jj = (job % xtiles) * nc;
2498
+ for (int64_t kk = 0 ; kk < k; kk += kc) {
2499
+ vec_t A_pack[kc*mc/4 ];
2500
+ vec_t B_pack[kc*nc/4 ];
2501
+ packTranspose (A+(ii*lda)+kk, lda, kc, mc, (float *)A_pack);
2502
+ packTranspose (B+(jj*ldb)+kk, ldb, kc, nc, (float *)B_pack);
2503
+ KERNEL (ii, jj, mc, nc, kc, A_pack, B_pack, kk);
2504
+ }
2505
+ }
2506
+ }
2507
+
2404
2508
void mnpack (int64_t m0, int64_t m, int64_t n0, int64_t n) {
2405
2509
int m_rem = MIN (m - m0, 8 );
2406
2510
int n_rem = MIN (n - n0, 8 );
0 commit comments