1
- #ifndef STAN_MATH_PRIM_MAT_FUN_OPENCL_COPY_HPP
2
- #define STAN_MATH_PRIM_MAT_FUN_OPENCL_COPY_HPP
1
+ #ifndef STAN_MATH_OPENCL_COPY_HPP
2
+ #define STAN_MATH_OPENCL_COPY_HPP
3
3
#ifdef STAN_OPENCL
4
4
5
5
#include < stan/math/opencl/opencl_context.hpp>
@@ -28,21 +28,16 @@ namespace math {
28
28
* the destination matrix that is stored
29
29
* on the OpenCL device.
30
30
*
31
- * @tparam T type of data in the Eigen matrix
32
- * @param dst destination matrix on the OpenCL device
31
+ * @tparam R Compile time rows of the Eigen matrix
32
+ * @tparam C Compile time columns of the Eigen matrix
33
33
* @param src source Eigen matrix
34
- *
35
- * @throw <code>std::invalid_argument</code> if the
36
- * matrices do not have matching dimensions
34
+ * @return matrix_cl with a copy of the data in the source matrix
37
35
*/
38
36
template <int R, int C>
39
- void copy (matrix_cl& dst, const Eigen::Matrix<double , R, C>& src) {
40
- check_size_match (" copy (Eigen -> (OpenCL))" , " src.rows()" , src.rows (),
41
- " dst.rows()" , dst.rows ());
42
- check_size_match (" copy (Eigen -> (OpenCL))" , " src.cols()" , src.cols (),
43
- " dst.cols()" , dst.cols ());
37
+ inline matrix_cl to_matrix_cl (const Eigen::Matrix<double , R, C>& src) {
38
+ matrix_cl dst (src.rows (), src.cols ());
44
39
if (src.size () == 0 ) {
45
- return ;
40
+ return dst ;
46
41
}
47
42
try {
48
43
/* *
@@ -61,28 +56,23 @@ void copy(matrix_cl& dst, const Eigen::Matrix<double, R, C>& src) {
61
56
} catch (const cl::Error& e) {
62
57
check_opencl_error (" copy Eigen->(OpenCL)" , e);
63
58
}
59
+ return dst;
64
60
}
65
61
66
62
/* *
67
63
* Copies the source matrix that is stored
68
64
* on the OpenCL device to the destination Eigen
69
65
* matrix.
70
66
*
71
- * @tparam T type of data in the Eigen matrix
72
- * @param dst destination Eigen matrix
73
67
* @param src source matrix on the OpenCL device
74
- *
75
- * @throw <code>std::invalid_argument</code> if the
76
- * matrices do not have matching dimensions
68
+ * @return Eigen matrix with a copy of the data in the source matrix
77
69
*/
78
- template <int R, int C>
79
- void copy (Eigen::Matrix<double , R, C>& dst, const matrix_cl& src) {
80
- check_size_match (" copy ((OpenCL) -> Eigen)" , " src.rows()" , src.rows (),
81
- " dst.rows()" , dst.rows ());
82
- check_size_match (" copy ((OpenCL) -> Eigen)" , " src.cols()" , src.cols (),
83
- " dst.cols()" , dst.cols ());
70
+ inline Eigen::Matrix<double , Eigen::Dynamic, Eigen::Dynamic> from_matrix_cl (
71
+ const matrix_cl& src) {
72
+ Eigen::Matrix<double , Eigen::Dynamic, Eigen::Dynamic> dst (src.rows (),
73
+ src.cols ());
84
74
if (src.size () == 0 ) {
85
- return ;
75
+ return dst ;
86
76
}
87
77
try {
88
78
/* *
@@ -103,6 +93,7 @@ void copy(Eigen::Matrix<double, R, C>& dst, const matrix_cl& src) {
103
93
} catch (const cl::Error& e) {
104
94
check_opencl_error (" copy (OpenCL)->Eigen" , e);
105
95
}
96
+ return dst;
106
97
}
107
98
108
99
/* *
@@ -184,19 +175,15 @@ inline matrix_cl packed_copy(const std::vector<double>& src, int rows) {
184
175
* destination matrix. Both matrices
185
176
* are stored on the OpenCL device.
186
177
*
187
- * @param dst destination matrix
188
178
* @param src source matrix
189
- *
179
+ * @return matrix_cl with copies of values in the source matrix
190
180
* @throw <code>std::invalid_argument</code> if the
191
181
* matrices do not have matching dimensions
192
182
*/
193
- inline void copy (matrix_cl& dst, const matrix_cl& src) {
194
- check_size_match (" copy ((OpenCL) -> (OpenCL))" , " src.rows()" , src.rows (),
195
- " dst.rows()" , dst.rows ());
196
- check_size_match (" copy ((OpenCL) -> (OpenCL))" , " src.cols()" , src.cols (),
197
- " dst.cols()" , dst.cols ());
183
+ inline matrix_cl copy_cl (const matrix_cl& src) {
184
+ matrix_cl dst (src.rows (), src.cols ());
198
185
if (src.size () == 0 ) {
199
- return ;
186
+ return dst ;
200
187
}
201
188
try {
202
189
/* *
@@ -216,16 +203,18 @@ inline void copy(matrix_cl& dst, const matrix_cl& src) {
216
203
} catch (const cl::Error& e) {
217
204
check_opencl_error (" copy (OpenCL)->(OpenCL)" , e);
218
205
}
206
+ return dst;
219
207
}
220
208
221
209
/* *
222
210
* Copy A 1 by 1 source matrix from the Device to the host.
223
- * @tparam An arithmetic type to pass the value from the OpenCL matrix to.
224
- * @param dst Arithmetic to receive the matrix_cl value.
211
+ * @tparam T An arithmetic type to pass the value from the OpenCL matrix to.
225
212
* @param src A 1x1 matrix on the device.
213
+ * @return dst Arithmetic to receive the matrix_cl value.
226
214
*/
227
215
template <typename T, std::enable_if_t <std::is_arithmetic<T>::value, int > = 0 >
228
- inline void copy (T& dst, const matrix_cl& src) {
216
+ inline T from_matrix_cl (const matrix_cl& src) {
217
+ T dst;
229
218
check_size_match (" copy ((OpenCL) -> (OpenCL))" , " src.rows()" , src.rows (),
230
219
" dst.rows()" , 1 );
231
220
check_size_match (" copy ((OpenCL) -> (OpenCL))" , " src.cols()" , src.cols (),
@@ -240,16 +229,18 @@ inline void copy(T& dst, const matrix_cl& src) {
240
229
} catch (const cl::Error& e) {
241
230
check_opencl_error (" copy (OpenCL)->(OpenCL)" , e);
242
231
}
232
+ return dst;
243
233
}
244
234
245
235
/* *
246
236
* Copy an arithmetic type to the device.
247
- * @tparam An arithmetic type to pass the value from the OpenCL matrix to.
237
+ * @tparam T An arithmetic type to pass the value from the OpenCL matrix to.
248
238
* @param src Arithmetic to receive the matrix_cl value.
249
- * @param dst A 1x1 matrix on the device.
239
+ * @return A 1x1 matrix on the device.
250
240
*/
251
241
template <typename T, std::enable_if_t <std::is_arithmetic<T>::value, int > = 0 >
252
- inline void copy (matrix_cl& dst, const T& src) {
242
+ inline matrix_cl to_matrix_cl (const T& src) {
243
+ matrix_cl dst (1 , 1 );
253
244
check_size_match (" copy ((OpenCL) -> (OpenCL))" , " src.rows()" , dst.rows (),
254
245
" dst.rows()" , 1 );
255
246
check_size_match (" copy ((OpenCL) -> (OpenCL))" , " src.cols()" , dst.cols (),
@@ -263,6 +254,7 @@ inline void copy(matrix_cl& dst, const T& src) {
263
254
} catch (const cl::Error& e) {
264
255
check_opencl_error (" copy (OpenCL)->(OpenCL)" , e);
265
256
}
257
+ return dst;
266
258
}
267
259
268
260
} // namespace math
0 commit comments