Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/ck/utility/amd_buffer_addressing_builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ union BufferResource
};

template <typename T>
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
__device__ int32x4_t make_wave_buffer_resource(T* __restrict__ p_wave, index_t element_space_size)
{
BufferResource<T> wave_buffer_resource;

Expand All @@ -35,7 +35,7 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_
}

template <typename T>
__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave)
__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* __restrict__ p_wave)
{
BufferResource<T> wave_buffer_resource;

Expand Down Expand Up @@ -711,7 +711,7 @@ template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave,
T* __restrict__ p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
Expand Down
2 changes: 1 addition & 1 deletion include/ck/utility/c_style_pointer_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace ck {
template <typename PY,
typename PX,
typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
__host__ __device__ PY c_style_pointer_cast(PX p_x)
__host__ __device__ PY c_style_pointer_cast(PX __restrict__ p_x)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/arch/amd_buffer_addressing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct __attribute__((packed)) buffer_resource
uint32_t config;
};

CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff)
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* __restrict__ ptr, uint32_t size = 0xffffffff)
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct __attribute__((packed)) buffer_resource
uint32_t config;
};

CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff)
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* __restrict__ ptr, uint32_t size = 0xffffffff)
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/core/tensor/load_tile_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ CK_TILE_DEVICE auto
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
NumCoord>& __restrict__ tile_window)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
Expand Down
19 changes: 15 additions & 4 deletions include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,6 @@ struct FmhaFwdV3Kernel
{
using namespace ck_tile;

// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];

// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);

Expand Down Expand Up @@ -483,14 +480,28 @@ struct FmhaFwdV3Kernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();

__shared__ char
smem_k0[FmhaPipeline::Policy::template GetSmemSizeKV<typename FmhaPipeline::Problem>()];
__shared__ char
smem_k1[FmhaPipeline::Policy::template GetSmemSizeKV<typename FmhaPipeline::Problem>()];
__shared__ char
smem_v0[FmhaPipeline::Policy::template GetSmemSizeKV<typename FmhaPipeline::Problem>()];
__shared__ char
smem_v1[FmhaPipeline::Policy::template GetSmemSizeKV<typename FmhaPipeline::Problem>()];
__shared__ char smem[1];

auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
kargs.scale_s,
smem_ptr);
reinterpret_cast<KDataType*>(smem_k0),
reinterpret_cast<KDataType*>(smem_k1),
reinterpret_cast<VDataType*>(smem_v0),
reinterpret_cast<VDataType*>(smem_v1),
smem);
}();

// O DRAM and O DRAM window
Expand Down
137 changes: 68 additions & 69 deletions include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ struct BlockFmhaFwdV3Pipeline
}

template <typename DataType, typename Descriptor>
CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc)
CK_TILE_DEVICE static constexpr auto make_lds_tile_window(DataType* __restrict__ base,
const Descriptor& desc)
{
using namespace ck_tile;

auto tensor_view =
make_tensor_view<address_space_enum::lds>(reinterpret_cast<DataType*>(base), desc);
auto tensor_view = make_tensor_view<address_space_enum::lds>(base, desc);
return make_tile_window(tensor_view, desc.get_lengths(), {0, 0});
}

Expand Down Expand Up @@ -343,20 +343,25 @@ struct BlockFmhaFwdV3Pipeline
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
[[maybe_unused]] const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
[[maybe_unused]] const VElementFunction& v_element_func,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale_s,
void* smem_ptr) const
CK_TILE_DEVICE auto
operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile
[[maybe_unused]] const KElementFunction& k_element_func,
const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile
[[maybe_unused]] const VElementFunction& v_element_func,
LSEDramBlockWindowTmp& __restrict__ lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
float scale_s,
KDataType* __restrict__ smem_k0,
KDataType* __restrict__ smem_k1,
VDataType* __restrict__ smem_v0,
VDataType* __restrict__ smem_v1,
void* __restrict__ smem_ptr) const
{
using namespace ck_tile;

Expand All @@ -375,28 +380,22 @@ struct BlockFmhaFwdV3Pipeline

static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<kM0, kN0>());
static_cast<SaccDataType* __restrict__>(smem_ptr), MakeSimpleLdsDesc<kM0, kN0>());
[[maybe_unused]] auto s_lds_window =
make_tile_window(s_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});

auto p_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSize<Problem>()),
MakeSimpleLdsDesc<kM0, kN0>());
static_cast<PDataType* __restrict__>(smem_ptr), MakeSimpleLdsDesc<kM0, kN0>());
[[maybe_unused]] auto p_lds_window =
make_tile_window(p_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});

auto o_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
MakeSimpleLdsDesc<kM0, kN1>());
static_cast<PDataType* __restrict__>(smem_ptr), MakeSimpleLdsDesc<kM0, kN1>());
[[maybe_unused]] auto o_lds_window =
make_tile_window(o_lds, make_tuple(number<kM0>{}, number<kN1>{}), {0, 0});

