Skip to content

Commit 466de9c

Browse files
[SYCL] Speed up device compilation for half and bfloat16 (#20050)
* `std::istream`/`std::ostream` aren't usable on device, so limit `operator<<`/`operator>>` to declaration only to limit device includes to `<iosfwd>` instead of much heavier `<istream>`/`<ostream>`. * Use "lighter" `<optional>` to get `std::hash`
1 parent babcd91 commit 466de9c

File tree

4 files changed

+35
-24
lines changed

4 files changed

+35
-24
lines changed

sycl/include/sycl/detail/vector_arith.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
1515
#include <sycl/ext/oneapi/bfloat16.hpp>
1616

17+
#include <functional>
18+
1719
namespace sycl {
1820
inline namespace _V1 {
1921
namespace detail {

sycl/include/sycl/ext/intel/esimd/detail/half_type_traits.hpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,6 @@ struct is_esimd_arithmetic_type<half_raw_type, void> : std::true_type {};
115115
template <>
116116
struct is_esimd_arithmetic_type<sycl::half, void> : std::true_type {};
117117

118-
// Misc
119-
inline std::ostream &operator<<(std::ostream &O, sycl::half const &rhs) {
120-
O << static_cast<float>(rhs);
121-
return O;
122-
}
123-
124-
inline std::istream &operator>>(std::istream &I, sycl::half &rhs) {
125-
float ValFloat = 0.0f;
126-
I >> ValFloat;
127-
rhs = ValFloat;
128-
return I;
129-
}
130-
131118
} // namespace ext::intel::esimd::detail
132119
} // namespace _V1
133120
} // namespace sycl

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
#pragma once
1010

11-
#include <sycl/aliases.hpp> // for half
12-
#include <sycl/bit_cast.hpp> // for bit_cast
13-
#include <sycl/detail/defines_elementary.hpp> // for __DPCPP_SYCL_EXTERNAL
14-
#include <sycl/half_type.hpp> // for half
11+
#include <sycl/aliases.hpp>
12+
#include <sycl/bit_cast.hpp>
13+
#include <sycl/detail/defines_elementary.hpp>
14+
#include <sycl/half_type.hpp>
1515

16-
#include <cstdint> // for uint16_t, uint32_t
16+
#include <cstdint>
1717

1818
namespace sycl {
1919
inline namespace _V1 {
@@ -126,6 +126,12 @@ class bfloat16 {
126126
// for floating-point types.
127127

128128
// Stream Operator << and >>
129+
#ifdef __SYCL_DEVICE_ONLY__
130+
// std::istream/std::ostream aren't usable on device, so don't provide a
131+
// definition to save compile time by using lightweight `<iosfwd>`.
132+
inline friend std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs);
133+
inline friend std::istream &operator>>(std::istream &I, bfloat16 &rhs);
134+
#else
129135
inline friend std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
130136
O << static_cast<float>(rhs);
131137
return O;
@@ -137,6 +143,7 @@ class bfloat16 {
137143
rhs = ValFloat;
138144
return I;
139145
}
146+
#endif
140147

141148
private:
142149
Bfloat16StorageT value;

sycl/include/sycl/half_type.hpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,25 @@
1010

1111
#include <sycl/bit_cast.hpp> // for bit_cast
1212
#include <sycl/detail/export.hpp> // for __SYCL_EXPORT
13-
#include <sycl/detail/iostream_proxy.hpp> // for istream, ostream
1413

1514
#ifdef __SYCL_DEVICE_ONLY__
1615
#include <sycl/aspects.hpp>
1716
#endif
1817

19-
#include <cstddef> // for size_t
20-
#include <cstdint> // for uint16_t, uint32_t, uint8_t
21-
#include <functional> // for hash
18+
#include <cstddef>
19+
#include <cstdint>
2220
#include <limits> // for float_denorm_style, float_r...
23-
#include <string_view> // for hash
24-
#include <type_traits> // for enable_if_t
21+
#include <type_traits>
22+
23+
// For std::hash, seems to be the most lightweight header provide it under
24+
// C++17:
25+
#include <optional>
26+
27+
#ifdef __SYCL_DEVICE_ONLY__
28+
#include <iosfwd>
29+
#else
30+
#include <sycl/detail/iostream_proxy.hpp>
31+
#endif
2532

2633
#if !defined(__has_builtin) || !__has_builtin(__builtin_expect)
2734
#define __builtin_expect(a, b) (a)
@@ -478,6 +485,13 @@ class [[__sycl_detail__::__uses_aspects__(aspect::fp16)]] half {
478485
#endif // __SYCL_DEVICE_ONLY__
479486

480487
// Operator << and >>
488+
#ifdef __SYCL_DEVICE_ONLY__
489+
// std::istream/std::ostream aren't usable on device, so don't provide a
490+
// definition to save compile time by using lightweight `<iosfwd>`.
491+
inline friend std::ostream &operator<<(std::ostream &O,
492+
sycl::half const &rhs);
493+
inline friend std::istream &operator>>(std::istream &I, sycl::half &rhs);
494+
#else
481495
inline friend std::ostream &operator<<(std::ostream &O,
482496
sycl::half const &rhs) {
483497
O << static_cast<float>(rhs);
@@ -490,6 +504,7 @@ class [[__sycl_detail__::__uses_aspects__(aspect::fp16)]] half {
490504
rhs = ValFloat;
491505
return I;
492506
}
507+
#endif
493508

494509
template <typename Key> friend struct std::hash;
495510

0 commit comments

Comments
 (0)