Skip to content

Commit 0f8e33f

Browse files
authored
Extend XDL kernel to Support RDNA3/4 - Part 3 (#2723)
Support Wave32/Wave64 in all XDL Kernels 1. Add following helper function/marocs in device_base.hpp - GET_NXDL_PER_WAVE_IMPL and GetNXdlPerWave2 - INVOKER_RUN_IMPL and INVOKER_RUN3_IMPL - IsValidGemmCompilationParameter and IS_VALID_COMPILATION_PARAMETER_IMPL 2. Replace GridwiseGemm to GridwiseGemm32 and GridwiseGemm64, and use one of them according to current GPU target 3. Move gridwise gemm related variable from Argument member to local variable in RunImp - It is to avoid duplicated GridwiseGemm::CheckValidity 4. Add IsValidGemmCompilationParameter to all XDL kernels. Know issues: - DeviceBatchedGemmXdl and DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle are incorrect on gfx11. - DeviceGemmMultipleDLayernorm_Xdl_CShuffle are incorrect on both gfx11 and gfx12.
1 parent e4a7728 commit 0f8e33f

File tree

131 files changed

+8827
-5425
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

131 files changed

+8827
-5425
lines changed

codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
55
#include "ck/host/stringutils.hpp"
@@ -76,28 +76,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
7676
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
7777
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
7878
// | | | | | | | | | | | Wave| Wave| Wave| |
79-
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
80-
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
81-
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
82-
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
83-
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
84-
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
85-
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
86-
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
87-
{ 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
88-
{ 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
89-
{ 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
90-
{ 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
79+
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 16, 16, 4, 8, 4, 1},
80+
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 16, 16, 4, 8, 8, 1},
81+
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
82+
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
83+
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1},
84+
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 16, 16, 2, 8, 4, 1},
85+
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
86+
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
87+
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
88+
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
89+
{ 256, 128, 256, 64, 128, 32, 8, 8, 2, 16, 16, 2, 16, 8, 1},
90+
{ 256, 128, 256, 64, 64, 32, 8, 8, 2, 16, 16, 2, 16, 4, 1},
9191
// Padded fallback kernel
92-
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
93-
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
92+
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 16, 16, 2, 8, 8, 1},
93+
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 16, 16, 2, 4, 8, 1},
9494
// Irregular k
95-
{ 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
96-
{ 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
97-
{ 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
98-
{ 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
99-
{ 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
100-
{ 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
95+
{ 256, 256, 128, 48, 64, 32, 4, 4, 2, 16, 16, 4, 8, 4, 1},
96+
{ 256, 256, 128, 48, 128, 32, 4, 4, 2, 16, 16, 4, 8, 8, 1},
97+
{ 256, 128, 256, 48, 64, 32, 4, 4, 2, 16, 16, 2, 16, 4, 1},
98+
{ 256, 128, 256, 48, 128, 32, 4, 4, 2, 16, 16, 2, 16, 8, 1},
99+
{ 256, 128, 128, 48, 64, 32, 4, 4, 2, 16, 16, 2, 8, 4, 1},
100+
{ 256, 128, 128, 48, 128, 32, 4, 4, 2, 16, 16, 2, 8, 8, 1},
101101
// clang-format on
102102
};
103103

@@ -200,28 +200,28 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
200200
// _MBlock_MWaveMPerXdl| ScalarPerVector
201201
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
202202
// |
203-
{ S<1, 32, 1, 8>, 8},
204-
{ S<1, 32, 1, 8>, 8},
205-
{ S<1, 32, 1, 8>, 8},
206-
{ S<1, 32, 1, 8>, 8},
207-
{ S<1, 32, 1, 8>, 8},
208-
{ S<1, 32, 1, 8>, 8},
209-
{ S<1, 32, 1, 8>, 8},
210-
{ S<1, 32, 1, 8>, 8},
211-
{ S<1, 16, 1,16>, 8},
212-
{ S<1, 32, 1, 8>, 8},
213-
{ S<1, 16, 1,16>, 8},
214-
{ S<1, 32, 1, 8>, 8},
203+
{ S<1, 32, 1, 8>, 4},
204+
{ S<1, 32, 1, 8>, 4},
205+
{ S<1, 32, 1, 8>, 4},
206+
{ S<1, 32, 1, 8>, 4},
207+
{ S<1, 32, 1, 8>, 4},
208+
{ S<1, 32, 1, 8>, 4},
209+
{ S<1, 32, 1, 8>, 4},
210+
{ S<1, 32, 1, 8>, 4},
211+
{ S<1, 16, 1,16>, 4},
212+
{ S<1, 32, 1, 8>, 4},
213+
{ S<1, 16, 1,16>, 4},
214+
{ S<1, 32, 1, 8>, 4},
215215
// Padded fallback kernel
216-
{ S<1, 32, 1, 8>, 8},
217-
{ S<1, 32, 1, 8>, 8},
216+
{ S<1, 32, 1, 8>, 4},
217+
{ S<1, 32, 1, 8>, 4},
218218
// Irregular k
219-
{ S<1, 32, 1, 8>, 8},
220-
{ S<1, 32, 1, 8>, 8},
221-
{ S<1, 32, 1, 8>, 8},
222-
{ S<1, 32, 1, 8>, 8},
223-
{ S<1, 32, 1, 8>, 8},
224-
{ S<1, 32, 1, 8>, 8},
219+
{ S<1, 32, 1, 8>, 4},
220+
{ S<1, 32, 1, 8>, 4},
221+
{ S<1, 32, 1, 8>, 4},
222+
{ S<1, 32, 1, 8>, 4},
223+
{ S<1, 32, 1, 8>, 4},
224+
{ S<1, 32, 1, 8>, 4},
225225
// clang-format on
226226
};
227227

codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include "ck/host/device_gemm_multiple_d/operation.hpp"
55
#include "ck/host/stringutils.hpp"
@@ -81,16 +81,16 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
8181
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
8282
// | | | | | | | | Wave| Wave| Stage|
8383
// | | | | | | | | | | |
84-
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
85-
{ 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1},
86-
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1},
87-
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
88-
{ 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1},
89-
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
90-
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
91-
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
84+
{ 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1},
85+
{ 256, 128, 256, 32, 8, 8, 16, 16, 4, 8, 1},
86+
{ 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1},
87+
{ 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1},
88+
{ 128, 128, 64, 32, 8, 8, 16, 16, 4, 4, 1},
89+
{ 128, 64, 128, 32, 8, 8, 16, 16, 4, 4, 1},
90+
{ 256, 128, 64, 32, 8, 8, 16, 16, 4, 2, 1},
91+
{ 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, 1},
9292
// Irregular tile
93-
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
93+
{ 64, 32, 32, 32, 8, 8, 16, 16, 2, 2, 1},
9494
// clang-format on
9595
};
9696

@@ -194,14 +194,14 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
194194
// _MBlock_MWaveMPerXdl| ScalarPerVector
195195
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
196196
// |
197-
{ S<1, 32, 1, 8>, 8},
198-
{ S<1, 32, 1, 8>, 8},
199-
{ S<1, 16, 1, 8>, 8},
200-
{ S<1, 32, 1, 8>, 8},
201-
{ S<1, 32, 1, 4>, 8},
202-
{ S<1, 16, 1, 8>, 8},
203-
{ S<1, 32, 1, 8>, 8},
204-
{ S<1, 32, 1, 8>, 8},
197+
{ S<1, 32, 1, 8>, 4},
198+
{ S<1, 32, 1, 8>, 4},
199+
{ S<1, 16, 1, 8>, 4},
200+
{ S<1, 32, 1, 8>, 4},
201+
{ S<1, 32, 1, 4>, 4},
202+
{ S<1, 16, 1, 8>, 4},
203+
{ S<1, 32, 1, 8>, 4},
204+
{ S<1, 32, 1, 8>, 4},
205205
// Irregular tile
206206
{ S<1, 16, 1, 4>, 1},
207207
// clang-format on

codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
55
#include <iostream>
@@ -55,12 +55,12 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
5555
// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch|
5656
// | | | | | | | | Wave| Wave| Stage|
5757
// | | | | | | | | | | |
58-
{ 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, 1},
59-
{ 256, 128, 256, 32, 8, 8, 32, 32, 4, 2, 1},
60-
{ 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1},
61-
{ 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, 1},
62-
{ 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1},
63-
{ 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}
58+
{ 64, 64, 32, 32, 8, 8, 16, 16, 4, 2, 1},
59+
{ 256, 128, 256, 32, 8, 8, 16, 16, 8, 4, 1},
60+
{ 256, 128, 128, 32, 8, 8, 16, 16, 4, 4, 1},
61+
{ 64, 64, 64, 32, 8, 8, 16, 16, 4, 4, 1},
62+
{ 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, 1},
63+
{ 128, 128, 128, 32, 8, 8, 16, 16, 8, 4, 1}
6464
// clang-format on
6565
};
6666

@@ -116,11 +116,11 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
116116
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
117117
// |
118118
{ S<1, 16, 1, 4>, 1},
119-
{ S<1, 32, 1, 8>, 8},
120-
{ S<1, 32, 1, 8>, 8},
119+
{ S<1, 16, 1, 16>, 4},
120+
{ S<1, 32, 1, 8>, 4},
121121
{ S<1, 16, 1, 4>, 1},
122-
{ S<1, 32, 1, 8>, 8},
123-
{ S<1, 16, 1, 8>, 8}
122+
{ S<1, 32, 1, 8>, 4},
123+
{ S<1, 16, 1, 8>, 4}
124124
// clang-format on
125125
};
126126

