Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
using WarpTile = ck_tile::sequence<32, 128>;
using ThreadTile = ck_tile::sequence<8, 8>;

constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1;

ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) /
Expand All @@ -352,7 +351,8 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
using Problem =
ck_tile::Reduce2dProblem<CDataType, ComputeDataType, CDataType, Shape, ReduceOp>;
using Kernel = ck_tile::Reduce<Problem>;
using Kernel = ck_tile::Reduce<Problem>;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();

if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides))
{
Expand Down Expand Up @@ -992,7 +992,11 @@ int main(int argc, char* argv[])

try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{
Expand Down
5 changes: 2 additions & 3 deletions example/ck_tile/05_reduce/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>;

constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kept_dim_len_prod = N * C;
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
Expand All @@ -99,8 +98,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Porblem =
ck_tile::Reduce2dProblem<XDataType, ComputeDataType, YDataType, Shape, ReduceOp>;

using Kernel = ck_tile::Reduce<Porblem>;

using Kernel = ck_tile::Reduce<Porblem>;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
// Create input tensor shape and strides
auto input_shape =
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ struct matrix_core_swizzle_kernel
using karg = matrix_core_swizzle_host_args;
using harg = matrix_core_swizzle_host_args;

static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = 4;
static constexpr int WavesPerBlock_K = 1;
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = BLOCK_SIZE / ck_tile::get_warp_size();
static constexpr int WavesPerBlock_K = 1;
static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_;
Expand Down
4 changes: 4 additions & 0 deletions example/ck_tile/06_permute/permute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "permute.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/utility/json_dump.hpp"

#include <array>
#include <cstring>
Expand Down Expand Up @@ -128,6 +129,7 @@ auto create_args(int argc, char* argv[])
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "permute.json", "json file name to dump results");

bool result = arg_parser.parse(argc, argv);
Expand Down Expand Up @@ -257,6 +259,7 @@ bool run(const ck_tile::ArgParser& arg_parser)

return permute(t, a, stream_config);
};
#if !CK_TILE_USE_WMMA
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
Expand Down Expand Up @@ -345,6 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
#endif
#endif
{
ave_time = run_permute();
Expand Down
3 changes: 1 addition & 2 deletions example/ck_tile/21_elementwise/elementwise_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// This is often a multiple of the wavefront size, 64 on CDNA.
// Here, it's explicitly set to 512. This should be consistent with Shape::kBlockSize.
// Shape::kBlockSize would be BlockWarps * warpSize (e.g., 8 * 64 = 512).
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = Kernel::BlockSize();

// kBlockPerCu: Hint for how many workgroups can be scheduled per Compute Unit (CU).
// This can influence occupancy and performance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(auto d : problem_shape)
total_elements *= d;

constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = Kernel::BlockSize();

constexpr ck_tile::index_t kBlockPerCu = 2;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ bool run(const ck_tile::ArgParser& arg_parser)

ck_tile::index_t total_elements = M * N;

constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(auto d : shape)
total_elements *= d;

constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;

constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ struct ElementWiseKernel
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;

static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;

CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
template <typename... XDataType, typename Dims>
CK_TILE_DEVICE void operator()(const Dims lens,
const Dims input_strides,
Expand Down
4 changes: 4 additions & 0 deletions include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ struct Reduce
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;

static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}

private:
// Helper function to calculate optimal vector size for input tensor
Expand Down
17 changes: 12 additions & 5 deletions include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ struct Reduce2dShape
static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{});
static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{});

static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N;

static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t RepeatInWarp =
Warp_M * Warp_N / ThreadTile_M / ThreadTile_N / ck_tile::get_warp_size();
static constexpr index_t RepeatInWarp_M =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? RepeatInWarp : 1;
static constexpr index_t RepeatInWarp_N =
(Warp_M / ThreadTile_M > Warp_N / ThreadTile_N) ? 1 : RepeatInWarp;

static constexpr index_t ThreadPerWarp_M = Warp_M / ThreadTile_M / RepeatInWarp_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / ThreadTile_N / RepeatInWarp_N;

static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);

static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
Expand Down
Loading