@@ -383,12 +383,53 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
383383 // The output image size is the spatial size of the output.
384384 const int output_image_size = out_rows * out_cols;
385385
386+ // TODO(andydavis) Get L2/L3 cache sizes from device.
387+ const size_t l2_cache_size = 256LL << 10 ;
388+ const size_t l3_cache_size = 30LL << 20 ;
389+
390+ // Use L3 cache size as target working set size.
391+ const size_t target_working_set_size = l3_cache_size / sizeof (T);
392+
393+ // Calculate size of matrices involved in MatMul: C = A x B.
394+ const size_t size_A = output_image_size * out_depth;
395+
396+ const size_t size_B = filter_total_size * out_depth;
397+
398+ const size_t size_C = output_image_size * filter_total_size;
399+
400+ const size_t work_unit_size = size_A + size_B + size_C;
401+
402+ auto worker_threads = *(context->device ()->tensorflow_cpu_worker_threads ());
403+
404+ // Calculate per-thread work unit size.
405+ const size_t thread_work_unit_size =
406+ work_unit_size / worker_threads.num_threads ;
407+
408+ // Set minimum per-thread work unit size to size of L2 cache.
409+ const size_t min_thread_work_unit_size = l2_cache_size / sizeof (T);
410+
411+ // Use parallel tensor contractions if there is no batching, or if the
412+ // minimum per-thread work unit size threshold has been exceeded.
413+ // Otherwise, revert to multiple single-threaded matmul ops running in
414+ // parallel to keep all threads busy.
415+ // TODO(andydavis) Explore alternatives to branching the code in this way
416+ // (i.e. run multiple, parallel tensor contractions in another thread pool).
417+ const bool use_parallel_contraction =
418+ batch == 1 || thread_work_unit_size >= min_thread_work_unit_size;
419+
420+ const size_t shard_size =
421+ use_parallel_contraction
422+ ? 1
423+ : (target_working_set_size + work_unit_size - 1 ) / work_unit_size;
424+
386425 Tensor col_buffer;
387- OP_REQUIRES_OK (
388- context,
389- context->allocate_temp (
390- DataTypeToEnum<T>::value,
391- TensorShape ({output_image_size, filter_total_size}), &col_buffer));
426+ OP_REQUIRES_OK (context,
427+ context->allocate_temp (
428+ DataTypeToEnum<T>::value,
429+ TensorShape ({static_cast <int64>(shard_size),
430+ static_cast <int64>(output_image_size),
431+ static_cast <int64>(filter_total_size)}),
432+ &col_buffer));
392433
393434 // The input offset corresponding to a single input image.
394435 const int input_offset = input_rows * input_cols * in_depth;
@@ -400,31 +441,74 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
400441 auto * out_backprop_data = out_backprop.template flat <T>().data ();
401442 auto * input_backprop_data = in_backprop->template flat <T>().data ();
402443
403- typedef Eigen::TensorMap<Eigen::Tensor<T, 2 , Eigen::RowMajor>,
404- Eigen::Unaligned> TensorMap;
405- typedef Eigen::TensorMap<Eigen::Tensor<const T, 2 , Eigen::RowMajor>,
406- Eigen::Unaligned> ConstTensorMap;
444+ if (use_parallel_contraction) {
445+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2 , Eigen::RowMajor>,
446+ Eigen::Unaligned> TensorMap;
447+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2 , Eigen::RowMajor>,
448+ Eigen::Unaligned> ConstTensorMap;
407449
408- // Initialize contraction dims (we need to transpose 'B' below).
409- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1 > contract_dims;
410- contract_dims[0 ].first = 1 ;
411- contract_dims[0 ].second = 1 ;
450+ // Initialize contraction dims (we need to transpose 'B' below).
451+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1 > contract_dims;
452+ contract_dims[0 ].first = 1 ;
453+ contract_dims[0 ].second = 1 ;
412454
413- for (int image_id = 0 ; image_id < batch; ++image_id) {
414- // Compute gradient into col_buffer.
415- TensorMap C (col_buffer_data, output_image_size, filter_total_size);
455+ for (int image_id = 0 ; image_id < batch; ++image_id) {
456+ // Compute gradient into col_buffer.
457+ TensorMap C (col_buffer_data, output_image_size, filter_total_size);
416458
417- ConstTensorMap A (out_backprop_data + output_offset * image_id,
418- output_image_size, out_depth);
419- ConstTensorMap B (filter_data, filter_total_size, out_depth);
459+ ConstTensorMap A (out_backprop_data + output_offset * image_id,
460+ output_image_size, out_depth);
461+ ConstTensorMap B (filter_data, filter_total_size, out_depth);
420462
421- C.device (context->eigen_cpu_device ()) = A.contract (B, contract_dims);
463+ C.device (context->eigen_cpu_device ()) = A.contract (B, contract_dims);
422464
423- Col2im<T>(col_buffer_data, in_depth, input_rows, input_cols, filter_rows ,
424- filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride ,
425- stride, input_backprop_data);
465+ Col2im<T>(col_buffer_data, in_depth, input_rows, input_cols,
466+ filter_rows, filter_cols, pad_top, pad_left, pad_bottom,
467+ pad_right, stride, stride, input_backprop_data);
426468
427- input_backprop_data += input_offset;
469+ input_backprop_data += input_offset;
470+ }
471+ } else {
472+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
473+ Eigen::RowMajor>> MatrixMap;
474+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
475+ Eigen::RowMajor>> ConstMatrixMap;
476+
477+ for (int image_id = 0 ; image_id < batch; image_id += shard_size) {
478+ const int shard_limit = std::min (static_cast <int >(shard_size),
479+ static_cast <int >(batch) - image_id);
480+
481+ auto shard = [&in_depth, &input_rows, &input_cols, &filter_rows,
482+ &filter_cols, &pad_top, &pad_left, &pad_bottom,
483+ &pad_right, &stride, &output_image_size,
484+ &filter_total_size, &out_depth, &input_backprop_data,
485+ &col_buffer_data, &out_backprop_data, &filter_data,
486+ &input_offset, &output_offset,
487+ &size_C](int64 start, int64 limit) {
488+ for (int shard_id = start; shard_id < limit; ++shard_id) {
489+ T* im2col_buf = col_buffer_data + shard_id * size_C;
490+ T* input_data = input_backprop_data + shard_id * input_offset;
491+ const T* out_data = out_backprop_data + shard_id * output_offset;
492+
493+ // Compute gradient into 'im2col_buf'.
494+ MatrixMap C (im2col_buf, output_image_size, filter_total_size);
495+
496+ ConstMatrixMap A (out_data, output_image_size, out_depth);
497+ ConstMatrixMap B (filter_data, filter_total_size, out_depth);
498+
499+ C.noalias () = A * B.transpose ();
500+
501+ Col2im<T>(im2col_buf, in_depth, input_rows, input_cols, filter_rows,
502+ filter_cols, pad_top, pad_left, pad_bottom, pad_right,
503+ stride, stride, input_data);
504+ }
505+ };
506+ Shard (worker_threads.num_threads , worker_threads.workers , shard_limit,
507+ work_unit_size, shard);
508+
509+ input_backprop_data += input_offset * shard_limit;
510+ out_backprop_data += output_offset * shard_limit;
511+ }
428512 }
429513 }
430514
@@ -620,8 +704,8 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
620704 &pad_left, &pad_bottom, &pad_right, &stride, &input_offset,
621705 &size_A](int64 start, int64 limit) {
622706 for (int shard_id = start; shard_id < limit; ++shard_id) {
623- auto input_data_shard = input_data + shard_id * input_offset;
624- auto col_data_shard = col_buffer_data + shard_id * size_A;
707+ const T* input_data_shard = input_data + shard_id * input_offset;
708+ T* col_data_shard = col_buffer_data + shard_id * size_A;
625709
626710 // When we compute the gradient with respect to the filters, we need
627711 // to do im2col to allow gemm-type computation.
0 commit comments