@@ -73,9 +73,14 @@ namespace tensorflow {
7373#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
7474class EigenAllocator : public ::Eigen::Allocator {
7575 public:
76- explicit EigenAllocator (gpu::Stream* stream, ::tensorflow::Allocator* alloc,
77- EventMgr* em)
78- : stream_(stream), allocator_(alloc), em_(em) {}
76+ EigenAllocator () {}
77+
78+ void Reinitialize (gpu::Stream* stream, ::tensorflow::Allocator* alloc,
79+ EventMgr* em) {
80+ stream_ = stream;
81+ allocator_ = alloc;
82+ em_ = em;
83+ }
7984
8085 void * allocate (size_t num_bytes) const override {
8186 void * ret = allocator_->AllocateRaw (32 /* alignment */ , num_bytes);
@@ -103,10 +108,12 @@ class EigenAllocator : public ::Eigen::Allocator {
103108#else
104109class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
105110 public:
106- EigenCudaStreamDevice (const cudaStream_t* cuda_stream, int gpu_id,
107- ::tensorflow::Allocator* alloc)
108- : stream_(cuda_stream), allocator_(alloc) {
109- Eigen::initializeDeviceProp ();
111+ EigenCudaStreamDevice () { Eigen::initializeDeviceProp (); }
112+
113+ void Reinitialize (const cudaStream_t* cuda_stream, int gpu_id,
114+ ::tensorflow::Allocator* alloc) {
115+ stream_ = cuda_stream;
116+ allocator_ = alloc;
110117 device_prop_ = &Eigen::m_deviceProperties[gpu_id];
111118 }
112119
@@ -391,10 +398,11 @@ namespace {
391398#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
392399class ConcretePerOpGpuDevice : public PerOpGpuDevice {
393400 public:
394- explicit ConcretePerOpGpuDevice (gpu::Stream* stream,
395- Allocator* base_allocator,
396- ::tensorflow::EventMgr* em)
397- : allocator_(stream, base_allocator, em), device_(stream, &allocator_) {}
401+ void Reinitialize (gpu::Stream* stream, Allocator* base_allocator,
402+ ::tensorflow::EventMgr* em) {
403+ allocator_.Reinitialize (stream, base_allocator, em);
404+ device_.Reinitialize (stream, &allocator_);
405+ }
398406
399407 const Eigen::GpuDevice& device () const override { return device_; }
400408
@@ -405,10 +413,12 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
405413#else
406414class ConcretePerOpGpuDevice : public PerOpGpuDevice {
407415 public:
408- explicit ConcretePerOpGpuDevice (const cudaStream_t* cuda_stream, int gpu_id,
409- Allocator* base_allocator)
410- : stream_device_(cuda_stream, gpu_id, base_allocator),
411- device_(&stream_device_) {}
416+ ConcretePerOpGpuDevice () : device_(&stream_device_) {}
417+
418+ void Reinitialize (const cudaStream_t* cuda_stream, int gpu_id,
419+ Allocator* base_allocator) {
420+ stream_device_.Reinitialize (cuda_stream, gpu_id, base_allocator);
421+ }
412422
413423 const Eigen::GpuDevice& device () const override { return device_; }
414424
@@ -419,28 +429,36 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
419429#endif
420430} // namespace
421431
422- const PerOpGpuDevice* BaseGPUDevice::NewDevice (int stream_id,
423- Allocator* allocator) {
432+ void BaseGPUDevice::ReinitializeDevice (PerOpGpuDevice* device, int stream_id,
433+ Allocator* allocator) {
434+ ConcretePerOpGpuDevice* concrete_device =
435+ dynamic_cast <ConcretePerOpGpuDevice*>(device);
436+ DCHECK (concrete_device);
424437#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
425- return new ConcretePerOpGpuDevice (streams_[stream_id], allocator, em_.get ());
438+ concrete_device-> Reinitialize (streams_[stream_id], allocator, em_.get ());
426439#else
427440 const cudaStream_t* cuda_stream = reinterpret_cast <const cudaStream_t*>(
428441 streams_[stream_id]->implementation ()->CudaStreamMemberHack ());
429- return new ConcretePerOpGpuDevice (cuda_stream, gpu_id_, allocator);
442+ concrete_device-> Reinitialize (cuda_stream, gpu_id_, allocator);
430443#endif
431444}
432445
433- const PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice (DeviceContext* dc,
434- Allocator* allocator) {
446+ PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice () {
447+ return new ConcretePerOpGpuDevice ();
448+ }
449+
450+ void BaseGPUDevice::ReinitializeGpuDevice (PerOpGpuDevice* device,
451+ DeviceContext* dc,
452+ Allocator* allocator) {
435453 if (dc) {
436454 const GPUDeviceContext* gpu_dc = static_cast <GPUDeviceContext*>(dc);
437455 const int stream_id = gpu_dc->stream_id ();
438456 VLOG (1 ) << " eigen_gpu_device(" << dc << " ) => stream[" << stream_id
439457 << " ]" ;
440458 CHECK_LT (stream_id, streams_.size ());
441- return NewDevice ( stream_id, allocator);
459+ ReinitializeDevice (device, stream_id, allocator);
442460 } else {
443- return NewDevice ( 0 , allocator);
461+ ReinitializeDevice (device, 0 , allocator);
444462 }
445463}
446464
0 commit comments