Skip to content

Commit babcd91

Browse files
authored
[SYCL] Reuse root device build for sub-sub-devices (#20051)
Current root device ascent logic lead to reuse only for sub-devices, but we want reuse to happen for sub-sub-devices as well, this PR fixes that.
1 parent d452a04 commit babcd91

File tree

2 files changed

+93
-29
lines changed

2 files changed

+93
-29
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -843,40 +843,36 @@ static void setSpecializationConstants(device_image_impl &InputImpl,
843843
Managed<ur_program_handle_t> ProgramManager::getBuiltURProgram(
844844
context_impl &ContextImpl, device_impl &DeviceImpl,
845845
KernelNameStrRefT KernelName, const NDRDescT &NDRDesc) {
846-
device_impl *RootDevImpl;
847-
ur_bool_t MustBuildOnSubdevice = true;
848-
846+
device_impl *BuildDev = &DeviceImpl;
849847
// Check if we can optimize program builds for sub-devices by using a program
850848
// built for the root device
851-
if (!DeviceImpl.isRootDevice()) {
852-
RootDevImpl = &DeviceImpl;
853-
while (!RootDevImpl->isRootDevice()) {
854-
device_impl &ParentDev = *detail::getSyclObjImpl(
855-
RootDevImpl->get_info<info::device::parent_device>());
856-
// Sharing is allowed within a single context only
857-
if (!ContextImpl.hasDevice(ParentDev))
858-
break;
859-
RootDevImpl = &ParentDev;
860-
}
849+
if (!BuildDev->isRootDevice()) {
850+
device_impl *CandidateRoot = BuildDev;
851+
while (!CandidateRoot->isRootDevice())
852+
CandidateRoot = &*detail::getSyclObjImpl(
853+
CandidateRoot->get_info<info::device::parent_device>());
861854

855+
bool MustBuildOnSubdevice = true;
862856
ContextImpl.getAdapter().call<UrApiKind::urDeviceGetInfo>(
863-
RootDevImpl->getHandleRef(), UR_DEVICE_INFO_BUILD_ON_SUBDEVICE,
857+
CandidateRoot->getHandleRef(), UR_DEVICE_INFO_BUILD_ON_SUBDEVICE,
864858
sizeof(ur_bool_t), &MustBuildOnSubdevice, nullptr);
865-
}
866859

867-
device_impl &RootOrSubDevImpl =
868-
MustBuildOnSubdevice == true ? DeviceImpl : *RootDevImpl;
860+
// Sharing is allowed within a single context if and only if backend
861+
// supports sharing.
862+
if (!MustBuildOnSubdevice && ContextImpl.hasDevice(*CandidateRoot))
863+
BuildDev = CandidateRoot;
864+
}
869865

870866
const RTDeviceBinaryImage &Img =
871-
getDeviceImage(KernelName, ContextImpl, RootOrSubDevImpl);
867+
getDeviceImage(KernelName, ContextImpl, *BuildDev);
872868

873869
// Check that device supports all aspects used by the kernel
874870
if (auto exception =
875-
checkDevSupportDeviceRequirements(RootOrSubDevImpl, Img, NDRDesc))
871+
checkDevSupportDeviceRequirements(*BuildDev, Img, NDRDesc))
876872
throw *exception;
877873

878874
std::set<const RTDeviceBinaryImage *> DeviceImagesToLink =
879-
collectDeviceImageDeps(Img, {RootOrSubDevImpl});
875+
collectDeviceImageDeps(Img, {*BuildDev});
880876

881877
// Decompress all DeviceImagesToLink
882878
for (const RTDeviceBinaryImage *BinImg : DeviceImagesToLink)
@@ -888,8 +884,7 @@ Managed<ur_program_handle_t> ProgramManager::getBuiltURProgram(
888884
std::copy(DeviceImagesToLink.begin(), DeviceImagesToLink.end(),
889885
std::back_inserter(AllImages));
890886

891-
return getBuiltURProgram(std::move(AllImages), ContextImpl,
892-
{RootOrSubDevImpl});
887+
return getBuiltURProgram(std::move(AllImages), ContextImpl, {*BuildDev});
893888
}
894889

895890
Managed<ur_program_handle_t>

sycl/unittests/program_manager/SubDevices.cpp

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,24 @@
1414

1515
#include <helpers/TestKernel.hpp>
1616

17-
static ur_device_handle_t rootDevice;
18-
static ur_device_handle_t urSubDev1 = (ur_device_handle_t)0x1;
19-
static ur_device_handle_t urSubDev2 = (ur_device_handle_t)0x2;
17+
static ur_device_handle_t rootDevice = (ur_device_handle_t)0x1;
18+
// Sub-devices under rootDevice
19+
static ur_device_handle_t urSubDev1 = (ur_device_handle_t)0x11;
20+
static ur_device_handle_t urSubDev2 = (ur_device_handle_t)0x12;
21+
// Sub-sub-devices under urSubDev1
22+
static ur_device_handle_t urSubSubDev1 = (ur_device_handle_t)0x111;
23+
static ur_device_handle_t urSubSubDev2 = (ur_device_handle_t)0x112;
2024

2125
namespace {
26+
ur_result_t redefinedDeviceGet(void *pParams) {
27+
auto params = *static_cast<ur_device_get_params_t *>(pParams);
28+
if (*params.ppNumDevices)
29+
**params.ppNumDevices = 1;
30+
if (*params.pphDevices)
31+
(*params.pphDevices)[0] = rootDevice;
32+
return UR_RESULT_SUCCESS;
33+
}
34+
2235
ur_result_t redefinedDeviceGetInfo(void *pParams) {
2336
auto params = *static_cast<ur_device_get_info_params_t *>(pParams);
2437
if (*params.ppropName == UR_DEVICE_INFO_SUPPORTED_PARTITIONS) {
@@ -41,13 +54,32 @@ ur_result_t redefinedDeviceGetInfo(void *pParams) {
4154
}
4255
}
4356
if (*params.ppropName == UR_DEVICE_INFO_PARTITION_MAX_SUB_DEVICES) {
44-
((uint32_t *)*params.ppPropValue)[0] = 2;
57+
if (!*params.ppPropValue)
58+
**params.ppPropSizeRet = sizeof(uint32_t);
59+
else
60+
((uint32_t *)*params.ppPropValue)[0] = 2;
4561
}
4662
if (*params.ppropName == UR_DEVICE_INFO_PARENT_DEVICE) {
47-
if (*params.phDevice == urSubDev1 || *params.phDevice == urSubDev2)
48-
((ur_device_handle_t *)*params.ppPropValue)[0] = rootDevice;
63+
if (!*params.ppPropValue) {
64+
**params.ppPropSizeRet = sizeof(ur_device_handle_t);
65+
} else {
66+
ur_device_handle_t &ret =
67+
*static_cast<ur_device_handle_t *>(*params.ppPropValue);
68+
if (*params.phDevice == urSubDev1 || *params.phDevice == urSubDev2) {
69+
ret = rootDevice;
70+
} else if (*params.phDevice == urSubSubDev1 ||
71+
*params.phDevice == urSubSubDev2) {
72+
ret = urSubDev1;
73+
} else {
74+
ret = nullptr;
75+
}
76+
}
77+
}
78+
if (*params.ppropName == UR_DEVICE_INFO_BUILD_ON_SUBDEVICE) {
79+
if (!*params.ppPropValue)
80+
**params.ppPropSizeRet = sizeof(ur_bool_t);
4981
else
50-
((ur_device_handle_t *)*params.ppPropValue)[0] = nullptr;
82+
((ur_bool_t *)*params.ppPropValue)[0] = false;
5183
}
5284
return UR_RESULT_SUCCESS;
5385
}
@@ -77,6 +109,13 @@ ur_result_t redefinedProgramBuild(void *) {
77109
return UR_RESULT_SUCCESS;
78110
}
79111

112+
static int buildCallCount = 0;
113+
114+
ur_result_t redefinedProgramBuildExp(void *) {
115+
buildCallCount++;
116+
return UR_RESULT_SUCCESS;
117+
}
118+
80119
ur_result_t redefinedContextCreate(void *) { return UR_RESULT_SUCCESS; }
81120
} // anonymous namespace
82121

@@ -128,3 +167,33 @@ TEST(SubDevices, DISABLED_BuildProgramForSubdevices) {
128167
*sycl::detail::getSyclObjImpl(Ctx), subDev2,
129168
sycl::detail::KernelInfo<TestKernel>::getName());
130169
}
170+
171+
// Check that program is built once for all sub-sub-devices
172+
TEST(SubDevices, BuildProgramForSubSubDevices) {
173+
sycl::unittest::UrMock<> Mock;
174+
mock::getCallbacks().set_after_callback("urDeviceGet", &redefinedDeviceGet);
175+
mock::getCallbacks().set_after_callback("urDeviceGetInfo",
176+
&redefinedDeviceGetInfo);
177+
mock::getCallbacks().set_after_callback("urProgramBuildExp",
178+
&redefinedProgramBuildExp);
179+
sycl::platform Plt = sycl::platform();
180+
sycl::device root = Plt.get_devices()[0];
181+
sycl::detail::platform_impl &PltImpl = *sycl::detail::getSyclObjImpl(Plt);
182+
// Initialize sub-sub-devices
183+
sycl::detail::device_impl &SubSub1 =
184+
PltImpl.getOrMakeDeviceImpl(urSubSubDev1);
185+
sycl::detail::device_impl &SubSub2 =
186+
PltImpl.getOrMakeDeviceImpl(urSubSubDev2);
187+
188+
sycl::context Ctx{root};
189+
buildCallCount = 0;
190+
sycl::detail::ProgramManager::getInstance().getBuiltURProgram(
191+
*sycl::detail::getSyclObjImpl(Ctx), SubSub1,
192+
sycl::detail::KernelInfo<TestKernel>::getName());
193+
sycl::detail::ProgramManager::getInstance().getBuiltURProgram(
194+
*sycl::detail::getSyclObjImpl(Ctx), SubSub2,
195+
sycl::detail::KernelInfo<TestKernel>::getName());
196+
197+
// Check that program is built only once.
198+
EXPECT_EQ(buildCallCount, 1);
199+
}

0 commit comments

Comments
 (0)