Skip to content

Commit c9a318a

Browse files
WIP
1 parent 61d4460 commit c9a318a

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ static __global__ void flash_attn_tile(
158158
return;
159159
#endif // FP16_MMA_AVAILABLE
160160

161-
constexpr int warp_size_physical = ggml_cuda_get_physical_warp_size();
162-
constexpr int warp_size = D/2 < warp_size_physical ? D/2 : warp_size_physical;
161+
constexpr int warp_size = 32;
163162
constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size;
164163
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
165164
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
@@ -527,8 +526,7 @@ static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml
527526

528527
const int id = ggml_cuda_get_device();
529528
const int cc = ggml_cuda_info().devices[id].cc;
530-
const int warp_size_physical = ggml_cuda_info().devices[id].warp_size;
531-
const int warp_size = D/2 < warp_size_physical ? D/2 : warp_size_physical;
529+
const int warp_size = 32;
532530
const int nwarps = FATTN_TILE_NTHREADS / warp_size;
533531

534532
constexpr size_t nbytes_shared = 0;

0 commit comments

Comments
 (0)