diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index 324dfc069a..f200332588 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -343,7 +343,6 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config using WarpTile = ck_tile::sequence<32, 128>; using ThreadTile = ck_tile::sequence<8, 8>; - constexpr ck_tile::index_t kBlockSize = 256; constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) / @@ -352,7 +351,8 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config using Shape = ck_tile::Reduce2dShape; using Problem = ck_tile::Reduce2dProblem; - using Kernel = ck_tile::Reduce; + using Kernel = ck_tile::Reduce; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides)) { @@ -992,7 +992,11 @@ int main(int argc, char* argv[]) try { +#if CK_TILE_USE_WMMA + return !run_gemm_example(arg_parser); +#else return !run_gemm_example(arg_parser); +#endif } catch(const std::runtime_error& e) { diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index ea3253b629..297ff03992 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -88,7 +88,6 @@ bool run(const ck_tile::ArgParser& arg_parser) // using WarpTile = ck_tile::sequence<1, 512>; // using Vector = ck_tile::sequence<1, 8>; - constexpr ck_tile::index_t kBlockSize = 256; constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::index_t kept_dim_len_prod = N * C; ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) / @@ -99,8 +98,8 @@ bool run(const ck_tile::ArgParser& arg_parser) using Porblem = ck_tile::Reduce2dProblem; - using Kernel = ck_tile::Reduce; - + using Kernel = ck_tile::Reduce; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); // Create input tensor shape and strides auto input_shape = ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp index d486196fc3..aa9fd97171 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -88,10 +88,9 @@ struct matrix_core_swizzle_kernel using karg = matrix_core_swizzle_host_args; using harg = matrix_core_swizzle_host_args; - static constexpr int BLOCK_SIZE = BLOCK_SIZE_; - static constexpr int WavesPerBlock_N = 4; - static constexpr int WavesPerBlock_K = 1; - static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE); + static constexpr int BLOCK_SIZE = BLOCK_SIZE_; + static constexpr int WavesPerBlock_N = BLOCK_SIZE / ck_tile::get_warp_size(); + static constexpr int WavesPerBlock_K = 1; static constexpr int NPerBlock = NPerBlock_; static constexpr int KPerBlock = KPerBlock_; static constexpr matrix_core_permute_style pstyle = pstyle_; diff --git a/example/ck_tile/06_permute/permute.cpp b/example/ck_tile/06_permute/permute.cpp index e68fe4bac3..c4c6f077d7 100644 --- a/example/ck_tile/06_permute/permute.cpp +++ b/example/ck_tile/06_permute/permute.cpp @@ -3,6 +3,7 @@ #include "permute.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/utility/json_dump.hpp" #include #include @@ -128,6 +129,7 @@ auto create_args(int argc, char* argv[]) "non-deterministic seed") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "permute.json", "json file name to dump results"); bool result = arg_parser.parse(argc, argv); @@ -257,6 +259,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return permute(t, a, stream_config); }; +#if !CK_TILE_USE_WMMA #ifdef PERMUTE_USE_ALTERNATIVE_IMPL // batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") || @@ -345,6 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } } else +#endif #endif { ave_time = run_permute(); diff --git a/example/ck_tile/21_elementwise/elementwise_example.cpp b/example/ck_tile/21_elementwise/elementwise_example.cpp index 511449c0de..94d3e70be1 100644 --- a/example/ck_tile/21_elementwise/elementwise_example.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example.cpp @@ -137,8 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // This is often a multiple of the wavefront size, 64 on CDNA. // Here, it's explicitly set to 512. This should be consistent with Shape::kBlockSize. // Shape::kBlockSize would be BlockWarps * warpSize (e.g., 8 * 64 = 512). - constexpr ck_tile::index_t kBlockSize = - ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); // kBlockPerCu: Hint for how many workgroups can be scheduled per Compute Unit (CU). // This can influence occupancy and performance. diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp index a7b9fb494d..ff7ec1517e 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -84,8 +84,7 @@ bool run(const ck_tile::ArgParser& arg_parser) for(auto d : problem_shape) total_elements *= d; - constexpr ck_tile::index_t kBlockSize = - ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 2; diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp index abf2435221..16e9832c07 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -89,8 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t total_elements = M * N; - constexpr ck_tile::index_t kBlockSize = - ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp index 54533e01b2..c5a08d910e 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -78,8 +78,7 @@ bool run(const ck_tile::ArgParser& arg_parser) for(auto d : shape) total_elements *= d; - constexpr ck_tile::index_t kBlockSize = - ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp index e19ff4006b..351c2a0fcf 100644 --- a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -21,7 +21,10 @@ struct ElementWiseKernel using ElementWiseOperation = ck_tile::remove_cvref_t; static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize; - + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? kBlockSize / 2 : kBlockSize; + } template CK_TILE_DEVICE void operator()(const Dims lens, const Dims input_strides, diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp index 5755f38475..92a71a42c8 100644 --- a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -26,6 +26,10 @@ struct Reduce using YDataType = ck_tile::remove_cvref_t; static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize; + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? kBlockSize / 2 : kBlockSize; + } private: // Helper function to calculate optimal vector size for input tensor diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp index 31eb1f2f4f..0499fe370b 100644 --- a/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp @@ -25,11 +25,18 @@ struct Reduce2dShape static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); - static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N; - - static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); - static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + static constexpr index_t RepeatInWarp = + Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size(); + static constexpr index_t RepeatInWarp_M = + (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1; + static constexpr index_t RepeatInWarp_N = + (Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp; + + static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N; + + static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N); static constexpr index_t BlockSize = ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});