Skip to content

Commit 39c8658

Browse files
Rob Hessjeffdonahue
authored andcommitted
Separate device.hpp into device-specific headers.
1 parent 8d111e4 commit 39c8658

File tree

7 files changed

+194
-165
lines changed

7 files changed

+194
-165
lines changed

include/caffe/device.hpp

Lines changed: 3 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,12 @@
11
// Copyright 2014 BVLC and contributors.
22

3-
#ifndef CAFFE_UTIL_DEVICE_H_
4-
#define CAFFE_UTIL_DEVICE_H_
3+
#ifndef CAFFE_DEVICE_H_
4+
#define CAFFE_DEVICE_H_
55

66
extern "C" {
77
#include <cblas.h>
88
}
99

10-
#include <cublas_v2.h>
11-
#include <stdint.h>
12-
13-
#include "glog/logging.h"
14-
1510
#include "caffe/common.hpp"
1611

1712
namespace caffe {
@@ -127,164 +122,10 @@ class Device {
127122
}
128123
};
129124

130-
template<typename Dtype>
131-
class CPUDevice : public Device<Dtype> {
132-
public:
133-
CPUDevice() {
134-
}
135-
virtual ~CPUDevice() {
136-
}
137-
virtual void gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
138-
const int M, const int N, const int K, const Dtype alpha,
139-
const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C);
140-
141-
virtual void gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
142-
const Dtype alpha, const Dtype* A, const Dtype* x,
143-
const Dtype beta, Dtype* y);
144-
145-
virtual void axpy(const int N, const Dtype alpha, const Dtype* X, Dtype* Y);
146-
147-
virtual void axpby(const int N, const Dtype alpha, const Dtype* X,
148-
const Dtype beta, Dtype* Y);
149-
150-
/* NOLINT_NEXT_LINE(build/include_what_you_use) */
151-
virtual void copy(const int N, const Dtype *X, Dtype *Y);
152-
153-
virtual void copy_from_cpu(const int N, const Dtype* X, Dtype* Y);
154-
155-
virtual void set(const int N, const Dtype alpha, Dtype *X);
156-
157-
virtual void add_scalar(const int N, const Dtype alpha, Dtype *X);
158-
159-
virtual void scal(const int N, const Dtype alpha, Dtype *X);
160-
161-
virtual void sqr(const int N, const Dtype* a, Dtype* y);
162-
163-
virtual void add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
164-
165-
virtual void sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
166-
167-
virtual void mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
168-
169-
virtual void div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
170-
171-
virtual void powx(const int N, const Dtype* a, const Dtype b, Dtype* y);
172-
173-
virtual void rng_uniform(const int N, const Dtype a, const Dtype b, Dtype* r);
174-
175-
virtual void rng_gaussian(const int N, const Dtype mu, const Dtype sigma,
176-
Dtype* r);
177-
178-
virtual void rng_bernoulli(const int N, const Dtype p, int* r);
179-
180-
virtual void rng_bernoulli(const int N, const Dtype p, unsigned int* r);
181-
182-
virtual void exp(const int N, const Dtype* a, Dtype* y);
183-
184-
virtual void dot(const int N, const Dtype* x, const Dtype* y, Dtype* out);
185-
186-
virtual void hamming_distance(const int N, const Dtype* x, const Dtype* y,
187-
int* out);
188-
189-
// Returns the sum of the absolute values of the elements of vector x
190-
virtual void asum(const int N, const Dtype* x, Dtype* y);
191-
192-
virtual void sign(const int N, const Dtype* x, Dtype* y);
193-
194-
virtual void sgnbit(const int N, const Dtype* x, Dtype* y);
195-
196-
virtual void fabs(const int N, const Dtype* x, Dtype* y);
197-
198-
virtual void scale(const int N, const Dtype alpha, const Dtype *x, Dtype* y);
199-
200-
virtual void im2col(const Dtype* data_im, const int channels,
201-
const int height, const int width, const int ksize,
202-
const int pad, const int stride, Dtype* data_col);
203-
204-
virtual void col2im(const Dtype* data_col, const int channels,
205-
const int height, const int width, const int ksize,
206-
const int pad, const int stride, Dtype* data_im);
207-
};
208-
209-
template<typename Dtype>
210-
class GPUDevice : public Device<Dtype> {
211-
public:
212-
GPUDevice() {
213-
}
214-
virtual ~GPUDevice() {
215-
}
216-
virtual void gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
217-
const int M, const int N, const int K, const Dtype alpha,
218-
const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C);
219-
220-
virtual void gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
221-
const Dtype alpha, const Dtype* A, const Dtype* x,
222-
const Dtype beta, Dtype* y);
223-
224-
virtual void axpy(const int N, const Dtype alpha, const Dtype* X, Dtype* Y);
225-
226-
virtual void axpby(const int N, const Dtype alpha, const Dtype* X,
227-
const Dtype beta, Dtype* Y);
228-
229-
/* NOLINT_NEXT_LINE(build/include_what_you_use) */
230-
virtual void copy(const int N, const Dtype *X, Dtype *Y);
231-
232-
virtual void copy_from_cpu(const int N, const Dtype* X, Dtype* Y);
233-
234-
virtual void set(const int N, const Dtype alpha, Dtype *X);
235-
236-
virtual void add_scalar(const int N, const Dtype alpha, Dtype *X);
237-
238-
virtual void scal(const int N, const Dtype alpha, Dtype *X);
239-
240-
virtual void sqr(const int N, const Dtype* a, Dtype* y);
241-
242-
virtual void add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
243-
244-
virtual void sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
245-
246-
virtual void mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
247-
248-
virtual void div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
249-
250-
virtual void powx(const int N, const Dtype* a, const Dtype b, Dtype* y);
251-
252-
virtual void rng_uniform(const int N, const Dtype a, const Dtype b, Dtype* r);
253-
254-
virtual void rng_gaussian(const int N, const Dtype mu, const Dtype sigma,
255-
Dtype* r);
256-
257-
virtual void exp(const int N, const Dtype* a, Dtype* y);
258-
259-
virtual void dot(const int N, const Dtype* x, const Dtype* y, Dtype* out);
260-
261-
virtual void hamming_distance(const int N, const Dtype* x, const Dtype* y,
262-
int* out);
263-
264-
// Returns the sum of the absolute values of the elements of vector x
265-
virtual void asum(const int N, const Dtype* x, Dtype* y);
266-
267-
virtual void sign(const int N, const Dtype* x, Dtype* y);
268-
269-
virtual void sgnbit(const int N, const Dtype* x, Dtype* y);
270-
271-
virtual void fabs(const int N, const Dtype* x, Dtype* y);
272-
273-
virtual void scale(const int N, const Dtype alpha, const Dtype *x, Dtype* y);
274-
275-
virtual void im2col(const Dtype* data_im, const int channels,
276-
const int height, const int width, const int ksize,
277-
const int pad, const int stride, Dtype* data_col);
278-
279-
virtual void col2im(const Dtype* data_col, const int channels,
280-
const int height, const int width, const int psize,
281-
const int pad, const int stride, Dtype* data_im);
282-
};
283-
284125
// Device factory function
285126
template<typename Dtype>
286127
Device<Dtype>* GetDevice(Caffe::Brew mode = Caffe::UNSPECIFIED);
287128

