|
1 | 1 | // Copyright 2014 BVLC and contributors. |
2 | 2 |
|
3 | | -#ifndef CAFFE_UTIL_DEVICE_H_ |
4 | | -#define CAFFE_UTIL_DEVICE_H_ |
| 3 | +#ifndef CAFFE_DEVICE_H_ |
| 4 | +#define CAFFE_DEVICE_H_ |
5 | 5 |
|
6 | 6 | extern "C" { |
7 | 7 | #include <cblas.h> |
8 | 8 | } |
9 | 9 |
|
10 | | -#include <cublas_v2.h> |
11 | | -#include <stdint.h> |
12 | | - |
13 | | -#include "glog/logging.h" |
14 | | - |
15 | 10 | #include "caffe/common.hpp" |
16 | 11 |
|
17 | 12 | namespace caffe { |
@@ -127,164 +122,10 @@ class Device { |
127 | 122 | } |
128 | 123 | }; |
129 | 124 |
|
130 | | -template<typename Dtype> |
131 | | -class CPUDevice : public Device<Dtype> { |
132 | | - public: |
133 | | - CPUDevice() { |
134 | | - } |
135 | | - virtual ~CPUDevice() { |
136 | | - } |
137 | | - virtual void gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, |
138 | | - const int M, const int N, const int K, const Dtype alpha, |
139 | | - const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C); |
140 | | - |
141 | | - virtual void gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, |
142 | | - const Dtype alpha, const Dtype* A, const Dtype* x, |
143 | | - const Dtype beta, Dtype* y); |
144 | | - |
145 | | - virtual void axpy(const int N, const Dtype alpha, const Dtype* X, Dtype* Y); |
146 | | - |
147 | | - virtual void axpby(const int N, const Dtype alpha, const Dtype* X, |
148 | | - const Dtype beta, Dtype* Y); |
149 | | - |
150 | | - /* NOLINT_NEXT_LINE(build/include_what_you_use) */ |
151 | | - virtual void copy(const int N, const Dtype *X, Dtype *Y); |
152 | | - |
153 | | - virtual void copy_from_cpu(const int N, const Dtype* X, Dtype* Y); |
154 | | - |
155 | | - virtual void set(const int N, const Dtype alpha, Dtype *X); |
156 | | - |
157 | | - virtual void add_scalar(const int N, const Dtype alpha, Dtype *X); |
158 | | - |
159 | | - virtual void scal(const int N, const Dtype alpha, Dtype *X); |
160 | | - |
161 | | - virtual void sqr(const int N, const Dtype* a, Dtype* y); |
162 | | - |
163 | | - virtual void add(const int N, const Dtype* a, const Dtype* b, Dtype* y); |
164 | | - |
165 | | - virtual void sub(const int N, const Dtype* a, const Dtype* b, Dtype* y); |
166 | | - |
167 | | - virtual void mul(const int N, const Dtype* a, const Dtype* b, Dtype* y); |
168 | | - |
169 | | - virtual void div(const int N, const Dtype* a, const Dtype* b, Dtype* y); |
170 | | - |
171 | | - virtual void powx(const int N, const Dtype* a, const Dtype b, Dtype* y); |
172 | | - |
173 | | - virtual void rng_uniform(const int N, const Dtype a, const Dtype b, Dtype* r); |
174 | | - |
175 | | - virtual void rng_gaussian(const int N, const Dtype mu, const Dtype sigma, |
176 | | - Dtype* r); |
177 | | - |
178 | | - virtual void rng_bernoulli(const int N, const Dtype p, int* r); |
179 | | - |
180 | | - virtual void rng_bernoulli(const int N, const Dtype p, unsigned int* r); |
181 | | - |
182 | | - virtual void exp(const int N, const Dtype* a, Dtype* y); |
183 | | - |
184 | | - virtual void dot(const int N, const Dtype* x, const Dtype* y, Dtype* out); |
185 | | - |
186 | | - virtual void hamming_distance(const int N, const Dtype* x, const Dtype* y, |
187 | | - int* out); |
188 | | - |
189 | | - // Returns the sum of the absolute values of the elements of vector x |
190 | | - virtual void asum(const int N, const Dtype* x, Dtype* y); |
191 | | - |
192 | | - virtual void sign(const int N, const Dtype* x, Dtype* y); |
193 | | - |
194 | | - virtual void sgnbit(const int N, const Dtype* x, Dtype* y); |
195 | | - |
196 | | - virtual void fabs(const int N, const Dtype* x, Dtype* y); |
197 | | - |
198 | | - virtual void scale(const int N, const Dtype alpha, const Dtype *x, Dtype* y); |
199 | | - |
200 | | - virtual void im2col(const Dtype* data_im, const int channels, |
201 | | - const int height, const int width, const int ksize, |
202 | | - const int pad, const int stride, Dtype* data_col); |
203 | | - |
204 | | - virtual void col2im(const Dtype* data_col, const int channels, |
205 | | - const int height, const int width, const int ksize, |
206 | | - const int pad, const int stride, Dtype* data_im); |
207 | | -}; |
208 | | - |
209 | | -template<typename Dtype> |
210 | | -class GPUDevice : public Device<Dtype> { |
211 | | - public: |
212 | | - GPUDevice() { |
213 | | - } |
214 | | - virtual ~GPUDevice() { |
215 | | - } |
216 | | - virtual void gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, |
217 | | - const int M, const int N, const int K, const Dtype alpha, |
218 | | - const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C); |
219 | | - |
220 | | - virtual void gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, |
221 | | - const Dtype alpha, const Dtype* A, const Dtype* x, |
222 | | - const Dtype beta, Dtype* y); |
223 | | - |
224 | | - virtual void axpy(const int N, const Dtype alpha, const Dtype* X, Dtype* Y); |
225 | | - |
226 | | - virtual void axpby(const int N, const Dtype alpha, const Dtype* X, |
227 | | - const Dtype beta, Dtype* Y); |
228 | | - |
229 | | - /* NOLINT_NEXT_LINE(build/include_what_you_use) */ |
230 | | - virtual void copy(const int N, const Dtype *X, Dtype *Y); |
231 | | - |
232 | | - virtual void copy_from_cpu(const int N, const Dtype* X, Dtype* Y); |
233 | | - |
234 | | - virtual void set(const int N, const Dtype alpha, Dtype *X); |
235 | | - |
236 | | - virtual void add_scalar(const int N, const Dtype alpha, Dtype *X); |
237 | | - |
238 | | - virtual void scal(const int N, const Dtype alpha, Dtype *X); |
239 | | - |
240 | | - virtual void sqr(const int N, const Dtype* a, Dtype* y); |
241 | | - |
242 | | - virtual void add(const int N, const Dtype* a, const Dtype* b, Dtype* y); |
243 | | - |
244 | | - virtual void sub(const int N, const Dtype* a, const Dtype* b, Dtype* y); |
245 | | - |
246 | | - virtual void mul(const int N, const Dtype* a, const Dtype* b, Dtype* y); |
247 | | - |
248 | | - virtual void div(const int N, const Dtype* a, const Dtype* b, Dtype* y); |
249 | | - |
250 | | - virtual void powx(const int N, const Dtype* a, const Dtype b, Dtype* y); |
251 | | - |
252 | | - virtual void rng_uniform(const int N, const Dtype a, const Dtype b, Dtype* r); |
253 | | - |
254 | | - virtual void rng_gaussian(const int N, const Dtype mu, const Dtype sigma, |
255 | | - Dtype* r); |
256 | | - |
257 | | - virtual void exp(const int N, const Dtype* a, Dtype* y); |
258 | | - |
259 | | - virtual void dot(const int N, const Dtype* x, const Dtype* y, Dtype* out); |
260 | | - |
261 | | - virtual void hamming_distance(const int N, const Dtype* x, const Dtype* y, |
262 | | - int* out); |
263 | | - |
264 | | -// Returns the sum of the absolute values of the elements of vector x |
265 | | - virtual void asum(const int N, const Dtype* x, Dtype* y); |
266 | | - |
267 | | - virtual void sign(const int N, const Dtype* x, Dtype* y); |
268 | | - |
269 | | - virtual void sgnbit(const int N, const Dtype* x, Dtype* y); |
270 | | - |
271 | | - virtual void fabs(const int N, const Dtype* x, Dtype* y); |
272 | | - |
273 | | - virtual void scale(const int N, const Dtype alpha, const Dtype *x, Dtype* y); |
274 | | - |
275 | | - virtual void im2col(const Dtype* data_im, const int channels, |
276 | | - const int height, const int width, const int ksize, |
277 | | - const int pad, const int stride, Dtype* data_col); |
278 | | - |
279 | | - virtual void col2im(const Dtype* data_col, const int channels, |
280 | | - const int height, const int width, const int psize, |
281 | | - const int pad, const int stride, Dtype* data_im); |
282 | | -}; |
283 | | - |
284 | 125 | // Device factory function |
285 | 126 | template<typename Dtype> |
286 | 127 | Device<Dtype>* GetDevice(Caffe::Brew mode = Caffe::UNSPECIFIED); |
287 | 128 |
|
288 | 129 | } // namespace caffe |
289 | 130 |
|
290 | | -#endif // CAFFE_UTIL_DEVICE_H_ |
| 131 | +#endif // CAFFE_DEVICE_H_ |
0 commit comments