|
1 | 1 | // SPDX-License-Identifier: MIT
|
2 |
| -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. |
3 | 3 |
|
4 | 4 | #include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
5 | 5 | #include "ck/host/stringutils.hpp"
|
@@ -76,28 +76,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
76 | 76 | // Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
|
77 | 77 | // | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
|
78 | 78 | // | | | | | | | | | | | Wave| Wave| Wave| |
|
79 |
| - { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, |
80 |
| - { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, |
81 |
| - { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, |
82 |
| - { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, |
83 |
| - { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, |
84 |
| - { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, |
85 |
| - { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, |
86 |
| - { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, |
87 |
| - { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, |
88 |
| - { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, |
89 |
| - { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, |
90 |
| - { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, |
| 79 | + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 16, 16, 4, 8, 4, 1}, |
| 80 | + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 16, 16, 4, 8, 8, 1}, |
| 81 | + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, |
| 82 | + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, |
| 83 | + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1}, |
| 84 | + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1}, |
| 85 | + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, |
| 86 | + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, |
| 87 | + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, |
| 88 | + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, |
| 89 | + { 256, 128, 256, 64, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1}, |
| 90 | + { 256, 128, 256, 64, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1}, |
91 | 91 | // Padded fallback kernel
|
92 |
| - { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, |
93 |
| - { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, |
| 92 | + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1}, |
| 93 | + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 16, 16, 2, 4, 8, 1}, |
94 | 94 | // Irregular k
|
95 |
| - { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, |
96 |
| - { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, |
97 |
| - { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, |
98 |
| - { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, |
99 |
| - { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, |
100 |
| - { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, |
| 95 | + { 256, 256, 128, 48, 64, 32, 4, 4, 2, 16, 16, 4, 8, 4, 1}, |
| 96 | + { 256, 256, 128, 48, 128, 32, 4, 4, 2, 16, 16, 4, 8, 8, 1}, |
| 97 | + { 256, 128, 256, 48, 64, 32, 4, 4, 2, 16, 16, 2, 16, 4, 1}, |
| 98 | + { 256, 128, 256, 48, 128, 32, 4, 4, 2, 16, 16, 2, 16, 8, 1}, |
| 99 | + { 256, 128, 128, 48, 64, 32, 4, 4, 2, 16, 16, 2, 8, 4, 1}, |
| 100 | + { 256, 128, 128, 48, 128, 32, 4, 4, 2, 16, 16, 2, 8, 8, 1}, |
101 | 101 | // clang-format on
|
102 | 102 | };
|
103 | 103 |
|
@@ -200,28 +200,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
200 | 200 | // _MBlock_MWaveMPerXdl| ScalarPerVector
|
201 | 201 | // _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
202 | 202 | // |
|
203 |
| - { S<1, 32, 1, 8>, 8}, |
204 |
| - { S<1, 32, 1, 8>, 8}, |
205 |
| - { S<1, 32, 1, 8>, 8}, |
206 |
| - { S<1, 32, 1, 8>, 8}, |
207 |
| - { S<1, 32, 1, 8>, 8}, |
208 |
| - { S<1, 32, 1, 8>, 8}, |
209 |
| - { S<1, 32, 1, 8>, 8}, |
210 |
| - { S<1, 32, 1, 8>, 8}, |
211 |
| - { S<1, 16, 1,16>, 8}, |
212 |
| - { S<1, 32, 1, 8>, 8}, |
213 |
| - { S<1, 16, 1,16>, 8}, |
214 |
| - { S<1, 32, 1, 8>, 8}, |
| 203 | + { S<1, 32, 1, 8>, 4}, |
| 204 | + { S<1, 32, 1, 8>, 4}, |
| 205 | + { S<1, 32, 1, 8>, 4}, |
| 206 | + { S<1, 32, 1, 8>, 4}, |
| 207 | + { S<1, 32, 1, 8>, 4}, |
| 208 | + { S<1, 32, 1, 8>, 4}, |
| 209 | + { S<1, 32, 1, 8>, 4}, |
| 210 | + { S<1, 32, 1, 8>, 4}, |
| 211 | + { S<1, 16, 1,16>, 4}, |
| 212 | + { S<1, 32, 1, 8>, 4}, |
| 213 | + { S<1, 16, 1,16>, 4}, |
| 214 | + { S<1, 32, 1, 8>, 4}, |
215 | 215 | // Padded fallback kernel
|
216 |
| - { S<1, 32, 1, 8>, 8}, |
217 |
| - { S<1, 32, 1, 8>, 8}, |
| 216 | + { S<1, 32, 1, 8>, 4}, |
| 217 | + { S<1, 32, 1, 8>, 4}, |
218 | 218 | // Irregular k
|
219 |
| - { S<1, 32, 1, 8>, 8}, |
220 |
| - { S<1, 32, 1, 8>, 8}, |
221 |
| - { S<1, 32, 1, 8>, 8}, |
222 |
| - { S<1, 32, 1, 8>, 8}, |
223 |
| - { S<1, 32, 1, 8>, 8}, |
224 |
| - { S<1, 32, 1, 8>, 8}, |
| 219 | + { S<1, 32, 1, 8>, 4}, |
| 220 | + { S<1, 32, 1, 8>, 4}, |
| 221 | + { S<1, 32, 1, 8>, 4}, |
| 222 | + { S<1, 32, 1, 8>, 4}, |
| 223 | + { S<1, 32, 1, 8>, 4}, |
| 224 | + { S<1, 32, 1, 8>, 4}, |
225 | 225 | // clang-format on
|
226 | 226 | };
|
227 | 227 |
|
|
0 commit comments