Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 102 additions & 8 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ namespace {
#define XCCL_HAS_AVG 1
#endif // oneCCL version >= 2021.15

#if defined(CCL_MAJOR_VERSION) && \
((CCL_MAJOR_VERSION > 2021) || \
(CCL_MAJOR_VERSION == 2021) && (CCL_MINOR_VERSION >= 17))
#define ENABLE_XCCL_PREMUL_SUM_SUPPORT
#endif // oneCCL version >= 2021.17

const std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
{ReduceOp::MIN, ccl::reduction::min},
{ReduceOp::MAX, ccl::reduction::max},
Expand Down Expand Up @@ -44,6 +50,33 @@ const std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kFloat8_e5m2fnuz, ccl::datatype::uint8},
};

struct xcclRedOpRAII {
xcclRedOpRAII() = default;
xcclRedOpRAII(ccl::reduction op) : op_(op) {}
xcclRedOpRAII(ccl::reduction op, const xcclComm_t* comm)
: op_(op), comm_(comm), premul_sum_(true) {}
xcclRedOpRAII(const xcclRedOpRAII&) = delete;
xcclRedOpRAII& operator=(const xcclRedOpRAII&) = delete;
xcclRedOpRAII(xcclRedOpRAII&& tmp) noexcept : xcclRedOpRAII() {
std::swap(tmp.op_, this->op_);
std::swap(tmp.comm_, this->comm_);
std::swap(tmp.premul_sum_, this->premul_sum_);
}
#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT)
~xcclRedOpRAII() {
if (premul_sum_) {
ccl::reduction_destroy(op_, *comm_);
}
}
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
Copy link
Preview

Copilot AI Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The destructor is only defined when ENABLE_XCCL_PREMUL_SUM_SUPPORT is defined, but the class is used regardless of this macro. This will cause linking errors when the macro is not defined. The destructor should be defined unconditionally with appropriate conditional logic inside.

Suggested change
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
~xcclRedOpRAII() {
#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT)
if (premul_sum_) {
ccl::reduction_destroy(op_, *comm_);
}
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
}

Copilot uses AI. Check for mistakes.

operator ccl::reduction() const {
return op_;
}
ccl::reduction op_{};
const xcclComm_t* comm_ = nullptr;
bool premul_sum_ = false;
};

