Skip to content

Commit 96ed3c2

Browse files
Rob Hessjeffdonahue
authored andcommitted
Make device factory usage more inline with layer factory.
1 parent ff9bff0 commit 96ed3c2

File tree

7 files changed

+58
-41
lines changed

7 files changed

+58
-41
lines changed

include/caffe/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class Caffe {
105105
}
106106
return *singleton_;
107107
}
108-
enum Brew { CPU, GPU };
108+
enum Brew { UNSPECIFIED = -1, CPU, GPU };
109109
enum Phase { TRAIN, TEST };
110110

111111

include/caffe/device.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "glog/logging.h"
1010

11+
#include "caffe/common.hpp"
1112
#include "caffe/util/im2col.hpp"
1213
#include "caffe/util/math_functions.hpp"
1314

@@ -252,14 +253,9 @@ class GPUDevice : public Device<Dtype> {
252253
const int stride, Dtype* data_im);
253254
};
254255

256+
// Device factory function
255257
template<typename Dtype>
256-
class DeviceFactory {
257-
public:
258-
static Device<Dtype>* GetDevice();
259-
private:
260-
static Device<Dtype>* cpu_device_;
261-
static Device<Dtype>* gpu_device_;
262-
};
258+
Device<Dtype>* GetDevice(Caffe::Brew mode = Caffe::UNSPECIFIED);
263259

264260
} // namespace caffe
265261

include/caffe/layer.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Layer {
1919
// to SetUp(), where the dimensions of the bottom blobs are provided to the
2020
// layer.
2121
explicit Layer(const LayerParameter& param)
22-
: layer_param_(param), device_(DeviceFactory<Dtype>::GetDevice()) {
22+
: layer_param_(param), device_(GetDevice<Dtype>()) {
2323
// The only thing we do is to copy blobs if there are any.
2424
if (layer_param_.blobs_size() > 0) {
2525
blobs_.resize(layer_param_.blobs_size());

src/caffe/device_factory.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright 2014 BVLC and contributors.
2+
3+
#include "caffe/common.hpp"
4+
#include "caffe/device.hpp"
5+
6+
namespace caffe {
7+
8+
template<typename Dtype>
9+
class DeviceFactory {
10+
public:
11+
static Device<Dtype>* GetDevice(Caffe::Brew mode);
12+
private:
13+
static Device<Dtype>* cpu_device_;
14+
static Device<Dtype>* gpu_device_;
15+
};
16+
17+
template<typename Dtype>
18+
Device<Dtype>* DeviceFactory<Dtype>::GetDevice(Caffe::Brew mode) {
19+
if (mode == Caffe::UNSPECIFIED)
20+
mode = Caffe::mode();
21+
switch (mode) {
22+
case Caffe::CPU:
23+
return cpu_device_;
24+
case Caffe::GPU:
25+
return gpu_device_;
26+
default:
27+
LOG(FATAL) << "Unknown caffe mode.";
28+
return static_cast<Device<Dtype>*>(NULL);
29+
}
30+
}
31+
32+
template<typename Dtype>
33+
Device<Dtype>* DeviceFactory<Dtype>::cpu_device_ = new CPUDevice<Dtype>();
34+
35+
template<typename Dtype>
36+
Device<Dtype>* DeviceFactory<Dtype>::gpu_device_ = new GPUDevice<Dtype>();
37+
38+
INSTANTIATE_CLASS(DeviceFactory);
39+
40+
// A function to get the device either for the current mode (if
41+
// Caffe::UNSPECIFIED is passed, which is the default), or for a specific mode
42+
// (if a specific device's Caffe::Brew is passed).
43+
template<typename Dtype>
44+
Device<Dtype>* GetDevice(Caffe::Brew mode) {
45+
return DeviceFactory<Dtype>::GetDevice(mode);
46+
}
47+
48+
template Device<float>* GetDevice(Caffe::Brew mode);
49+
template Device<double>* GetDevice(Caffe::Brew mode);
50+
51+
} // namespace caffe

src/caffe/net.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ void Net<Dtype>::Update() {
684684
const int count = params_[i]->count();
685685
const Dtype* this_diff = params_[i]->const_diff();
686686
Dtype* owner_diff = params_[param_owners_[i]]->mutable_diff();
687-
Device<Dtype>* device = DeviceFactory<Dtype>::GetDevice();
687+
Device<Dtype>* device = GetDevice<Dtype>();
688688
device->add(count, this_diff, owner_diff, owner_diff);
689689
}
690690
// Now, update the owned parameters.

src/caffe/solver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ void SGDSolver<Dtype>::ComputeUpdateValue() {
367367
}
368368
Dtype momentum = this->param_.momentum();
369369
Dtype weight_decay = this->param_.weight_decay();
370-
Device<Dtype>* device = DeviceFactory<Dtype>::GetDevice();
370+
Device<Dtype>* device = GetDevice<Dtype>();
371371
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
372372
// Compute the value to history, and then copy them to the blob's diff.
373373
Dtype local_rate = rate * net_params_lr[param_id];

src/caffe/util/device.cpp

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)