@@ -223,8 +223,9 @@ extern "C" __global__ void run_${name}(
223223
constexpr ck::LoopScheduler LoopSched = ck::make_default_loop_scheduler();
224224
225225
// GridwiseGemm
226-
using GridwiseGemm = DeviceConv::GridwiseGemm;
227-
226+
using GridwiseGemm = ck::conditional_t<ck::get_warp_size() == 64,
227+
typename DeviceConv::GridwiseGemm64,
228+
typename DeviceConv::GridwiseGemm32>;
228229
static constexpr auto I0 = ck::Number<0>{};
229230
230231
ck::tensor_operation::device::device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<

codegen/src/utils.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#include "ck/host/utils.hpp"
55

@@ -13,7 +13,8 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
1313

1414
const std::unordered_set<std::string>& get_xdlop_archs()
1515
{
16-
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx942"};
16+
static std::unordered_set<std::string> supported_archs{
17+
"gfx90a", "gfx908", "gfx942", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"};
1718
return supported_archs;
1819
}
1920

codegen/test/grouped_conv_fwd_multiple_d_v1.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ struct Epilogue
160160
Epilogue{1.0f, 1.0f});
161161
out_host.SetZero();
162162
ref_invoker.Run(ref_argument);**/
163-
163+
int i = 0;
164164
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
165165
{
166+
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
166167
// substitute instance values into the template
167168
auto src = ck::host::InterpolateString(
168169
conv_compile_check,

codegen/test/grouped_conv_fwd_multiple_d_v2.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ struct Epilogue
160160
Epilogue{1.0f, 1.0f});
161161
out_host.SetZero();
162162
ref_invoker.Run(ref_argument);**/
163-
163+
int i = 0;
164164
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
165165
{
166+
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
166167
// substitute instance values into the template
167168
auto src = ck::host::InterpolateString(
168169
conv_compile_check,

codegen/test/grouped_conv_fwd_multiple_d_v3.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ struct Epilogue
160160
Epilogue{1.0f, 1.0f});
161161
out_host.SetZero();
162162
ref_invoker.Run(ref_argument);**/
163-
163+
int i = 0;
164164
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
165165
{
166+
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
166167
// substitute instance values into the template
167168
auto src = ck::host::InterpolateString(
168169
conv_compile_check,

codegen/test/grouped_conv_fwd_multiple_d_v4.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ struct Epilogue
160160
Epilogue{1.0f, 1.0f});
161161
out_host.SetZero();
162162
ref_invoker.Run(ref_argument);**/
163-
163+
int i = 0;
164164
for(auto solution : prob.GetSolutions("gfx908", prologue, epilogue))
165165
{
166+
std::cout << "Testing solution " << std::to_string(++i) << std::endl;
166167
// substitute instance values into the template
167168
auto src = ck::host::InterpolateString(
168169
conv_compile_check,

include/ck/host_utility/device_prop.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,34 @@ inline bool is_xdl_supported()
7575
;
7676
}
7777

78+
template <typename ADataType, typename BDataType, index_t MPerXDL, index_t NPerXDL>
79+
inline bool is_xdl_wmma_supported()
80+
{
81+
if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
82+
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")
83+
{
84+
return true;
85+
}
86+
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
87+
else if(is_gfx12_supported() || is_gfx11_supported())
88+
{
89+
if constexpr((MPerXDL != 16) || (NPerXDL != 16))
90+
{
91+
return false;
92+
}
93+
if constexpr(sizeof(ADataType) > 2 || sizeof(BDataType) > 2)
94+
{
95+
return false;
96+
}
97+
return true;
98+
}
99+
#endif
100+
else
101+
{
102+
return false;
103+
}
104+
}
105+
78106
inline bool is_lds_direct_load_supported()
79107
{
80108
// Check if direct loads from global memory to LDS are supported.

include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,8 @@ struct BlockwiseGemmXdlops_pipeline_v4
108108

109109
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
110110

111-
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
112-
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
113-
static_assert(MWaves > 0);
114-
static_assert(NWaves > 0);
111+
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
112+
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
115113
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
116114

117115
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);

0 commit comments

Comments
 (0)