288129
} // namespace caffe
289130

290-
#endif // CAFFE_UTIL_DEVICE_H_
131+
#endif // CAFFE_DEVICE_H_

include/caffe/devices/cpu.hpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// Copyright 2014 BVLC and contributors.
2+
3+
#ifndef CAFFE_DEVICES_CPU_H_
4+
#define CAFFE_DEVICES_CPU_H_
5+
6+
extern "C" {
7+
#include <cblas.h>
8+
}
9+
10+
#include "caffe/device.hpp"
11+
12+
namespace caffe {
13+
14+
template<typename Dtype>
15+
class CPUDevice : public Device<Dtype> {
16+
public:
17+
CPUDevice() {
18+
}
19+
virtual ~CPUDevice() {
20+
}
21+
virtual void gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
22+
const int M, const int N, const int K, const Dtype alpha,
23+
const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C);
24+
25+
virtual void gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
26+
const Dtype alpha, const Dtype* A, const Dtype* x,
27+
const Dtype beta, Dtype* y);
28+
29+
virtual void axpy(const int N, const Dtype alpha, const Dtype* X, Dtype* Y);
30+
31+
virtual void axpby(const int N, const Dtype alpha, const Dtype* X,
32+
const Dtype beta, Dtype* Y);
33+
34+
/* NOLINT_NEXT_LINE(build/include_what_you_use) */
35+
virtual void copy(const int N, const Dtype *X, Dtype *Y);
36+
37+
virtual void copy_from_cpu(const int N, const Dtype* X, Dtype* Y);
38+
39+
virtual void set(const int N, const Dtype alpha, Dtype *X);
40+
41+
virtual void add_scalar(const int N, const Dtype alpha, Dtype *X);
42+
43+
virtual void scal(const int N, const Dtype alpha, Dtype *X);
44+
45+
virtual void sqr(const int N, const Dtype* a, Dtype* y);
46+
47+
virtual void add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
48+
49+
virtual void sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
50+
51+
virtual void mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
52+
53+
virtual void div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
54+
55+
virtual void powx(const int N, const Dtype* a, const Dtype b, Dtype* y);
56+
57+
virtual void rng_uniform(const int N, const Dtype a, const Dtype b, Dtype* r);
58+
59+
virtual void rng_gaussian(const int N, const Dtype mu, const Dtype sigma,
60+
Dtype* r);
61+
62+
virtual void rng_bernoulli(const int N, const Dtype p, int* r);
63+
64+
virtual void rng_bernoulli(const int N, const Dtype p, unsigned int* r);
65+
66+
virtual void exp(const int N, const Dtype* a, Dtype* y);
67+
68+
virtual void dot(const int N, const Dtype* x, const Dtype* y, Dtype* out);
69+
70+
virtual void hamming_distance(const int N, const Dtype* x, const Dtype* y,
71+
int* out);
72+
73+
// Returns the sum of the absolute values of the elements of vector x
74+
virtual void asum(const int N, const Dtype* x, Dtype* y);
75+
76+
virtual void sign(const int N, const Dtype* x, Dtype* y);
77+
78+
virtual void sgnbit(const int N, const Dtype* x, Dtype* y);
79+
80+
virtual void fabs(const int N, const Dtype* x, Dtype* y);
81+
82+
virtual void scale(const int N, const Dtype alpha, const Dtype *x, Dtype* y);
83+
84+
virtual void im2col(const Dtype* data_im, const int channels,
85+
const int height, const int width, const int ksize,
86+
const int pad, const int stride, Dtype* data_col);
87+
88+
virtual void col2im(const Dtype* data_col, const int channels,
89+
const int height, const int width, const int ksize,
90+
const int pad, const int stride, Dtype* data_im);
91+
};
92+
93+
} // namespace caffe
94+
95+
#endif // CAFFE_DEVICES_CPU_H_

