Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,40 +843,36 @@ static void setSpecializationConstants(device_image_impl &InputImpl,
Managed<ur_program_handle_t> ProgramManager::getBuiltURProgram(
context_impl &ContextImpl, device_impl &DeviceImpl,
KernelNameStrRefT KernelName, const NDRDescT &NDRDesc) {
device_impl *RootDevImpl;
ur_bool_t MustBuildOnSubdevice = true;

device_impl *BuildDev = &DeviceImpl;
// Check if we can optimize program builds for sub-devices by using a program
// built for the root device
if (!DeviceImpl.isRootDevice()) {
RootDevImpl = &DeviceImpl;
while (!RootDevImpl->isRootDevice()) {
device_impl &ParentDev = *detail::getSyclObjImpl(
RootDevImpl->get_info<info::device::parent_device>());
// Sharing is allowed within a single context only
if (!ContextImpl.hasDevice(ParentDev))
break;
RootDevImpl = &ParentDev;
}
if (!BuildDev->isRootDevice()) {
device_impl *CandidateRoot = BuildDev;
while (!CandidateRoot->isRootDevice())
CandidateRoot = &*detail::getSyclObjImpl(
CandidateRoot->get_info<info::device::parent_device>());

bool MustBuildOnSubdevice = true;
ContextImpl.getAdapter().call<UrApiKind::urDeviceGetInfo>(
RootDevImpl->getHandleRef(), UR_DEVICE_INFO_BUILD_ON_SUBDEVICE,
CandidateRoot->getHandleRef(), UR_DEVICE_INFO_BUILD_ON_SUBDEVICE,
sizeof(ur_bool_t), &MustBuildOnSubdevice, nullptr);
}

device_impl &RootOrSubDevImpl =
MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl;
// Sharing is allowed within a single context only and only if backend
// supports sharing.
if (!MustBuildOnSubdevice && ContextImpl.hasDevice(*CandidateRoot))
BuildDev = CandidateRoot;
}

const RTDeviceBinaryImage &Img =
getDeviceImage(KernelName, ContextImpl, RootOrSubDevImpl);
getDeviceImage(KernelName, ContextImpl, *BuildDev);

// Check that device supports all aspects used by the kernel
if (auto exception =
checkDevSupportDeviceRequirements(RootOrSubDevImpl, Img, NDRDesc))
checkDevSupportDeviceRequirements(*BuildDev, Img, NDRDesc))
throw *exception;

std::set<const RTDeviceBinaryImage *> DeviceImagesToLink =
collectDeviceImageDeps(Img, {RootOrSubDevImpl});
collectDeviceImageDeps(Img, {*BuildDev});

// Decompress all DeviceImagesToLink
for (const RTDeviceBinaryImage *BinImg : DeviceImagesToLink)
Expand All @@ -888,8 +884,7 @@ Managed<ur_program_handle_t> ProgramManager::getBuiltURProgram(
std::copy(DeviceImagesToLink.begin(), DeviceImagesToLink.end(),
std::back_inserter(AllImages));

return getBuiltURProgram(std::move(AllImages), ContextImpl,
{RootOrSubDevImpl});
return getBuiltURProgram(std::move(AllImages), ContextImpl, {*BuildDev});
}

Managed<ur_program_handle_t>
Expand Down
83 changes: 76 additions & 7 deletions sycl/unittests/program_manager/SubDevices.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@

#include <helpers/TestKernel.hpp>

static ur_device_handle_t rootDevice;
static ur_device_handle_t urSubDev1 = (ur_device_handle_t)0x1;
static ur_device_handle_t urSubDev2 = (ur_device_handle_t)0x2;
static ur_device_handle_t rootDevice = (ur_device_handle_t)0x1;
// Sub-devices under rootDevice
static ur_device_handle_t urSubDev1 = (ur_device_handle_t)0x11;
static ur_device_handle_t urSubDev2 = (ur_device_handle_t)0x12;
// Sub-sub-devices under urSubDev1
static ur_device_handle_t urSubSubDev1 = (ur_device_handle_t)0x111;
static ur_device_handle_t urSubSubDev2 = (ur_device_handle_t)0x112;

namespace {
ur_result_t redefinedDeviceGet(void *pParams) {
auto params = *static_cast<ur_device_get_params_t *>(pParams);
if (*params.ppNumDevices)
**params.ppNumDevices = 1;
if (*params.pphDevices)
(*params.pphDevices)[0] = rootDevice;
return UR_RESULT_SUCCESS;
}

ur_result_t redefinedDeviceGetInfo(void *pParams) {
auto params = *static_cast<ur_device_get_info_params_t *>(pParams);
if (*params.ppropName == UR_DEVICE_INFO_SUPPORTED_PARTITIONS) {
Expand All @@ -41,13 +54,32 @@ ur_result_t redefinedDeviceGetInfo(void *pParams) {
}
}
if (*params.ppropName == UR_DEVICE_INFO_PARTITION_MAX_SUB_DEVICES) {
((uint32_t *)*params.ppPropValue)[0] = 2;
if (!*params.ppPropValue)
**params.ppPropSizeRet = sizeof(uint32_t);
else
((uint32_t *)*params.ppPropValue)[0] = 2;
}
if (*params.ppropName == UR_DEVICE_INFO_PARENT_DEVICE) {
if (*params.phDevice == urSubDev1 || *params.phDevice == urSubDev2)
((ur_device_handle_t *)*params.ppPropValue)[0] = rootDevice;
if (!*params.ppPropValue) {
**params.ppPropSizeRet = sizeof(ur_device_handle_t);
} else {
ur_device_handle_t &ret =
*static_cast<ur_device_handle_t *>(*params.ppPropValue);
if (*params.phDevice == urSubDev1 || *params.phDevice == urSubDev2) {
ret = rootDevice;
} else if (*params.phDevice == urSubSubDev1 ||
*params.phDevice == urSubSubDev2) {
ret = urSubDev1;
} else {
ret = nullptr;
}
}
}
if (*params.ppropName == UR_DEVICE_INFO_BUILD_ON_SUBDEVICE) {
if (!*params.ppPropValue)
**params.ppPropSizeRet = sizeof(ur_bool_t);
else
((ur_device_handle_t *)*params.ppPropValue)[0] = nullptr;
((ur_bool_t *)*params.ppPropValue)[0] = false;
}
return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -77,6 +109,13 @@ ur_result_t redefinedProgramBuild(void *) {
return UR_RESULT_SUCCESS;
}

static int buildCallCount = 0;

ur_result_t redefinedProgramBuildExp(void *) {
buildCallCount++;
return UR_RESULT_SUCCESS;
}

ur_result_t redefinedContextCreate(void *) { return UR_RESULT_SUCCESS; }
} // anonymous namespace

Expand Down Expand Up @@ -128,3 +167,33 @@ TEST(SubDevices, DISABLED_BuildProgramForSubdevices) {
*sycl::detail::getSyclObjImpl(Ctx), subDev2,
sycl::detail::KernelInfo<TestKernel>::getName());
}

// Check that program is built once for all sub-sub-devices
TEST(SubDevices, BuildProgramForSubSubDevices) {
sycl::unittest::UrMock<> Mock;
mock::getCallbacks().set_after_callback("urDeviceGet", &redefinedDeviceGet);
mock::getCallbacks().set_after_callback("urDeviceGetInfo",
&redefinedDeviceGetInfo);
mock::getCallbacks().set_after_callback("urProgramBuildExp",
&redefinedProgramBuildExp);
sycl::platform Plt = sycl::platform();
sycl::device root = Plt.get_devices()[0];
sycl::detail::platform_impl &PltImpl = *sycl::detail::getSyclObjImpl(Plt);
// Initialize sub-sub-devices
sycl::detail::device_impl &SubSub1 =
PltImpl.getOrMakeDeviceImpl(urSubSubDev1);
sycl::detail::device_impl &SubSub2 =
PltImpl.getOrMakeDeviceImpl(urSubSubDev2);

sycl::context Ctx{root};
buildCallCount = 0;
sycl::detail::ProgramManager::getInstance().getBuiltURProgram(
*sycl::detail::getSyclObjImpl(Ctx), SubSub1,
sycl::detail::KernelInfo<TestKernel>::getName());
sycl::detail::ProgramManager::getInstance().getBuiltURProgram(
*sycl::detail::getSyclObjImpl(Ctx), SubSub2,
sycl::detail::KernelInfo<TestKernel>::getName());

// Check that program is built only once.
EXPECT_EQ(buildCallCount, 1);
}
Loading