Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
format file
  • Loading branch information
jonpsy authored Dec 3, 2021
commit 7c766c3e5a114b6f4707b231a30e36ac7454cec9
36 changes: 23 additions & 13 deletions tensorflow_lite_support/cc/task/processor/image_postprocessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,17 @@ GetNormalizationOptionsIfAny(const TensorMetadata& tensor_metadata) {
tflite::support::StatusOr<std::unique_ptr<ImagePostprocessor>>
ImagePostprocessor::Create(core::TfLiteEngine* engine,
const std::initializer_list<int> output_indices,
const std::initializer_list<int> input_indices) {
ASSIGN_OR_RETURN(auto processor, Processor::Create<ImagePostprocessor>(/* num_expected_tensors = */ 1, engine, output_indices, /* requires_metadata = */ false));
const std::initializer_list<int> input_indices) {
ASSIGN_OR_RETURN(auto processor,
Processor::Create<ImagePostprocessor>(
/* num_expected_tensors = */ 1, engine, output_indices,
/* requires_metadata = */ false));

RETURN_IF_ERROR(processor->Init(input_indices));
return processor;
}

absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
absl::Status ImagePostprocessor::Init(const std::vector<int> &input_indices) {
if (core::TfLiteEngine::OutputCount(engine_->interpreter()) != 1) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
Expand All @@ -97,13 +100,15 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
tflite::support::TfLiteSupportStatus::kInvalidNumOutputTensorsError);
}

if (GetTensor()->type != kTfLiteUInt8 && GetTensor()->type != kTfLiteFloat32) {
if (GetTensor()->type != kTfLiteUInt8 &&
GetTensor()->type != kTfLiteFloat32) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Type mismatch for output tensor %s. Requested one "
"of these types: "
"kTfLiteUint8/kTfLiteFloat32, got %s.",
GetTensor()->name, TfLiteTypeGetName(GetTensor()->type)),
GetTensor()->name,
TfLiteTypeGetName(GetTensor()->type)),
tflite::support::TfLiteSupportStatus::kInvalidOutputTensorTypeError);
}

Expand All @@ -112,8 +117,9 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
absl::StatusCode::kInvalidArgument,
absl::StrCat("The input tensor should have dimensions 1 x height x "
"width x 3. Got ",
GetTensor()->dims->data[0], " x ", GetTensor()->dims->data[1],
" x ", GetTensor()->dims->data[2], " x ",
GetTensor()->dims->data[0], " x ",
GetTensor()->dims->data[1], " x ",
GetTensor()->dims->data[2], " x ",
GetTensor()->dims->data[3], "."),
tflite::support::TfLiteSupportStatus::
kInvalidInputTensorDimensionsError);
Expand All @@ -123,12 +129,15 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
const tflite::TensorMetadata* output_metadata =
engine_->metadata_extractor()->GetOutputTensorMetadata(
tensor_indices_.at(0));
const tflite::TensorMetadata* input_metadata = engine_->metadata_extractor()->GetInputTensorMetadata(
input_indices.at(0));
const tflite::TensorMetadata* input_metadata =
engine_->metadata_extractor()->GetInputTensorMetadata(
input_indices.at(0));

// Use input metadata for normalization as fallback.
const tflite::TensorMetadata* processing_metadata =
GetNormalizationOptionsIfAny(*output_metadata).value().has_value() ? output_metadata : input_metadata;
GetNormalizationOptionsIfAny(*output_metadata).value().has_value()
? output_metadata
: input_metadata;

absl::optional<vision::NormalizationOptions> normalization_options;
ASSIGN_OR_RETURN(normalization_options,
Expand Down Expand Up @@ -180,7 +189,8 @@ absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() {
postprocessed_data.insert(postprocessed_data.begin(), &output_data[0],
&output_data[output_byte_size / sizeof(uint8)]);
} else { // Denormalize to [0, 255] range.
if (GetTensor()->bytes / sizeof(float) != output_byte_size / sizeof(uint8)) {
if (GetTensor()->bytes / sizeof(float) !=
output_byte_size / sizeof(uint8)) {
return tflite::support::CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Size mismatch or unsupported padding bytes between pixel data "
Expand All @@ -205,7 +215,7 @@ absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() {
for (size_t i = 0; i < output_byte_size / sizeof(uint8);
++i, ++denormalized_output_data, ++output_data) {
*denormalized_output_data = static_cast<uint8>(std::round(std::min(
255.f,
255.f,
std::max(0.f, (*output_data) * norm_options.std_values[i % 3] +
norm_options.mean_values[i % 3]))));
}
Expand All @@ -219,7 +229,7 @@ absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() {
vision::FrameBuffer::Create({postprocessed_plane}, to_buffer_dimension,
vision::FrameBuffer::Format::kRGB,
vision::FrameBuffer::Orientation::kTopLeft);
return *postprocessed_frame_buffer.get();;
return *postprocessed_frame_buffer.get();
}

} // namespace processor
Expand Down