include/caffe/devices/gpu.hpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright 2014 BVLC and contributors.
2+
3+
#ifndef CAFFE_DEVICES_GPU_H_
4+
#define CAFFE_DEVICES_GPU_H_
5+
6+
extern "C" {
7+
#include <cblas.h>
8+
}
9+
10+
#include "caffe/device.hpp"
11+
12+
namespace caffe {
13+
14+
template<typename Dtype>
15+
class GPUDevice : public Device<Dtype> {
16+
public:
17+
GPUDevice() {
18+
}
19+
virtual ~GPUDevice() {
20+
}
21+
virtual void gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
22+
const int M, const int N, const int K, const Dtype alpha,
23+
const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C);
24+
25+
virtual void gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
26+
const Dtype alpha, const Dtype* A, const Dtype* x,
27+
const Dtype beta, Dtype* y);
28+
29+
virtual void axpy(const int N, const Dtype alpha, const Dtype* X, Dtype* Y);
30+
31+
virtual void axpby(const int N, const Dtype alpha, const Dtype* X,
32+
const Dtype beta, Dtype* Y);
33+
34+
/* NOLINT_NEXT_LINE(build/include_what_you_use) */
35+
virtual void copy(const int N, const Dtype *X, Dtype *Y);
36+
37+
virtual void copy_from_cpu(const int N, const Dtype* X, Dtype* Y);
38+
39+
virtual void set(const int N, const Dtype alpha, Dtype *X);
40+
41+
virtual void add_scalar(const int N, const Dtype alpha, Dtype *X);
42+
43+
virtual void scal(const int N, const Dtype alpha, Dtype *X);
44+
45+
virtual void sqr(const int N, const Dtype* a, Dtype* y);
46+
47+
virtual void add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
48+
49+
virtual void sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
50+
51+
virtual void mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
52+
53+
virtual void div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
54+
55+
virtual void powx(const int N, const Dtype* a, const Dtype b, Dtype* y);
56+
57+
virtual void rng_uniform(const int N, const Dtype a, const Dtype b, Dtype* r);
58+
59+
virtual void rng_gaussian(const int N, const Dtype mu, const Dtype sigma,
60+
Dtype* r);
61+
62+
virtual void exp(const int N, const Dtype* a, Dtype* y);
63+
64+
virtual void dot(const int N, const Dtype* x, const Dtype* y, Dtype* out);
65+
66+
virtual void hamming_distance(const int N, const Dtype* x, const Dtype* y,
67+
int* out);
68+
69+
// Returns the sum of the absolute values of the elements of vector x
70+
virtual void asum(const int N, const Dtype* x, Dtype* y);
71+
72+
virtual void sign(const int N, const Dtype* x, Dtype* y);
73+
74+
virtual void sgnbit(const int N, const Dtype* x, Dtype* y);
75+
76+
virtual void fabs(const int N, const Dtype* x, Dtype* y);
77+
78+
virtual void scale(const int N, const Dtype alpha, const Dtype *x, Dtype* y);
79+
80+
virtual void im2col(const Dtype* data_im, const int channels,
81+
const int height, const int width, const int ksize,
82+
const int pad, const int stride, Dtype* data_col);
83+
84+
virtual void col2im(const Dtype* data_col, const int channels,
85+
const int height, const int width, const int psize,
86+
const int pad, const int stride, Dtype* data_im);
87+
};
88+
89+
} // namespace caffe
90+
91+
#endif // CAFFE_DEVICES_GPU_H_

src/caffe/device_factory.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include "caffe/common.hpp"
44
#include "caffe/device.hpp"
5+
#include "caffe/devices/cpu.hpp"
6+
#include "caffe/devices/gpu.hpp"
57

68
namespace caffe {
79

src/caffe/devices/cpu_device.cpp renamed to src/caffe/devices/cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ extern "C" {
1212
#include <algorithm>
1313

1414
#include "caffe/common.hpp"
15-
#include "caffe/device.hpp"
15+
#include "caffe/devices/cpu.hpp"
1616
#include "caffe/util/mkl_alternate.hpp"
1717
#include "caffe/util/rng.hpp"
1818

0 commit comments

Comments
 (0)