Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
file(GLOB xpu_h "xpu/*.h")
file(GLOB xpu_cpp "xpu/*.cpp")
file(GLOB xpu_mkl "native/xpu/mkl/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/cutlass/*.cpp")
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/nested/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp")

list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
Expand All @@ -25,3 +25,24 @@ endforeach()
foreach(HEADER ${xpu_ops_generated_headers})
install(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/ops)
endforeach()

include(ExternalProject)
ExternalProject_Add(
cutlass_sycl_kernels_proj
SOURCE_DIR ${TORCH_XPU_OPS_ROOT}/src/ATen/native/cutlass/sycl
CMAKE_ARGS
-DCMAKE_C_COMPILER=icx
-DCMAKE_CXX_COMPILER=icpx
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
BUILD_ALWAYS TRUE
INSTALL_COMMAND ""
BUILD_BYPRODUCTS "cutlass_sycl_kernels_proj-prefix/src/cutlass_sycl_kernels_proj-build/libcutlass_kernels.so"
)

ExternalProject_Get_Property(cutlass_sycl_kernels_proj SOURCE_DIR BINARY_DIR)
set(CUTLASS_SYCL_KERNELS_LIBRARIES ${BINARY_DIR}/libcutlass_kernels.so)

add_library(cutlass_sycl_kernels INTERFACE)
add_dependencies(cutlass_sycl_kernels cutlass_sycl_kernels_proj)
target_link_libraries(cutlass_sycl_kernels INTERFACE ${CUTLASS_SYCL_KERNELS_LIBRARIES})
install(FILES ${CUTLASS_SYCL_KERNELS_LIBRARIES} DESTINATION "${TORCH_INSTALL_LIB_DIR}")
65 changes: 65 additions & 0 deletions src/ATen/native/cutlass/Attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty_like.h>
#include <ATen/ops/linear.h>
#include <ATen/ops/scaled_dot_product_attention.h>
#endif

#include <ATen/native/cutlass/Attention.h>
#include <ATen/native/cutlass/sycl/AttentionKernels.h>

#include <comm/SYCLContext.h>

namespace at {
namespace native {
namespace cutlass_sycl{

void sdpa_backward(
int batch_size,
int num_head_q,
int num_head_kv,
int seq_len_q,
int seq_len_kv,
int head_dim_qk,
int head_dim_v,
const Tensor& grad_out,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& out,
const Tensor& logsumexp,
std::optional<at::Tensor> attn_mask,
bool is_causal,
double scale,
Tensor& grad_query,
Tensor& grad_key,
Tensor& grad_value) {

std::cout << "lfq: entering cutlass sdpa_backward" << std::endl;

auto ps = at::matmul(query, key.transpose(-2, -1));
ps = ps / std::sqrt(scale);
ps = at::softmax(ps, -1).to(query.dtype());
auto dps = at::empty_like(ps);
cutlass_sdpa_backward(batch_size, num_head_q, num_head_kv, seq_len_q, seq_len_kv,
head_dim_qk, head_dim_v,
grad_out.data_ptr(),
query.data_ptr(),
key.data_ptr(),
value.data_ptr(),
ps.data_ptr(),
nullptr,
grad_query.data_ptr(),
grad_key.data_ptr(),
grad_value.data_ptr(),
dps.data_ptr());
}
} // cutlass_sycl
} // namespace native
} // namespace at
32 changes: 32 additions & 0 deletions src/ATen/native/cutlass/Attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <ATen/ATen.h>

namespace at {
namespace native {
namespace cutlass_sycl{

void sdpa_backward(
int batch_size,
int num_head_q,
int num_head_kv,
int seq_len_q,
int seq_len_kv,
int head_dim_qk,
int head_dim_v,
const Tensor& grad_out,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& out,
const Tensor& logsumexp,
std::optional<at::Tensor> attn_mask,
bool is_causal,
double scale,
Tensor& grad_query,
Tensor& grad_key,
Tensor& grad_value);

} // namespace cutlass_sycl
} // namespace native
} // namespace at
Loading
Loading