@@ -25,6 +25,7 @@ limitations under the License.
2525#include " tensorflow/core/kernels/ops_util.h"
2626#include " tensorflow/core/lib/core/errors.h"
2727#include " tensorflow/core/lib/gtl/array_slice.h"
28+ #include " tensorflow/core/lib/strings/numbers.h"
2829#include " tensorflow/core/platform/logging.h"
2930#include " tensorflow/core/public/tensor.h"
3031#include " tensorflow/core/public/tensor_shape.h"
@@ -34,6 +35,7 @@ limitations under the License.
3435#if GOOGLE_CUDA
3536#include " tensorflow/stream_executor/stream.h"
3637#include " tensorflow/core/common_runtime/gpu_device_context.h"
38+ #include " tensorflow/core/kernels/conv_ops_gpu.h"
3739#endif // GOOGLE_CUDA
3840
3941namespace tensorflow {
@@ -206,16 +208,22 @@ REGISTER_KERNEL_BUILDER(Name("Conv2D")
206208
207209#if GOOGLE_CUDA
208210
209- namespace {
210- template <typename T>
211- perftools::gputools::DeviceMemory<T> AsDeviceMemory (const T* cuda_memory,
212- uint64 size) {
213- perftools::gputools::DeviceMemoryBase wrapped (const_cast <T*>(cuda_memory),
214- size * sizeof (T));
215- perftools::gputools::DeviceMemory<T> typed (wrapped);
216- return typed;
211+ int64 GetCudnnWorkspaceLimit (const string& envvar_in_mb,
212+ int64 default_value_in_bytes) {
213+ const char * workspace_limit_in_mb_str = getenv (envvar_in_mb.c_str ());
214+ if (workspace_limit_in_mb_str != nullptr &&
215+ strcmp (workspace_limit_in_mb_str, " " ) != 0 ) {
216+ int64 scratch_limit_in_mb = -1 ;
217+ if (strings::safe_strto64 (workspace_limit_in_mb_str,
218+ &scratch_limit_in_mb)) {
219+ return scratch_limit_in_mb * (1 << 20 );
220+ } else {
221+ LOG (WARNING) << " Invalid value for env-var " << envvar_in_mb << " : "
222+ << workspace_limit_in_mb_str;
223+ }
224+ }
225+ return default_value_in_bytes;
217226}
218- } // namespace
219227
220228template <typename T>
221229struct LaunchConvOp <GPUDevice, T> {
@@ -287,18 +295,34 @@ struct LaunchConvOp<GPUDevice, T> {
287295 input = transformed_input;
288296 }
289297
298+ {
299+ // Convert the input tensor from NHWC to NCHW.
300+ Tensor transformed_input;
301+ OP_REQUIRES_OK (ctx,
302+ ctx->allocate_temp (
303+ DataTypeToEnum<T>::value,
304+ TensorShape ({input.dim_size (0 ), input.dim_size (3 ),
305+ input.dim_size (1 ), input.dim_size (2 )}),
306+ &transformed_input));
307+ functor::NHWCToNCHW<GPUDevice, T>()(
308+ ctx->eigen_device <GPUDevice>(),
309+ const_cast <const Tensor&>(input).tensor <T, 4 >(),
310+ transformed_input.tensor <T, 4 >());
311+ input = transformed_input;
312+ }
313+
290314 perftools::gputools::dnn::BatchDescriptor input_desc;
291315 input_desc.set_count (input.dim_size (0 ))
292- .set_height (input.dim_size (1 ))
293- .set_width (input.dim_size (2 ))
294- .set_feature_map_count (input.dim_size (3 ))
295- .set_layout (perftools::gputools::dnn::DataLayout::kBatchYXDepth );
316+ .set_feature_map_count (input.dim_size (1 ))
317+ .set_height (input.dim_size (2 ))
318+ .set_width (input.dim_size (3 ))
319+ .set_layout (perftools::gputools::dnn::DataLayout::kBatchDepthYX );
296320 perftools::gputools::dnn::BatchDescriptor output_desc;
297321 output_desc.set_count (output->dim_size (0 ))
298322 .set_height (output->dim_size (1 ))
299323 .set_width (output->dim_size (2 ))
300324 .set_feature_map_count (output->dim_size (3 ))
301- .set_layout (perftools::gputools::dnn::DataLayout::kBatchYXDepth );
325+ .set_layout (perftools::gputools::dnn::DataLayout::kBatchDepthYX );
302326 perftools::gputools::dnn::FilterDescriptor filter_desc;
303327 filter_desc.set_input_filter_height (filter.dim_size (0 ))
304328 .set_input_filter_width (filter.dim_size (1 ))
@@ -320,24 +344,44 @@ struct LaunchConvOp<GPUDevice, T> {
320344 ctx->eigen_device <GPUDevice>(), To32Bit (filter.tensor <T, 4 >()),
321345 To32Bit (transformed_filter.tensor <T, 4 >()));
322346
347+ Tensor transformed_output;
348+ OP_REQUIRES_OK (
349+ ctx, ctx->allocate_temp (
350+ DataTypeToEnum<T>::value,
351+ TensorShape ({output->dim_size (0 ), output->dim_size (3 ),
352+ output->dim_size (1 ), output->dim_size (2 )}),
353+ &transformed_output));
354+
323355 auto input_ptr = AsDeviceMemory (input.template flat <T>().data (),
324356 input.template flat <T>().size ());
325357 auto filter_ptr =
326358 AsDeviceMemory (transformed_filter.template flat <T>().data (),
327359 transformed_filter.template flat <T>().size ());
328- auto output_ptr = AsDeviceMemory (output->template flat <T>().data (),
329- output->template flat <T>().size ());
330-
360+ auto output_ptr =
361+ AsDeviceMemory (transformed_output.template flat <T>().data (),
362+ transformed_output.template flat <T>().size ());
363+
364+ static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit (
365+ " TF_CUDNN_WORKSPACE_LIMIT_IN_MB" , 1LL << 30 // 1GB by default
366+ );
367+ CudnnScratchAllocator scratch_allocator (ConvolveScratchSize, ctx);
331368 bool cudnn_launch_status =
332- stream->ThenConvolve (input_desc, input_ptr, filter_desc, filter_ptr,
333- conv_desc, output_desc, &output_ptr)
369+ stream->ThenConvolveWithScratch (input_desc, input_ptr, filter_desc,
370+ filter_ptr, conv_desc, output_desc,
371+ &output_ptr, &scratch_allocator)
334372 .ok ();
335373
336374 if (!cudnn_launch_status) {
337375 ctx->SetStatus (errors::Internal (
338376 " cuDNN launch failure : input shape(" , input.shape ().DebugString (),
339377 " ) filter shape(" , filter.shape ().DebugString (), " )" ));
340378 }
379+
380+ // Convert the output tensor back from NHWC to NCHW.
381+ functor::NCHWToNHWC<GPUDevice, T>()(
382+ ctx->eigen_device <GPUDevice>(),
383+ const_cast <const Tensor&>(transformed_output).tensor <T, 4 >(),
384+ output->tensor <T, 4 >());
341385 } else {
342386 LaunchGeneric<GPUDevice, T>::launch (ctx, input_param, filter, stride,
343387 padding, output);
0 commit comments