auto m_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
Policy::template GetSmemSize<Problem>()),
MakeSimpleLdsDesc1D<kM0>());
static_cast<SMPLComputeDataType* __restrict__>(smem_ptr), MakeSimpleLdsDesc1D<kM0>());
[[maybe_unused]] auto m_lds_window =
make_tile_window(m_lds, make_tuple(number<kM0>{}), {0});

Expand All @@ -413,35 +412,21 @@ struct BlockFmhaFwdV3Pipeline
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };

auto k_lds_window_store = generate_tuple(
[&](auto i_buf) {
return make_lds_tile_window<KDataType>(
smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
},
number<2>{});

auto v_lds_window_store = generate_tuple(
[&](auto i_buf) {
return make_lds_tile_window<KDataType>(
smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
},
number<2>{});

statically_indexed_array<decltype(make_tile_window(
make_lds_tile_window<KDataType>(
nullptr,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
Policy::template MakeKLdsBlockDescriptor<Problem>()),
Policy::template MakeKRegTileDistribution<Problem>())),
2>
k_lds_window_load;
k_lds_window;

statically_indexed_array<decltype(make_tile_window(
make_lds_tile_window<VDataType>(
nullptr,
Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
Policy::template MakeVLdsBlockDescriptor<Problem>()),
Policy::template MakeVRegTileDistribution<Problem>())),
2>
v_lds_window_load;
v_lds_window;

decltype(make_static_distributed_tensor<QDataType>(
Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
Expand All @@ -450,9 +435,9 @@ struct BlockFmhaFwdV3Pipeline
{
CK_TILE_DEVICE kv_tile_type() {}

decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile;
decltype(load_tile(k_lds_window(number<0>{}))) k_tile;

decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
decltype(load_tile_transpose(v_lds_window(number<0>{}))) v_tile;
} kv_tile;

union sp_compute_type
Expand All @@ -476,19 +461,28 @@ struct BlockFmhaFwdV3Pipeline

// initialize k_lds_window and v_lds_window
static_for<0, 2, 1>{}([&](auto idx) {
k_lds_window_load(idx) = make_tile_window(
make_lds_tile_window<KDataType>(
static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
Policy::template MakeKRegTileDistribution<Problem>());
k_lds_window(idx) =
make_tile_window(make_lds_tile_window(
[&] {
if constexpr(idx == 0)
return smem_k0;
else
return smem_k1;
}(),
Policy::template MakeKLdsBlockDescriptor<Problem>()),
Policy::template MakeKRegTileDistribution<Problem>());
});

static_for<0, 2, 1>{}([&](auto idx) {
v_lds_window_load(idx) =
make_tile_window(make_lds_tile_window<VDataType>(
static_cast<char*>(smem_ptr) +
(idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
v_lds_window(idx) =
make_tile_window(make_lds_tile_window(
[&] {
if constexpr(idx == 0)
return smem_v0;
else
return smem_v1;
}(),
Policy::template MakeVLdsBlockDescriptor<Problem>()),
Policy::template MakeVRegTileDistribution<Problem>());
});

Expand Down Expand Up @@ -536,14 +530,12 @@ struct BlockFmhaFwdV3Pipeline
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
k_dram_window.init_raw();

auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
v_dram_window.init_raw();

// prefetch K tile
index_t i_total_loops = 0;
Expand Down Expand Up @@ -635,27 +627,26 @@ struct BlockFmhaFwdV3Pipeline
static constexpr int V_mem_su_ld_insts = 1;

auto K_mem_load = [&](auto k_lds_write_idx) {
async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
async_load_tile(k_lds_window(k_lds_write_idx), k_dram_window);

/// FIXME: use the future-predicting method to move the window
// move K tile windows
move_tile_window(k_dram_window, {kN0, 0});
};

auto K_lds_load = [&](auto k_lds_read_idx) {
kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx));
kv_tile.k_tile = load_tile(k_lds_window(k_lds_read_idx));
};

auto V_mem_load = [&](auto v_lds_write_idx) {
async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
__builtin_amdgcn_sched_barrier(0);
async_load_tile(v_lds_window(v_lds_write_idx), v_dram_window);

/// FIXME: use the future-predicting method to move the window
move_tile_window(v_dram_window, {kK1, 0});
};

auto V_lds_load = [&](auto v_lds_read_idx) {
kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
kv_tile.v_tile = load_tile_transpose(v_lds_window(v_lds_read_idx));
};

decltype(m) m_old;
Expand Down Expand Up @@ -1168,13 +1159,17 @@ struct BlockFmhaFwdV3Pipeline
typename VDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile
LSEDramBlockWindowTmp& __restrict__ lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
void* smem_ptr) const
KDataType* __restrict__ smem_k0,
KDataType* __restrict__ smem_k1,
VDataType* __restrict__ smem_v0,
VDataType* __restrict__ smem_v1,
void* __restrict__ smem_ptr) const
{
using namespace ck_tile;

Expand All @@ -1191,6 +1186,10 @@ struct BlockFmhaFwdV3Pipeline
identity{},
mask,
scale_s,
smem_k0,
smem_k1,
smem_v0,
smem_v1,
smem_ptr);
}
};
Expand Down
Loading
Loading