Skip to content

Commit 236b6d9

Browse files
committed
make explicit use of pointer to AD tape for faster access without relying on compiler optimizations
1 parent 7d521fd commit 236b6d9

File tree

75 files changed

+294
-318
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+294
-318
lines changed

stan/math/rev/arr/fun/sum.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class sum_v_vari : public vari {
2929

3030
explicit sum_v_vari(const std::vector<var>& v1)
3131
: vari(sum_of_val(v1)),
32-
v_(reinterpret_cast<vari**>(ChainableStack::instance().memalloc_.alloc(
32+
v_(reinterpret_cast<vari**>(ChainableStack::instance_->memalloc_.alloc(
3333
v1.size() * sizeof(vari*)))),
3434
length_(v1.size()) {
3535
for (size_t i = 0; i < length_; i++)

stan/math/rev/core/autodiffstackstorage.hpp

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@ namespace math {
99

1010
/**
1111
* Provides a thread_local singleton if needed. Read warnings below!
12-
* With STAN_THREADS defined, the singleton is a thread_local static pointer
13-
* for performance reasons. When STAN_THREADS is not set, we have the old
14-
* static AD stack in the instance_ field because we saw odd performance
15-
* issues on the Mac Pro[4]. The rest of this commentary is specifically
16-
* talking about the design choices in the STAN_THREADS=true case.
17-
* When a TLS is used then initialization with
12+
* For performance reasons the singleton is a global static pointer
13+
* which is stored as thread local (TLS) if STAN_THREADS is
14+
* defined. The use of a pointer is motivated by performance reasons
15+
* for the threading case. When a TLS is used then initialization with
1816
* a constant expression is required for fast access to the TLS. As
1917
* the AD storage struct is non-POD it must be initialized as a
2018
* dynamic expression such that compilers will wrap any access to the
@@ -61,7 +59,6 @@ namespace math {
6159
* [2] https://github.com/stan-dev/math/pull/826
6260
* [3]
6361
* http://discourse.mc-stan.org/t/potentially-dropping-support-for-older-versions-of-apples-version-of-clang/3780/
64-
* [4] https://github.com/stan-dev/math/pull/1135
6562
*/
6663
template <typename ChainableT, typename ChainableAllocT>
6764
struct AutodiffStackSingleton {
@@ -70,12 +67,10 @@ struct AutodiffStackSingleton {
7067

7168
AutodiffStackSingleton() : own_instance_(init()) {}
7269
~AutodiffStackSingleton() {
73-
#ifdef STAN_THREADS
7470
if (own_instance_) {
7571
delete instance_;
7672
instance_ = nullptr;
7773
}
78-
#endif
7974
}
8075

8176
struct AutodiffStackStorage {
@@ -95,39 +90,25 @@ struct AutodiffStackSingleton {
9590
explicit AutodiffStackSingleton(AutodiffStackSingleton_t const &) = delete;
9691
AutodiffStackSingleton &operator=(const AutodiffStackSingleton_t &) = delete;
9792

98-
static constexpr inline AutodiffStackStorage &instance() {
99-
return
93+
static
10094
#ifdef STAN_THREADS
101-
*
95+
#ifdef __GNUC__
96+
__thread
97+
#else
98+
thread_local
10299
#endif
103-
instance_;
104-
}
100+
#endif
101+
AutodiffStackStorage *instance_;
105102

106103
private:
107104
static bool init() {
108-
#ifdef STAN_THREADS
109105
if (!instance_) {
110106
instance_ = new AutodiffStackStorage();
111107
return true;
112108
}
113-
#endif
114109
return false;
115110
}
116111

117-
static
118-
#ifdef STAN_THREADS
119-
#ifdef __GNUC__
120-
__thread
121-
#else
122-
thread_local
123-
#endif
124-
#endif
125-
AutodiffStackStorage
126-
#ifdef STAN_THREADS
127-
*
128-
#endif
129-
instance_;
130-
131112
bool own_instance_;
132113
};
133114

@@ -141,13 +122,8 @@ thread_local
141122
#endif
142123
typename AutodiffStackSingleton<ChainableT,
143124
ChainableAllocT>::AutodiffStackStorage
144-
145-
#ifdef STAN_THREADS
146125
*AutodiffStackSingleton<ChainableT, ChainableAllocT>::instance_
147126
= nullptr;
148-
#else
149-
AutodiffStackSingleton<ChainableT, ChainableAllocT>::instance_;
150-
#endif
151127

152128
} // namespace math
153129
} // namespace stan

stan/math/rev/core/build_vari_array.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace math {
2020
template <int R, int C>
2121
vari** build_vari_array(const Eigen::Matrix<var, R, C>& x) {
2222
vari** x_vi_
23-
= ChainableStack::instance().memalloc_.alloc_array<vari*>(x.size());
23+
= ChainableStack::instance_->memalloc_.alloc_array<vari*>(x.size());
2424
for (int i = 0; i < x.size(); ++i) {
2525
x_vi_[i] = x(i).vi_;
2626
}

stan/math/rev/core/chainable_alloc.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace math {
1616
class chainable_alloc {
1717
public:
1818
chainable_alloc() {
19-
ChainableStack::instance().var_alloc_stack_.push_back(this);
19+
ChainableStack::instance_->var_alloc_stack_.push_back(this);
2020
}
2121
virtual ~chainable_alloc() {}
2222
};

stan/math/rev/core/empty_nested.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace math {
1010
* Return true if there is no nested autodiff being executed.
1111
*/
1212
static inline bool empty_nested() {
13-
return ChainableStack::instance().nested_var_stack_sizes_.empty();
13+
return ChainableStack::instance_->nested_var_stack_sizes_.empty();
1414
}
1515

1616
} // namespace math

stan/math/rev/core/gevv_vvv_vari.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class gevv_vvv_vari : public vari {
3232
length_(length) {
3333
alpha_ = alpha->vi_;
3434
// TODO(carpenter): replace this with array alloc fun call
35-
v1_ = reinterpret_cast<vari**>(ChainableStack::instance().memalloc_.alloc(
35+
v1_ = reinterpret_cast<vari**>(ChainableStack::instance_->memalloc_.alloc(
3636
2 * length_ * sizeof(vari*)));
3737
v2_ = v1_ + length_;
3838
for (size_t i = 0; i < length_; i++)

stan/math/rev/core/grad.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ static void grad(vari* vi) {
3737

3838
typedef std::vector<vari*>::reverse_iterator it_t;
3939
vi->init_dependent();
40-
it_t begin = ChainableStack::instance().var_stack_.rbegin();
41-
it_t end = empty_nested() ? ChainableStack::instance().var_stack_.rend()
40+
it_t begin = ChainableStack::instance_->var_stack_.rbegin();
41+
it_t end = empty_nested() ? ChainableStack::instance_->var_stack_.rend()
4242
: begin + nested_size();
4343
for (it_t it = begin; it < end; ++it) {
4444
(*it)->chain();

stan/math/rev/core/nested_size.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ namespace stan {
88
namespace math {
99

1010
static inline size_t nested_size() {
11-
return ChainableStack::instance().var_stack_.size()
12-
- ChainableStack::instance().nested_var_stack_sizes_.back();
11+
return ChainableStack::instance_->var_stack_.size()
12+
- ChainableStack::instance_->nested_var_stack_sizes_.back();
1313
}
1414

1515
} // namespace math

stan/math/rev/core/precomputed_gradients.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ class precomputed_gradients_vari : public vari {
5252
const std::vector<double>& gradients)
5353
: vari(val),
5454
size_(vars.size()),
55-
varis_(ChainableStack::instance().memalloc_.alloc_array<vari*>(
55+
varis_(ChainableStack::instance_->memalloc_.alloc_array<vari*>(
5656
vars.size())),
57-
gradients_(ChainableStack::instance().memalloc_.alloc_array<double>(
57+
gradients_(ChainableStack::instance_->memalloc_.alloc_array<double>(
5858
vars.size())) {
5959
check_consistent_sizes("precomputed_gradients_vari", "vars", vars,
6060
"gradients", gradients);

stan/math/rev/core/print_stack.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ namespace math {
1818
* @param o ostream to modify
1919
*/
2020
inline void print_stack(std::ostream& o) {
21-
o << "STACK, size=" << ChainableStack::instance().var_stack_.size()
21+
o << "STACK, size=" << ChainableStack::instance_->var_stack_.size()
2222
<< std::endl;
2323
// TODO(carpenter): this shouldn't need to be cast any more
24-
for (size_t i = 0; i < ChainableStack::instance().var_stack_.size(); ++i)
25-
o << i << " " << ChainableStack::instance().var_stack_[i] << " "
26-
<< (static_cast<vari*>(ChainableStack::instance().var_stack_[i]))->val_
24+
for (size_t i = 0; i < ChainableStack::instance_->var_stack_.size(); ++i)
25+
o << i << " " << ChainableStack::instance_->var_stack_[i] << " "
26+
<< (static_cast<vari*>(ChainableStack::instance_->var_stack_[i]))->val_
2727
<< " : "
28-
<< (static_cast<vari*>(ChainableStack::instance().var_stack_[i]))->adj_
28+
<< (static_cast<vari*>(ChainableStack::instance_->var_stack_[i]))->adj_
2929
<< std::endl;
3030
}
3131

0 commit comments

Comments
 (0)