bool computeLengthsAndCheckAndGetFlat(
const std::vector<at::Tensor>& tensors,
std::vector<size_t>& lengths,
Expand Down Expand Up @@ -152,7 +185,37 @@ ccl::datatype getXcclDataType(
return it->second;
}

ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT
template <typename T, ccl::datatype dataType>
xcclRedOpRAII unpackPreMulSum(
const ReduceOp& reduceOp,
const xcclComm_t& comm) {
const auto* preMulSupplement =
reinterpret_cast<XCCLPreMulSumSupplement*>(reduceOp.supplement_.get());
ccl::reduction preMulSum{};
bool has_tensor = preMulSupplement->tensor_factor.defined();
auto residence = has_tensor
? ccl::scalar_residence_type::scalar_device
: ccl::scalar_residence_type::scalar_host_immediate;
const T* ptr_factor = has_tensor
? preMulSupplement->tensor_factor.const_data_ptr<T>()
: nullptr;
T scalar_factor = T(preMulSupplement->double_factor);
ccl::reduction_create_pre_mul_sum(
&preMulSum,
/*scalar=*/has_tensor ? const_cast<T*>(ptr_factor) : &scalar_factor,
dataType,
residence,
comm);
return xcclRedOpRAII(preMulSum, &comm);
}
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT

xcclRedOpRAII getXcclReduceOp(
const ReduceOp& reduceOp,
at::Tensor& input,
const ccl::datatype& dataType,
xcclComm_t& comm) {
try {
if (input.scalar_type() == at::kBool) {
if (reduceOp == ReduceOp::SUM) {
Expand All @@ -171,6 +234,30 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
return ccl::reduction::sum;
}
#endif
if (reduceOp == ReduceOp::PREMUL_SUM) {
#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT
switch (dataType) {
case ccl::datatype::float16:
return unpackPreMulSum<at::Half, ccl::datatype::float16>(
reduceOp, comm);
case ccl::datatype::float32:
return unpackPreMulSum<float, ccl::datatype::float32>(reduceOp, comm);
case ccl::datatype::bfloat16:
return unpackPreMulSum<float, ccl::datatype::bfloat16>(
Copy link
Preview

Copilot AI Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bfloat16 data type, the template should use at::BFloat16 instead of float. Using float for bfloat16 data will cause type mismatch issues when accessing the tensor data.

Suggested change
return unpackPreMulSum<float, ccl::datatype::bfloat16>(
return unpackPreMulSum<at::BFloat16, ccl::datatype::bfloat16>(

Copilot uses AI. Check for mistakes.

reduceOp, comm);
case ccl::datatype::float64:
return unpackPreMulSum<double, ccl::datatype::float64>(
reduceOp, comm);
default:
C10_THROW_ERROR(
TypeError,
"PreMulSum Data type must be half, float, bfloat16 or double");
return ccl::reduction{};
}
#else
C10_THROW_ERROR(ValueError, "PreMulSum requires oneCCL>=2021.17");
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
}
return xcclOps.at(reduceOp);
} catch (const std::out_of_range&) {
C10_THROW_ERROR(
Expand Down Expand Up @@ -1266,7 +1353,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1363,7 +1451,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1525,7 +1614,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
ccl::stream& xcclStream) {
const int root = opts.rootRank + opts.rootTensor;
const auto xcclDataType = getXcclDataType(input.scalar_type(), true);
const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
const auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1569,7 +1659,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_oop(
ccl::stream& xcclStream) {
const int root = opts.rootRank + opts.rootTensor;
const auto xcclDataType = getXcclDataType(input.scalar_type(), true);
const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
const auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1829,7 +1920,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1924,7 +2016,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1983,7 +2076,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down
18 changes: 18 additions & 0 deletions src/xccl/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,24 @@ TORCH_API std::string dump_xccl_trace(
bool onlyActive);

TORCH_API std::string getXcclVersion();

struct XCCLPreMulSumSupplement : _SupplementBase {
double double_factor{0.0};
at::Tensor tensor_factor;
XCCLPreMulSumSupplement(double f) : double_factor{f} {}
XCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} {
TORCH_CHECK_EQ(tensor_factor.numel(), 1);
}
};

template <typename T>
ReduceOp makeXCCLPreMulSum(const T& factor) {
ReduceOp rop;
rop.op_ = ReduceOp::PREMUL_SUM;
rop.supplement_ = c10::make_intrusive<XCCLPreMulSumSupplement>(factor);
return rop;
}

} // namespace c10d

namespace {
Expand Down
37 changes: 37 additions & 0 deletions test/xpu/distributed/test_c10d_ops_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,28 @@ def reduce(xs, rootRank, rootTensor, op=None):
):
reduce(tensors, self.rank, rt, op)

for factor in (3.0, torch.tensor([5.0], device=local_device_id)):
if isinstance(factor, torch.Tensor):
factor_ref = factor.cpu().item()
else:
factor_ref = factor
float_tensors = [
torch.tensor(
[self.rank + 1.0], device=f"xpu:{local_device_id}"
)
]
float_tensors_ref = [
torch.tensor(
[(self.rank + 1.0) * factor_ref],
device=f"xpu:{local_device_id}",
)
]

reduce(float_tensors_ref, rt, 0)
reduce(float_tensors, rt, 0, c10d._make_xccl_premul_sum(factor))
if self.rank == rt:
self.assertEqual(float_tensors_ref[0], float_tensors[0])

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs")
def test_allgather_ops(self):
Expand Down Expand Up @@ -713,6 +735,21 @@ def perm(n, k):
expected = torch.tensor(prod_val)
self.assertEqual(expected, output_tensor)

for factor in (3.0, torch.tensor([5.0], device=self.rank)):
if isinstance(factor, torch.Tensor):
factor_ref = factor.cpu().item()
else:
factor_ref = factor
output = [t.float() for t in output]
tensor_lists = [[t.float() for t in tl] for tl in tensor_lists]
output_ref = [t.float() for t in output]
tensor_lists_ref = [
[t.float() * factor_ref for t in tl] for tl in tensor_lists
]
reduce_scatter(output, tensor_lists, c10d._make_xccl_premul_sum(factor))
reduce_scatter(output_ref, tensor_lists_ref, c10d.ReduceOp.SUM)
self.assertEqual(output_ref, output)

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs")
def test_reduce_scatter_base_ops(self):
Expand Down
Loading